diff --git a/.circleci/config.yml b/.circleci/config.yml
index bcbacbe3bd..bde4a90784 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1,5 +1,5 @@
# Setup in CircleCI account the following ENV variables:
-# PACKAGECLOUD_ORGANIZATION (default: stackstorm)
+# PACKAGECLOUD_ORGANIZATION (default: stackstorm)
# PACKAGECLOUD_TOKEN
version: 2
jobs:
diff --git a/.git-blame-ignore-rev b/.git-blame-ignore-rev
new file mode 100644
index 0000000000..2e9f4011b2
--- /dev/null
+++ b/.git-blame-ignore-rev
@@ -0,0 +1,5 @@
+# Code was auto formatted using black
+8496bb2407b969f0937431992172b98b545f6756
+
+# Files were auto formatted to remove trailing whitespace
+969793f1fdbdd2c228e59ab112189166530d2680
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000..c539e0b854
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,42 @@
+# pre-commit hook which runs all the various lint checks + black auto formatting on the added
+# files.
+# This hook relies on development virtual environment being present in virtualenv/.
+default_language_version:
+ python: python3.6
+
+exclude: '(build|dist)'
+repos:
+ - repo: local
+ hooks:
+ - id: black
+ name: black
+ entry: ./virtualenv/bin/python -m black --config pyproject.toml
+ language: script
+ types: [file, python]
+ - repo: local
+ hooks:
+ - id: flake8
+ name: flake8
+ entry: ./virtualenv/bin/python -m flake8 --config lint-configs/python/.flake8
+ language: script
+ types: [file, python]
+ - repo: local
+ hooks:
+ - id: pylint
+ name: pylint
+ entry: ./virtualenv/bin/python -m pylint -E --rcfile=./lint-configs/python/.pylintrc
+ language: script
+ types: [file, python]
+ - repo: local
+ hooks:
+ - id: bandit
+ name: bandit
+ entry: ./virtualenv/bin/python -m bandit -lll -x build,dist
+ language: script
+ types: [file, python]
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v2.5.0
+ hooks:
+ - id: trailing-whitespace
+ exclude: (^conf/|^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures)
+ - id: check-yaml
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 11ecb52105..0fa0a562b8 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -4,6 +4,13 @@ Changelog
in development
--------------
+Changed
+~~~~~~~
+
+* All the code has been refactored using black and black style is automatically enforced and
+ required for all the new code. (#5156)
+
+ Contributed by @Kami.
3.4.0 - March 02, 2021
----------------------
@@ -22,7 +29,8 @@ Added
* Added st2-auth-ldap pip requirements for LDAP auth integartion. (new feature) #5082
Contributed by @hnanchahal
-* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch. (part of upgrade instructions) [#5167]
+* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch.
+ (part of upgrade instructions) [#5167]
Contributed by @winem and @blag
Changed
@@ -34,7 +42,7 @@ Changed
* Improve the st2-self-check script to echo to stderr and exit if it isn't run with a
ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068
-* Added timeout parameter for packs.install action to help with long running installs that exceed the
+* Added timeout parameter for packs.install action to help with long running installs that exceed the
default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084
Contributed by @hnanchahal
diff --git a/Makefile b/Makefile
index 65ca6204ae..91439ae15b 100644
--- a/Makefile
+++ b/Makefile
@@ -326,6 +326,73 @@ schemasgen: requirements .schemasgen
. $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc --load-plugins=pylint_plugins.api_models tools/*.py || exit 1;
. $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc pylint_plugins/*.py || exit 1;
+# Black task which checks if the code comforts to black code style
+.PHONY: black-check
+black: requirements .black-check
+
+.PHONY: .black-check
+.black-check:
+ @echo
+ @echo "================== black-check ===================="
+ @echo
+ # st2 components
+ @for component in $(COMPONENTS); do\
+ echo "==========================================================="; \
+ echo "Running black on" $$component; \
+ echo "==========================================================="; \
+ . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \
+ done
+ # runner modules and packages
+ @for component in $(COMPONENTS_RUNNERS); do\
+ echo "==========================================================="; \
+ echo "Running black on" $$component; \
+ echo "==========================================================="; \
+ . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \
+ done
+ . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/ || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml scripts/*.py || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml tools/*.py || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml pylint_plugins/*.py || exit 1;
+
+# Black task which reformats the code using black
+.PHONY: black-format
+black: requirements .black-format
+
+.PHONY: .black-format
+.black-format:
+ @echo
+ @echo "================== black ===================="
+ @echo
+ # st2 components
+ @for component in $(COMPONENTS); do\
+ echo "==========================================================="; \
+ echo "Running black on" $$component; \
+ echo "==========================================================="; \
+ . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \
+ done
+ # runner modules and packages
+ @for component in $(COMPONENTS_RUNNERS); do\
+ echo "==========================================================="; \
+ echo "Running black on" $$component; \
+ echo "==========================================================="; \
+ . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \
+ done
+ . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml contrib/ || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml scripts/*.py || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml tools/*.py || exit 1;
+ . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml pylint_plugins/*.py || exit 1;
+
+.PHONY: pre-commit-checks
+black: requirements .pre-commit-checks
+
+# Ensure all files contain no trailing whitespace + that all YAML files are valid.
+.PHONY: .pre-commit-checks
+.pre-commit-checks:
+ @echo
+ @echo "================== pre-commit-checks ===================="
+ @echo
+ . $(VIRTUALENV_DIR)/bin/activate; pre-commit run trailing-whitespace --all --show-diff-on-failure
+ . $(VIRTUALENV_DIR)/bin/activate; pre-commit run check-yaml --all --show-diff-on-failure
.PHONY: lint-api-spec
lint-api-spec: requirements .lint-api-spec
@@ -418,7 +485,7 @@ bandit: requirements .bandit
lint: requirements .lint
.PHONY: .lint
-.lint: .generate-api-spec .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check
+.lint: .generate-api-spec .black-check .pre-commit-checks .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check
.PHONY: clean
clean: .cleanpycs
@@ -979,7 +1046,7 @@ debs:
ci: ci-checks ci-unit ci-integration ci-packs-tests
.PHONY: ci-checks
-ci-checks: .generated-files-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages
+ci-checks: .generated-files-check .black-check .pre-commit-checks .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages
.PHONY: .rst-check
.rst-check:
diff --git a/OWNERS.md b/OWNERS.md
index dfb8fb87bc..501e7b8f14 100644
--- a/OWNERS.md
+++ b/OWNERS.md
@@ -74,7 +74,7 @@ Thank you, Friends!
* Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/).
* Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/).
* Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times.
-* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today.
+* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today.
* Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features.
* Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA.
* Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure.
diff --git a/README.md b/README.md
index 4d84895bbd..b22e908d5c 100644
--- a/README.md
+++ b/README.md
@@ -4,12 +4,12 @@
[](https://github.com/StackStorm/st2/actions?query=branch%3Amaster)
[](https://travis-ci.org/StackStorm/st2)
-[](https://circleci.com/gh/StackStorm/st2)
-[](https://codecov.io/github/StackStorm/st2?branch=master)
-[](https://bestpractices.coreinfrastructure.org/projects/1833)
-
-[](LICENSE)
-[](https://stackstorm.com/community-signup)
+[](https://circleci.com/gh/StackStorm/st2)
+[](https://codecov.io/github/StackStorm/st2?branch=master)
+[](https://bestpractices.coreinfrastructure.org/projects/1833)
+
+[](LICENSE)
+[](https://stackstorm.com/community-signup)
[](https://forum.stackstorm.com/)
---
diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample
index 758b743e75..488939eb55 100644
--- a/conf/st2.conf.sample
+++ b/conf/st2.conf.sample
@@ -2,7 +2,7 @@
# Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY
[action_sensor]
-# List of execution statuses for which a trigger will be emitted.
+# List of execution statuses for which a trigger will be emitted.
emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here.
# Whether to enable or disable the ability to post a trigger on action.
enable = True
@@ -170,7 +170,7 @@ trigger_instances_ttl = None
# Allow encryption of values in key value stored qualified as "secret".
enable_encryption = True
# Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool.
-encryption_key_path =
+encryption_key_path =
[log]
# Exclusion list of loggers to omit.
diff --git a/contrib/chatops/actions/format_execution_result.py b/contrib/chatops/actions/format_execution_result.py
index 8790ae4ae7..d6830df004 100755
--- a/contrib/chatops/actions/format_execution_result.py
+++ b/contrib/chatops/actions/format_execution_result.py
@@ -23,51 +23,50 @@
class FormatResultAction(Action):
def __init__(self, config=None, action_service=None):
- super(FormatResultAction, self).__init__(config=config, action_service=action_service)
- api_url = os.environ.get('ST2_ACTION_API_URL', None)
- token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None)
+ super(FormatResultAction, self).__init__(
+ config=config, action_service=action_service
+ )
+ api_url = os.environ.get("ST2_ACTION_API_URL", None)
+ token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None)
self.client = Client(api_url=api_url, token=token)
self.jinja = jinja_utils.get_jinja_environment(allow_undefined=True)
- self.jinja.tests['in'] = lambda item, list: item in list
+ self.jinja.tests["in"] = lambda item, list: item in list
path = os.path.dirname(os.path.realpath(__file__))
- with open(os.path.join(path, 'templates/default.j2'), 'r') as f:
+ with open(os.path.join(path, "templates/default.j2"), "r") as f:
self.default_template = f.read()
def run(self, execution_id):
execution = self._get_execution(execution_id)
- context = {
- 'six': six,
- 'execution': execution
- }
+ context = {"six": six, "execution": execution}
template = self.default_template
result = {"enabled": True}
- alias_id = execution['context'].get('action_alias_ref', {}).get('id', None)
+ alias_id = execution["context"].get("action_alias_ref", {}).get("id", None)
if alias_id:
- alias = self.client.managers['ActionAlias'].get_by_id(alias_id)
+ alias = self.client.managers["ActionAlias"].get_by_id(alias_id)
- context.update({
- 'alias': alias
- })
+ context.update({"alias": alias})
- result_params = getattr(alias, 'result', None)
+ result_params = getattr(alias, "result", None)
if result_params:
- if not result_params.get('enabled', True):
+ if not result_params.get("enabled", True):
result["enabled"] = False
else:
- if 'format' in alias.result:
- template = alias.result['format']
- if 'extra' in alias.result:
- result['extra'] = jinja_utils.render_values(alias.result['extra'], context)
+ if "format" in alias.result:
+ template = alias.result["format"]
+ if "extra" in alias.result:
+ result["extra"] = jinja_utils.render_values(
+ alias.result["extra"], context
+ )
- result['message'] = self.jinja.from_string(template).render(context)
+ result["message"] = self.jinja.from_string(template).render(context)
return result
def _get_execution(self, execution_id):
if not execution_id:
- raise ValueError('Invalid execution_id provided.')
+ raise ValueError("Invalid execution_id provided.")
execution = self.client.liveactions.get_by_id(id=execution_id)
if not execution:
return None
diff --git a/contrib/chatops/actions/match.py b/contrib/chatops/actions/match.py
index 46dac1ff64..7ee2154b42 100644
--- a/contrib/chatops/actions/match.py
+++ b/contrib/chatops/actions/match.py
@@ -23,23 +23,16 @@
class MatchAction(Action):
def __init__(self, config=None):
super(MatchAction, self).__init__(config=config)
- api_url = os.environ.get('ST2_ACTION_API_URL', None)
- token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None)
+ api_url = os.environ.get("ST2_ACTION_API_URL", None)
+ token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None)
self.client = Client(api_url=api_url, token=token)
def run(self, text):
alias_match = ActionAliasMatch()
alias_match.command = text
- matches = self.client.managers['ActionAlias'].match(alias_match)
- return {
- 'alias': _format_match(matches[0]),
- 'representation': matches[1]
- }
+ matches = self.client.managers["ActionAlias"].match(alias_match)
+ return {"alias": _format_match(matches[0]), "representation": matches[1]}
def _format_match(match):
- return {
- 'name': match.name,
- 'pack': match.pack,
- 'action_ref': match.action_ref
- }
+ return {"name": match.name, "pack": match.pack, "action_ref": match.action_ref}
diff --git a/contrib/chatops/actions/match_and_execute.py b/contrib/chatops/actions/match_and_execute.py
index 11388e599b..5e90080f03 100644
--- a/contrib/chatops/actions/match_and_execute.py
+++ b/contrib/chatops/actions/match_and_execute.py
@@ -19,25 +19,26 @@
from st2common.runners.base_action import Action
from st2client.models.action_alias import ActionAliasMatch
from st2client.models.aliasexecution import ActionAliasExecution
-from st2client.commands.action import (LIVEACTION_STATUS_REQUESTED,
- LIVEACTION_STATUS_SCHEDULED,
- LIVEACTION_STATUS_RUNNING,
- LIVEACTION_STATUS_CANCELING)
+from st2client.commands.action import (
+ LIVEACTION_STATUS_REQUESTED,
+ LIVEACTION_STATUS_SCHEDULED,
+ LIVEACTION_STATUS_RUNNING,
+ LIVEACTION_STATUS_CANCELING,
+)
from st2client.client import Client
class ExecuteActionAliasAction(Action):
def __init__(self, config=None):
super(ExecuteActionAliasAction, self).__init__(config=config)
- api_url = os.environ.get('ST2_ACTION_API_URL', None)
- token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None)
+ api_url = os.environ.get("ST2_ACTION_API_URL", None)
+ token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None)
self.client = Client(api_url=api_url, token=token)
def run(self, text, source_channel=None, user=None):
alias_match = ActionAliasMatch()
alias_match.command = text
- alias, representation = self.client.managers['ActionAlias'].match(
- alias_match)
+ alias, representation = self.client.managers["ActionAlias"].match(alias_match)
execution = ActionAliasExecution()
execution.name = alias.name
@@ -48,20 +49,20 @@ def run(self, text, source_channel=None, user=None):
execution.notification_route = None
execution.user = user
- action_exec_mgr = self.client.managers['ActionAliasExecution']
+ action_exec_mgr = self.client.managers["ActionAliasExecution"]
execution = action_exec_mgr.create(execution)
- self._wait_execution_to_finish(execution.execution['id'])
- return execution.execution['id']
+ self._wait_execution_to_finish(execution.execution["id"])
+ return execution.execution["id"]
def _wait_execution_to_finish(self, execution_id):
pending_statuses = [
LIVEACTION_STATUS_REQUESTED,
LIVEACTION_STATUS_SCHEDULED,
LIVEACTION_STATUS_RUNNING,
- LIVEACTION_STATUS_CANCELING
+ LIVEACTION_STATUS_CANCELING,
]
- action_exec_mgr = self.client.managers['LiveAction']
+ action_exec_mgr = self.client.managers["LiveAction"]
execution = action_exec_mgr.get_by_id(execution_id)
while execution.status in pending_statuses:
time.sleep(1)
diff --git a/contrib/chatops/tests/test_format_result.py b/contrib/chatops/tests/test_format_result.py
index e700af7454..05114cb361 100644
--- a/contrib/chatops/tests/test_format_result.py
+++ b/contrib/chatops/tests/test_format_result.py
@@ -20,9 +20,7 @@
from format_execution_result import FormatResultAction
-__all__ = [
- 'FormatResultActionTestCase'
-]
+__all__ = ["FormatResultActionTestCase"]
class FormatResultActionTestCase(BaseActionTestCase):
@@ -30,47 +28,45 @@ class FormatResultActionTestCase(BaseActionTestCase):
def test_rendering_works_remote_shell_cmd(self):
remote_shell_cmd_execution_model = json.loads(
- self.get_fixture_content('remote_cmd_execution.json')
+ self.get_fixture_content("remote_cmd_execution.json")
)
action = self.get_action_instance()
action._get_execution = mock.MagicMock(
return_value=remote_shell_cmd_execution_model
)
- result = action.run(execution_id='57967f9355fc8c19a96d9e4f')
+ result = action.run(execution_id="57967f9355fc8c19a96d9e4f")
self.assertTrue(result)
- self.assertIn('web_url', result['message'])
- self.assertIn('Took 2s to complete', result['message'])
+ self.assertIn("web_url", result["message"])
+ self.assertIn("Took 2s to complete", result["message"])
def test_rendering_local_shell_cmd(self):
local_shell_cmd_execution_model = json.loads(
- self.get_fixture_content('local_cmd_execution.json')
+ self.get_fixture_content("local_cmd_execution.json")
)
action = self.get_action_instance()
action._get_execution = mock.MagicMock(
return_value=local_shell_cmd_execution_model
)
- self.assertTrue(action.run(execution_id='5799522f55fc8c2d33ac03e0'))
+ self.assertTrue(action.run(execution_id="5799522f55fc8c2d33ac03e0"))
def test_rendering_http_request(self):
http_execution_model = json.loads(
- self.get_fixture_content('http_execution.json')
+ self.get_fixture_content("http_execution.json")
)
action = self.get_action_instance()
- action._get_execution = mock.MagicMock(
- return_value=http_execution_model
- )
- self.assertTrue(action.run(execution_id='579955f055fc8c2d33ac03e3'))
+ action._get_execution = mock.MagicMock(return_value=http_execution_model)
+ self.assertTrue(action.run(execution_id="579955f055fc8c2d33ac03e3"))
def test_rendering_python_action(self):
python_action_execution_model = json.loads(
- self.get_fixture_content('python_action_execution.json')
+ self.get_fixture_content("python_action_execution.json")
)
action = self.get_action_instance()
action._get_execution = mock.MagicMock(
return_value=python_action_execution_model
)
- self.assertTrue(action.run(execution_id='5799572a55fc8c2d33ac03ec'))
+ self.assertTrue(action.run(execution_id="5799572a55fc8c2d33ac03ec"))
diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md
index b9c04efa88..c0b1692b03 100644
--- a/contrib/core/CHANGES.md
+++ b/contrib/core/CHANGES.md
@@ -1,5 +1,5 @@
# Changelog
-
+
## 0.3.1
* Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs.
diff --git a/contrib/core/actions/generate_uuid.py b/contrib/core/actions/generate_uuid.py
index 972b7cb552..88d8125549 100644
--- a/contrib/core/actions/generate_uuid.py
+++ b/contrib/core/actions/generate_uuid.py
@@ -18,16 +18,14 @@
from st2common.runners.base_action import Action
-__all__ = [
- 'GenerateUUID'
-]
+__all__ = ["GenerateUUID"]
class GenerateUUID(Action):
def run(self, uuid_type):
- if uuid_type == 'uuid1':
+ if uuid_type == "uuid1":
return str(uuid.uuid1())
- elif uuid_type == 'uuid4':
+ elif uuid_type == "uuid4":
return str(uuid.uuid4())
else:
raise ValueError("Unknown uuid_type. Only uuid1 and uuid4 are supported")
diff --git a/contrib/core/actions/inject_trigger.py b/contrib/core/actions/inject_trigger.py
index 706e2165db..a6b2e68317 100644
--- a/contrib/core/actions/inject_trigger.py
+++ b/contrib/core/actions/inject_trigger.py
@@ -17,9 +17,7 @@
from st2common.runners.base_action import Action
-__all__ = [
- 'InjectTriggerAction'
-]
+__all__ = ["InjectTriggerAction"]
class InjectTriggerAction(Action):
@@ -34,8 +32,11 @@ def run(self, trigger, payload=None, trace_tag=None):
# results in a TriggerInstanceDB database object creation or not. The object is created
# inside rulesengine service and could fail due to the user providing an invalid trigger
# reference or similar.
- self.logger.debug('Injecting trigger "%s" with payload="%s"' % (trigger, str(payload)))
- result = client.webhooks.post_generic_webhook(trigger=trigger, payload=payload,
- trace_tag=trace_tag)
+ self.logger.debug(
+ 'Injecting trigger "%s" with payload="%s"' % (trigger, str(payload))
+ )
+ result = client.webhooks.post_generic_webhook(
+ trigger=trigger, payload=payload, trace_tag=trace_tag
+ )
return result
diff --git a/contrib/core/actions/pause.py b/contrib/core/actions/pause.py
index 99b9ed9e9b..7ef8b4eccb 100755
--- a/contrib/core/actions/pause.py
+++ b/contrib/core/actions/pause.py
@@ -19,9 +19,7 @@
from st2common.runners.base_action import Action
-__all__ = [
- 'PauseAction'
-]
+__all__ = ["PauseAction"]
class PauseAction(Action):
diff --git a/contrib/core/tests/test_action_inject_trigger.py b/contrib/core/tests/test_action_inject_trigger.py
index 4e0c3b1a29..7c8e44ac98 100644
--- a/contrib/core/tests/test_action_inject_trigger.py
+++ b/contrib/core/tests/test_action_inject_trigger.py
@@ -27,50 +27,46 @@
class InjectTriggerActionTestCase(BaseActionTestCase):
action_cls = InjectTriggerAction
- @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client')
+ @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client")
def test_inject_trigger_only_trigger_no_payload(self, mock_get_api_client):
mock_api_client = mock.Mock()
mock_get_api_client.return_value = mock_api_client
action = self.get_action_instance()
- action.run(trigger='dummy_pack.trigger1')
+ action.run(trigger="dummy_pack.trigger1")
mock_api_client.webhooks.post_generic_webhook.assert_called_with(
- trigger='dummy_pack.trigger1',
- payload={},
- trace_tag=None
+ trigger="dummy_pack.trigger1", payload={}, trace_tag=None
)
mock_api_client.webhooks.post_generic_webhook.reset()
- @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client')
+ @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client")
def test_inject_trigger_trigger_and_payload(self, mock_get_api_client):
mock_api_client = mock.Mock()
mock_get_api_client.return_value = mock_api_client
action = self.get_action_instance()
- action.run(trigger='dummy_pack.trigger2', payload={'foo': 'bar'})
+ action.run(trigger="dummy_pack.trigger2", payload={"foo": "bar"})
mock_api_client.webhooks.post_generic_webhook.assert_called_with(
- trigger='dummy_pack.trigger2',
- payload={'foo': 'bar'},
- trace_tag=None
+ trigger="dummy_pack.trigger2", payload={"foo": "bar"}, trace_tag=None
)
mock_api_client.webhooks.post_generic_webhook.reset()
- @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client')
+ @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client")
def test_inject_trigger_trigger_payload_trace_tag(self, mock_get_api_client):
mock_api_client = mock.Mock()
mock_get_api_client.return_value = mock_api_client
action = self.get_action_instance()
- action.run(trigger='dummy_pack.trigger3', payload={'foo': 'bar'}, trace_tag='Tag1')
+ action.run(
+ trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1"
+ )
mock_api_client.webhooks.post_generic_webhook.assert_called_with(
- trigger='dummy_pack.trigger3',
- payload={'foo': 'bar'},
- trace_tag='Tag1'
+ trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1"
)
diff --git a/contrib/core/tests/test_action_sendmail.py b/contrib/core/tests/test_action_sendmail.py
index 241fd35d68..b821ca5f12 100644
--- a/contrib/core/tests/test_action_sendmail.py
+++ b/contrib/core/tests/test_action_sendmail.py
@@ -33,12 +33,10 @@
from local_runner.local_shell_script_runner import LocalShellScriptRunner
-__all__ = [
- 'SendmailActionTestCase'
-]
+__all__ = ["SendmailActionTestCase"]
MOCK_EXECUTION = mock.Mock()
-MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b'
+MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b"
HOSTNAME = socket.gethostname()
@@ -47,134 +45,151 @@ class SendmailActionTestCase(RunnerTestCase, CleanDbTestCase, CleanFilesTestCase
NOTE: Those tests rely on stanley user being available on the system and having passwordless
sudo access.
"""
+
fixtures_loader = FixturesLoader()
def test_sendmail_default_text_html_content_type(self):
action_parameters = {
- 'sendmail_binary': 'cat',
-
- 'from': 'from.user@example.tld1',
- 'to': 'to.user@example.tld2',
- 'subject': 'this is subject 1',
- 'send_empty_body': False,
- 'content_type': 'text/html',
- 'body': 'Hello there html.',
- 'attachments': ''
+ "sendmail_binary": "cat",
+ "from": "from.user@example.tld1",
+ "to": "to.user@example.tld2",
+ "subject": "this is subject 1",
+ "send_empty_body": False,
+ "content_type": "text/html",
+ "body": "Hello there html.",
+ "attachments": "",
}
- expected_body = ('Hello there html.\n'
- '
\n'
- 'This message was generated by StackStorm action '
- 'send_mail running on %s' % (HOSTNAME))
+ expected_body = (
+ "Hello there html.\n"
+ "
\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
- status, _, email_data, message = self._run_action(action_parameters=action_parameters)
+ status, _, email_data, message = self._run_action(
+ action_parameters=action_parameters
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
# Verify subject contains utf-8 charset and is base64 encoded
- self.assertIn('SUBJECT: =?UTF-8?B?', email_data)
+ self.assertIn("SUBJECT: =?UTF-8?B?", email_data)
- self.assertEqual(message.to[0][1], action_parameters['to'])
- self.assertEqual(message.from_[0][1], action_parameters['from'])
- self.assertEqual(message.subject, action_parameters['subject'])
+ self.assertEqual(message.to[0][1], action_parameters["to"])
+ self.assertEqual(message.from_[0][1], action_parameters["from"])
+ self.assertEqual(message.subject, action_parameters["subject"])
self.assertEqual(message.body, expected_body)
- self.assertEqual(message.content_type, 'text/html; charset=UTF-8')
+ self.assertEqual(message.content_type, "text/html; charset=UTF-8")
def test_sendmail_text_plain_content_type(self):
action_parameters = {
- 'sendmail_binary': 'cat',
-
- 'from': 'from.user@example.tld1',
- 'to': 'to.user@example.tld2',
- 'subject': 'this is subject 2',
- 'send_empty_body': False,
- 'content_type': 'text/plain',
- 'body': 'Hello there plain.',
- 'attachments': ''
+ "sendmail_binary": "cat",
+ "from": "from.user@example.tld1",
+ "to": "to.user@example.tld2",
+ "subject": "this is subject 2",
+ "send_empty_body": False,
+ "content_type": "text/plain",
+ "body": "Hello there plain.",
+ "attachments": "",
}
- expected_body = ('Hello there plain.\n\n'
- 'This message was generated by StackStorm action '
- 'send_mail running on %s' % (HOSTNAME))
+ expected_body = (
+ "Hello there plain.\n\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
- status, _, email_data, message = self._run_action(action_parameters=action_parameters)
+ status, _, email_data, message = self._run_action(
+ action_parameters=action_parameters
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
# Verify subject contains utf-8 charset and is base64 encoded
- self.assertIn('SUBJECT: =?UTF-8?B?', email_data)
+ self.assertIn("SUBJECT: =?UTF-8?B?", email_data)
- self.assertEqual(message.to[0][1], action_parameters['to'])
- self.assertEqual(message.from_[0][1], action_parameters['from'])
- self.assertEqual(message.subject, action_parameters['subject'])
+ self.assertEqual(message.to[0][1], action_parameters["to"])
+ self.assertEqual(message.from_[0][1], action_parameters["from"])
+ self.assertEqual(message.subject, action_parameters["subject"])
self.assertEqual(message.body, expected_body)
- self.assertEqual(message.content_type, 'text/plain; charset=UTF-8')
+ self.assertEqual(message.content_type, "text/plain; charset=UTF-8")
def test_sendmail_utf8_subject_and_body(self):
# 1. tex/html
action_parameters = {
- 'sendmail_binary': 'cat',
-
- 'from': 'from.user@example.tld1',
- 'to': 'to.user@example.tld2',
- 'subject': u'Å unicode subject 😃😃',
- 'send_empty_body': False,
- 'content_type': 'text/html',
- 'body': u'Hello there 😃😃.',
- 'attachments': ''
+ "sendmail_binary": "cat",
+ "from": "from.user@example.tld1",
+ "to": "to.user@example.tld2",
+ "subject": "Å unicode subject 😃😃",
+ "send_empty_body": False,
+ "content_type": "text/html",
+ "body": "Hello there 😃😃.",
+ "attachments": "",
}
if six.PY2:
- expected_body = (u'Hello there 😃😃.\n'
- u'
\n'
- u'This message was generated by StackStorm action '
- u'send_mail running on %s' % (HOSTNAME))
+ expected_body = (
+ "Hello there 😃😃.\n"
+ "
\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
else:
- expected_body = (u'Hello there \\U0001f603\\U0001f603.\n'
- u'
\n'
- u'This message was generated by StackStorm action '
- u'send_mail running on %s' % (HOSTNAME))
-
- status, _, email_data, message = self._run_action(action_parameters=action_parameters)
+ expected_body = (
+ "Hello there \\U0001f603\\U0001f603.\n"
+ "
\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
+
+ status, _, email_data, message = self._run_action(
+ action_parameters=action_parameters
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
# Verify subject contains utf-8 charset and is base64 encoded
- self.assertIn('SUBJECT: =?UTF-8?B?', email_data)
+ self.assertIn("SUBJECT: =?UTF-8?B?", email_data)
- self.assertEqual(message.to[0][1], action_parameters['to'])
- self.assertEqual(message.from_[0][1], action_parameters['from'])
- self.assertEqual(message.subject, action_parameters['subject'])
+ self.assertEqual(message.to[0][1], action_parameters["to"])
+ self.assertEqual(message.from_[0][1], action_parameters["from"])
+ self.assertEqual(message.subject, action_parameters["subject"])
self.assertEqual(message.body, expected_body)
- self.assertEqual(message.content_type, 'text/html; charset=UTF-8')
+ self.assertEqual(message.content_type, "text/html; charset=UTF-8")
# 2. text/plain
action_parameters = {
- 'sendmail_binary': 'cat',
-
- 'from': 'from.user@example.tld1',
- 'to': 'to.user@example.tld2',
- 'subject': u'Å unicode subject 😃😃',
- 'send_empty_body': False,
- 'content_type': 'text/plain',
- 'body': u'Hello there 😃😃.',
- 'attachments': ''
+ "sendmail_binary": "cat",
+ "from": "from.user@example.tld1",
+ "to": "to.user@example.tld2",
+ "subject": "Å unicode subject 😃😃",
+ "send_empty_body": False,
+ "content_type": "text/plain",
+ "body": "Hello there 😃😃.",
+ "attachments": "",
}
if six.PY2:
- expected_body = (u'Hello there 😃😃.\n\n'
- u'This message was generated by StackStorm action '
- u'send_mail running on %s' % (HOSTNAME))
+ expected_body = (
+ "Hello there 😃😃.\n\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
else:
- expected_body = (u'Hello there \\U0001f603\\U0001f603.\n\n'
- u'This message was generated by StackStorm action '
- u'send_mail running on %s' % (HOSTNAME))
-
- status, _, email_data, message = self._run_action(action_parameters=action_parameters)
+ expected_body = (
+ "Hello there \\U0001f603\\U0001f603.\n\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
+
+ status, _, email_data, message = self._run_action(
+ action_parameters=action_parameters
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(message.to[0][1], action_parameters['to'])
- self.assertEqual(message.from_[0][1], action_parameters['from'])
- self.assertEqual(message.subject, action_parameters['subject'])
+ self.assertEqual(message.to[0][1], action_parameters["to"])
+ self.assertEqual(message.from_[0][1], action_parameters["from"])
+ self.assertEqual(message.subject, action_parameters["subject"])
self.assertEqual(message.body, expected_body)
- self.assertEqual(message.content_type, 'text/plain; charset=UTF-8')
+ self.assertEqual(message.content_type, "text/plain; charset=UTF-8")
def test_sendmail_with_attachments(self):
_, path_1 = tempfile.mkstemp()
@@ -185,48 +200,52 @@ def test_sendmail_with_attachments(self):
self.to_delete_files.append(path_1)
self.to_delete_files.append(path_2)
- with open(path_1, 'w') as fp:
- fp.write('content 1')
+ with open(path_1, "w") as fp:
+ fp.write("content 1")
- with open(path_2, 'w') as fp:
- fp.write('content 2')
+ with open(path_2, "w") as fp:
+ fp.write("content 2")
action_parameters = {
- 'sendmail_binary': 'cat',
-
- 'from': 'from.user@example.tld1',
- 'to': 'to.user@example.tld2',
- 'subject': 'this is email with attachments',
- 'send_empty_body': False,
- 'content_type': 'text/plain',
- 'body': 'Hello there plain.',
- 'attachments': '%s,%s' % (path_1, path_2)
+ "sendmail_binary": "cat",
+ "from": "from.user@example.tld1",
+ "to": "to.user@example.tld2",
+ "subject": "this is email with attachments",
+ "send_empty_body": False,
+ "content_type": "text/plain",
+ "body": "Hello there plain.",
+ "attachments": "%s,%s" % (path_1, path_2),
}
- expected_body = ('Hello there plain.\n\n'
- 'This message was generated by StackStorm action '
- 'send_mail running on %s' % (HOSTNAME))
+ expected_body = (
+ "Hello there plain.\n\n"
+ "This message was generated by StackStorm action "
+ "send_mail running on %s" % (HOSTNAME)
+ )
- status, _, email_data, message = self._run_action(action_parameters=action_parameters)
+ status, _, email_data, message = self._run_action(
+ action_parameters=action_parameters
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
# Verify subject contains utf-8 charset and is base64 encoded
- self.assertIn('SUBJECT: =?UTF-8?B?', email_data)
+ self.assertIn("SUBJECT: =?UTF-8?B?", email_data)
- self.assertEqual(message.to[0][1], action_parameters['to'])
- self.assertEqual(message.from_[0][1], action_parameters['from'])
- self.assertEqual(message.subject, action_parameters['subject'])
+ self.assertEqual(message.to[0][1], action_parameters["to"])
+ self.assertEqual(message.from_[0][1], action_parameters["from"])
+ self.assertEqual(message.subject, action_parameters["subject"])
self.assertEqual(message.body, expected_body)
- self.assertEqual(message.content_type,
- 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"')
+ self.assertEqual(
+ message.content_type, 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"'
+ )
# There should be 3 message parts - 2 for attachments, one for body
- self.assertEqual(email_data.count('--ZZ_/afg6432dfgkl.94531q'), 3)
+ self.assertEqual(email_data.count("--ZZ_/afg6432dfgkl.94531q"), 3)
# There should be 2 attachments
- self.assertEqual(email_data.count('Content-Transfer-Encoding: base64'), 2)
- self.assertIn(base64.b64encode(b'content 1').decode('utf-8'), email_data)
- self.assertIn(base64.b64encode(b'content 2').decode('utf-8'), email_data)
+ self.assertEqual(email_data.count("Content-Transfer-Encoding: base64"), 2)
+ self.assertIn(base64.b64encode(b"content 1").decode("utf-8"), email_data)
+ self.assertIn(base64.b64encode(b"content 2").decode("utf-8"), email_data)
def _run_action(self, action_parameters):
"""
@@ -234,10 +253,12 @@ def _run_action(self, action_parameters):
parse the output email data.
"""
models = self.fixtures_loader.load_models(
- fixtures_pack='packs/core', fixtures_dict={'actions': ['sendmail.yaml']})
- action_db = models['actions']['sendmail.yaml']
+ fixtures_pack="packs/core", fixtures_dict={"actions": ["sendmail.yaml"]}
+ )
+ action_db = models["actions"]["sendmail.yaml"]
entry_point = self.fixtures_loader.get_fixture_file_path_abs(
- 'packs/core', 'actions', 'send_mail/send_mail')
+ "packs/core", "actions", "send_mail/send_mail"
+ )
runner = self._get_runner(action_db, entry_point=entry_point)
runner.pre_run()
@@ -246,13 +267,13 @@ def _run_action(self, action_parameters):
# Remove footer added by the action which is not part of raw email data and parse
# the message
- if 'stdout' in result:
- email_data = result['stdout']
- email_data = email_data.split('\n')[:-2]
- email_data = '\n'.join(email_data)
+ if "stdout" in result:
+ email_data = result["stdout"]
+ email_data = email_data.split("\n")[:-2]
+ email_data = "\n".join(email_data)
if six.PY2 and isinstance(email_data, six.text_type):
- email_data = email_data.encode('utf-8')
+ email_data = email_data.encode("utf-8")
message = mailparser.parse_from_string(email_data)
else:
@@ -273,5 +294,5 @@ def _get_runner(self, action_db, entry_point):
runner.callback = dict()
runner.libs_dir_path = None
runner.auth_token = mock.Mock()
- runner.auth_token.token = 'mock-token'
+ runner.auth_token.token = "mock-token"
return runner
diff --git a/contrib/core/tests/test_action_uuid.py b/contrib/core/tests/test_action_uuid.py
index 4e946f3062..e13cfd4c18 100644
--- a/contrib/core/tests/test_action_uuid.py
+++ b/contrib/core/tests/test_action_uuid.py
@@ -28,13 +28,13 @@ def test_run(self):
action = self.get_action_instance()
# accepts uuid1 as a type
- result = action.run(uuid_type='uuid1')
+ result = action.run(uuid_type="uuid1")
self.assertTrue(result)
# accepts uuid4 as a type
- result = action.run(uuid_type='uuid4')
+ result = action.run(uuid_type="uuid4")
self.assertTrue(result)
# fails on incorrect type
with self.assertRaises(ValueError):
- result = action.run(uuid_type='foobar')
+ result = action.run(uuid_type="foobar")
diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml
index f226eae420..86ead5303a 100644
--- a/contrib/examples/actions/forloop_chain.yaml
+++ b/contrib/examples/actions/forloop_chain.yaml
@@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml"
enabled: true
parameters:
github_organization_url:
- type: "string"
+ type: "string"
description: "Organization url to parse data from"
default: "https://github.com/StackStorm-Exchange"
required: false
diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml
index 3ff06eabc3..878772636a 100644
--- a/contrib/examples/actions/forloop_push_github_repos.yaml
+++ b/contrib/examples/actions/forloop_push_github_repos.yaml
@@ -5,7 +5,7 @@ description: "Action to push the data to an external service"
enabled: true
entry_point: "pythonactions/forloop_push_github_repos.py"
parameters:
- data_to_push:
+ data_to_push:
type: "object"
description: "Dictonary of the data to be pushed"
required: true
diff --git a/contrib/examples/actions/noop.py b/contrib/examples/actions/noop.py
index 0283499ce1..bbdf5e67e6 100644
--- a/contrib/examples/actions/noop.py
+++ b/contrib/examples/actions/noop.py
@@ -5,6 +5,6 @@
class PrintParametersAction(Action):
def run(self, **parameters):
- print('=========')
+ print("=========")
pprint(parameters)
- print('=========')
+ print("=========")
diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml
index 85e774a702..35c5ab26d8 100644
--- a/contrib/examples/actions/orquesta-mock-create-vm.yaml
+++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml
@@ -15,7 +15,7 @@ parameters:
required: true
type: string
ip:
- default: "10.1.23.99"
+ default: "10.1.23.99"
required: true
type: string
meta:
diff --git a/contrib/examples/actions/print_config.py b/contrib/examples/actions/print_config.py
index 68bdf1e2d6..15b3103b61 100644
--- a/contrib/examples/actions/print_config.py
+++ b/contrib/examples/actions/print_config.py
@@ -5,6 +5,6 @@
class PrintConfigAction(Action):
def run(self):
- print('=========')
+ print("=========")
pprint(self.config)
- print('=========')
+ print("=========")
diff --git a/contrib/examples/actions/print_to_stdout_and_stderr.py b/contrib/examples/actions/print_to_stdout_and_stderr.py
index da31dc14b4..124a32a67c 100644
--- a/contrib/examples/actions/print_to_stdout_and_stderr.py
+++ b/contrib/examples/actions/print_to_stdout_and_stderr.py
@@ -23,12 +23,12 @@ class PrintToStdoutAndStderrAction(Action):
def run(self, count=100, sleep_delay=0.5):
for i in range(0, count):
if i % 2 == 0:
- text = 'stderr'
+ text = "stderr"
stream = sys.stderr
else:
- text = 'stdout'
+ text = "stdout"
stream = sys.stdout
- stream.write('%s -> Line: %s\n' % (text, (i + 1)))
+ stream.write("%s -> Line: %s\n" % (text, (i + 1)))
stream.flush()
time.sleep(sleep_delay)
diff --git a/contrib/examples/actions/python-mock-core-remote.py b/contrib/examples/actions/python-mock-core-remote.py
index cd4d44500e..52c13d804e 100644
--- a/contrib/examples/actions/python-mock-core-remote.py
+++ b/contrib/examples/actions/python-mock-core-remote.py
@@ -2,7 +2,6 @@
class MockCoreRemoteAction(Action):
-
def run(self, cmd, hosts, hosts_dict):
if hosts_dict:
return hosts_dict
@@ -10,14 +9,14 @@ def run(self, cmd, hosts, hosts_dict):
if not hosts:
return None
- host_list = hosts.split(',')
+ host_list = hosts.split(",")
results = {}
for h in hosts:
results[h] = {
- 'failed': False,
- 'return_code': 0,
- 'stderr': '',
- 'succeeded': True,
- 'stdout': cmd,
+ "failed": False,
+ "return_code": 0,
+ "stderr": "",
+ "succeeded": True,
+ "stdout": cmd,
}
return results
diff --git a/contrib/examples/actions/python-mock-create-vm.py b/contrib/examples/actions/python-mock-create-vm.py
index 60a88b7967..62fdaa36c1 100644
--- a/contrib/examples/actions/python-mock-create-vm.py
+++ b/contrib/examples/actions/python-mock-create-vm.py
@@ -5,17 +5,12 @@
class MockCreateVMAction(Action):
-
def run(self, cpu_cores, memory_mb, vm_name, ip):
eventlet.sleep(5)
data = {
- 'vm_id': 'vm' + str(random.randint(0, 10000)),
- ip: {
- 'cpu_cores': cpu_cores,
- 'memory_mb': memory_mb,
- 'vm_name': vm_name
- }
+ "vm_id": "vm" + str(random.randint(0, 10000)),
+ ip: {"cpu_cores": cpu_cores, "memory_mb": memory_mb, "vm_name": vm_name},
}
return data
diff --git a/contrib/examples/actions/pythonactions/fibonacci.py b/contrib/examples/actions/pythonactions/fibonacci.py
index afab612161..bd9a479f35 100755
--- a/contrib/examples/actions/pythonactions/fibonacci.py
+++ b/contrib/examples/actions/pythonactions/fibonacci.py
@@ -12,12 +12,13 @@ def fib(n):
return n
return fib(n - 2) + fib(n - 1)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
try:
startNumber = int(float(sys.argv[1]))
endNumber = int(float(sys.argv[2]))
results = map(str, map(fib, list(range(startNumber, endNumber))))
- results = ' '.join(results)
+ results = " ".join(results)
print(results)
except Exception as e:
traceback.print_exc(file=sys.stderr)
diff --git a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py
index 989467570c..8cb3c42f4b 100644
--- a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py
+++ b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py
@@ -3,13 +3,13 @@
class IncreaseIndexAndCheckCondition(Action):
def run(self, index, pagesize, input):
- if pagesize and pagesize != '':
+ if pagesize and pagesize != "":
if len(input) < int(pagesize):
return (False, "Breaking out of the loop")
else:
pagesize = 0
- if not index or index == '':
+ if not index or index == "":
index = 1
- return(True, int(index) + 1)
+ return (True, int(index) + 1)
diff --git a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py
index a2cdfd1063..dbefc1b07e 100644
--- a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py
+++ b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py
@@ -6,12 +6,12 @@
class ParseGithubRepos(Action):
def run(self, content):
try:
- soup = BeautifulSoup(content, 'html.parser')
+ soup = BeautifulSoup(content, "html.parser")
repo_list = soup.find_all("h3")
output = {}
for each_item in repo_list:
- repo_half_url = each_item.find("a")['href']
+ repo_half_url = each_item.find("a")["href"]
repo_name = repo_half_url.split("/")[-1]
repo_url = "https://github.com" + repo_half_url
output[repo_name] = repo_url
diff --git a/contrib/examples/actions/pythonactions/isprime.py b/contrib/examples/actions/pythonactions/isprime.py
index 911594a01e..e55d202922 100644
--- a/contrib/examples/actions/pythonactions/isprime.py
+++ b/contrib/examples/actions/pythonactions/isprime.py
@@ -6,18 +6,19 @@
class PrimeCheckerAction(Action):
def run(self, value=0):
- self.logger.debug('PYTHONPATH: %s', get_environ('PYTHONPATH'))
- self.logger.debug('value=%s' % (value))
+ self.logger.debug("PYTHONPATH: %s", get_environ("PYTHONPATH"))
+ self.logger.debug("value=%s" % (value))
if math.floor(value) != value:
- raise ValueError('%s should be an integer.' % value)
+ raise ValueError("%s should be an integer." % value)
if value < 2:
return False
- for test in range(2, int(math.floor(math.sqrt(value)))+1):
+ for test in range(2, int(math.floor(math.sqrt(value))) + 1):
if value % test == 0:
return False
return True
-if __name__ == '__main__':
+
+if __name__ == "__main__":
checker = PrimeCheckerAction()
for i in range(0, 10):
- print('%s : %s' % (i, checker.run(value=1)))
+ print("%s : %s" % (i, checker.run(value=1)))
diff --git a/contrib/examples/actions/pythonactions/json_string_to_object.py b/contrib/examples/actions/pythonactions/json_string_to_object.py
index 1072c4554e..e3c492d7a2 100644
--- a/contrib/examples/actions/pythonactions/json_string_to_object.py
+++ b/contrib/examples/actions/pythonactions/json_string_to_object.py
@@ -4,6 +4,5 @@
class JsonStringToObject(Action):
-
def run(self, json_str):
return json.loads(json_str)
diff --git a/contrib/examples/actions/pythonactions/object_return.py b/contrib/examples/actions/pythonactions/object_return.py
index ecaaf57391..f8a008b73d 100644
--- a/contrib/examples/actions/pythonactions/object_return.py
+++ b/contrib/examples/actions/pythonactions/object_return.py
@@ -2,6 +2,5 @@
class ObjectReturnAction(Action):
-
def run(self):
- return {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}}
+ return {"a": "b", "c": {"d": "e", "f": 1, "g": True}}
diff --git a/contrib/examples/actions/pythonactions/print_python_environment.py b/contrib/examples/actions/pythonactions/print_python_environment.py
index 9c070cc1c0..dd92bfc202 100644
--- a/contrib/examples/actions/pythonactions/print_python_environment.py
+++ b/contrib/examples/actions/pythonactions/print_python_environment.py
@@ -6,10 +6,9 @@
class PrintPythonEnvironmentAction(Action):
-
def run(self):
- print('Using Python executable: %s' % (sys.executable))
- print('Using Python version: %s' % (sys.version))
- print('Platform: %s' % (platform.platform()))
- print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH')))
- print('sys.path: %s' % (sys.path))
+ print("Using Python executable: %s" % (sys.executable))
+ print("Using Python version: %s" % (sys.version))
+ print("Platform: %s" % (platform.platform()))
+ print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH")))
+ print("sys.path: %s" % (sys.path))
diff --git a/contrib/examples/actions/pythonactions/print_python_version.py b/contrib/examples/actions/pythonactions/print_python_version.py
index 0ae2a27b18..201c68dd5f 100644
--- a/contrib/examples/actions/pythonactions/print_python_version.py
+++ b/contrib/examples/actions/pythonactions/print_python_version.py
@@ -4,7 +4,6 @@
class PrintPythonVersionAction(Action):
-
def run(self):
- print('Using Python executable: %s' % (sys.executable))
- print('Using Python version: %s' % (sys.version))
+ print("Using Python executable: %s" % (sys.executable))
+ print("Using Python version: %s" % (sys.version))
diff --git a/contrib/examples/actions/pythonactions/yaml_string_to_object.py b/contrib/examples/actions/pythonactions/yaml_string_to_object.py
index 297451cdad..aa888ce408 100644
--- a/contrib/examples/actions/pythonactions/yaml_string_to_object.py
+++ b/contrib/examples/actions/pythonactions/yaml_string_to_object.py
@@ -4,6 +4,5 @@
class YamlStringToObject(Action):
-
def run(self, yaml_str):
return yaml.safe_load(yaml_str)
diff --git a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py
index c2c0198a42..14f19582fd 100644
--- a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py
+++ b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py
@@ -5,11 +5,11 @@ def to_json(out, err, code):
payload = {}
if err:
- payload['err'] = err
- payload['exit_code'] = code
+ payload["err"] = err
+ payload["exit_code"] = code
return json.dumps(payload)
- payload['pkg_info'] = out
- payload['exit_code'] = code
+ payload["pkg_info"] = out
+ payload["exit_code"] = code
return json.dumps(payload)
diff --git a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py
index d8213f4342..ec5e5f7ace 100755
--- a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py
+++ b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py
@@ -7,17 +7,20 @@
def main(args):
- command_list = shlex.split('apt-cache policy ' + ' '.join(args[1:]))
- process = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ command_list = shlex.split("apt-cache policy " + " ".join(args[1:]))
+ process = subprocess.Popen(
+ command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
command_stdout, command_stderr = process.communicate()
command_exitcode = process.returncode
try:
payload = transformer.to_json(command_stdout, command_stderr, command_exitcode)
except Exception as e:
- sys.stderr.write('JSON conversion failed. %s' % six.text_type(e))
+ sys.stderr.write("JSON conversion failed. %s" % six.text_type(e))
sys.exit(1)
sys.stdout.write(payload)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
main(sys.argv)
diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml
index a0793f8bf6..82a131712c 100644
--- a/contrib/examples/actions/workflows/orquesta-delay.yaml
+++ b/contrib/examples/actions/workflows/orquesta-delay.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic sequential workflow.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml
index 5d9c6f22a0..80047d2e5e 100644
--- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml
+++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow that demonstrates error handler with continue.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml
index da9179b5ed..4e3dfa38c2 100644
--- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml
+++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow that demonstrates error handler with remediation and explicit fail.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml
index 61b14a3c11..e949dc3742 100644
--- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml
+++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow that demonstrates error handler with noop to ignore error.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml
index 936db68ff3..b86d8ef25b 100644
--- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml
+++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml
@@ -11,7 +11,7 @@ tasks:
- when: <% failed() %>
publish:
- task_name: <% task().task_name %>
- - task_exit_code: <% task().result.stdout %>
+ - task_exit_code: <% task().result.stdout %>
do:
- log
- fail
diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml
index eaf09fed66..a247423948 100644
--- a/contrib/examples/actions/workflows/orquesta-join.yaml
+++ b/contrib/examples/actions/workflows/orquesta-join.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow that demonstrate branching and join.
vars:
diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml
index 936db68ff3..b86d8ef25b 100644
--- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml
+++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml
@@ -11,7 +11,7 @@ tasks:
- when: <% failed() %>
publish:
- task_name: <% task().task_name %>
- - task_exit_code: <% task().result.stdout %>
+ - task_exit_code: <% task().result.stdout %>
do:
- log
- fail
diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml
index 0d80b0dbcb..a1f203fb09 100644
--- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml
+++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: >
A sample workflow that demonstrates how to handle rollback and retry on error. In this example,
the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time
diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml
index 3a03409d36..404681a369 100644
--- a/contrib/examples/actions/workflows/orquesta-sequential.yaml
+++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic sequential workflow.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml
index e20b907898..6bcbb82c58 100644
--- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml
+++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A workflow demonstrating with items and concurrent processing.
input:
diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml
index 6a2cc4af49..5833e27051 100644
--- a/contrib/examples/actions/workflows/orquesta-with-items.yaml
+++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A workflow demonstrating with items.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml
index ce935f62f7..907a18e8bf 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow with error while evaluating input.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml
index c0322d025e..a8be531180 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic sequential workflow with inspection error(s).
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml
index dd1e516441..003ab8b69d 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow with error while evaluating output.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml
index a0deab1d8f..0c23ee6a82 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow with error in the rendering of the starting task.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml
index 0887d4a7be..149fb93b97 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow that fails on publish during task transition.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml
index 8fd2a94d8a..4d4d9e5f39 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow with error while evaluating task transition.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml
index 403728100a..4ddd986755 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A basic workflow with error while evaluating vars.
input:
diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml
index 7123727cc3..285bf972d7 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml
@@ -19,4 +19,4 @@ tasks:
task2:
action: core.local
input:
- cmd: 'echo "<% $.var1 %>"'
+ cmd: 'echo "<% $.var1 %>"'
diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml
index 11eb22a721..3a4b20cee0 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml
@@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature.
input:
- tempfile
-
+
tasks:
task1:
action: core.noop
diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml
index 8af6899b59..6e24c0ec41 100644
--- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml
+++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml
@@ -1,5 +1,5 @@
version: 1.0
-
+
description: A workflow for testing with items and concurrency.
input:
diff --git a/contrib/examples/sensors/echo_flask_app.py b/contrib/examples/sensors/echo_flask_app.py
index c4025ae441..5177df2306 100644
--- a/contrib/examples/sensors/echo_flask_app.py
+++ b/contrib/examples/sensors/echo_flask_app.py
@@ -6,13 +6,12 @@
class EchoFlaskSensor(Sensor):
def __init__(self, sensor_service, config):
super(EchoFlaskSensor, self).__init__(
- sensor_service=sensor_service,
- config=config
+ sensor_service=sensor_service, config=config
)
- self._host = '127.0.0.1'
+ self._host = "127.0.0.1"
self._port = 5000
- self._path = '/echo'
+ self._path = "/echo"
self._log = self._sensor_service.get_logger(__name__)
self._app = Flask(__name__)
@@ -21,15 +20,19 @@ def setup(self):
pass
def run(self):
- @self._app.route(self._path, methods=['POST'])
+ @self._app.route(self._path, methods=["POST"])
def echo():
payload = request.get_json(force=True)
- self._sensor_service.dispatch(trigger="examples.echoflasksensor",
- payload=payload)
+ self._sensor_service.dispatch(
+ trigger="examples.echoflasksensor", payload=payload
+ )
return request.data
- self._log.info('Listening for payload on http://{}:{}{}'.format(
- self._host, self._port, self._path))
+ self._log.info(
+ "Listening for payload on http://{}:{}{}".format(
+ self._host, self._port, self._path
+ )
+ )
self._app.run(host=self._host, port=self._port, threaded=False)
def cleanup(self):
diff --git a/contrib/examples/sensors/fibonacci_sensor.py b/contrib/examples/sensors/fibonacci_sensor.py
index 266e81aba3..2df956335b 100644
--- a/contrib/examples/sensors/fibonacci_sensor.py
+++ b/contrib/examples/sensors/fibonacci_sensor.py
@@ -4,12 +4,9 @@
class FibonacciSensor(PollingSensor):
-
def __init__(self, sensor_service, config, poll_interval=20):
super(FibonacciSensor, self).__init__(
- sensor_service=sensor_service,
- config=config,
- poll_interval=poll_interval
+ sensor_service=sensor_service, config=config, poll_interval=poll_interval
)
self.a = None
self.b = None
@@ -26,19 +23,21 @@ def setup(self):
def poll(self):
# Reset a and b if there are large enough to avoid integer overflow problems
if self.a > 10000 or self.b > 10000:
- self.logger.debug('Reseting values to avoid integer overflow issues')
+ self.logger.debug("Reseting values to avoid integer overflow issues")
self.a = 0
self.b = 1
self.count = 2
- fib = (self.a + self.b)
- self.logger.debug('Count: %d, a: %d, b: %d, fib: %s', self.count, self.a, self.b, fib)
+ fib = self.a + self.b
+ self.logger.debug(
+ "Count: %d, a: %d, b: %d, fib: %s", self.count, self.a, self.b, fib
+ )
payload = {
"count": self.count,
"fibonacci": fib,
- "pythonpath": os.environ.get("PYTHONPATH", None)
+ "pythonpath": os.environ.get("PYTHONPATH", None),
}
self.sensor_service.dispatch(trigger="examples.fibonacci", payload=payload)
diff --git a/contrib/hello_st2/sensors/sensor1.py b/contrib/hello_st2/sensors/sensor1.py
index 501de54a98..a4914cdf8b 100644
--- a/contrib/hello_st2/sensors/sensor1.py
+++ b/contrib/hello_st2/sensors/sensor1.py
@@ -14,11 +14,11 @@ def setup(self):
def run(self):
while not self._stop:
- self._logger.debug('HelloSensor dispatching trigger...')
- count = self.sensor_service.get_value('hello_st2.count') or 0
- payload = {'greeting': 'Yo, StackStorm!', 'count': int(count) + 1}
- self.sensor_service.dispatch(trigger='hello_st2.event1', payload=payload)
- self.sensor_service.set_value('hello_st2.count', payload['count'])
+ self._logger.debug("HelloSensor dispatching trigger...")
+ count = self.sensor_service.get_value("hello_st2.count") or 0
+ payload = {"greeting": "Yo, StackStorm!", "count": int(count) + 1}
+ self.sensor_service.dispatch(trigger="hello_st2.event1", payload=payload)
+ self.sensor_service.set_value("hello_st2.count", payload["count"])
eventlet.sleep(60)
def cleanup(self):
diff --git a/contrib/linux/README.md b/contrib/linux/README.md
index 33d872cf86..e2b9f09d44 100644
--- a/contrib/linux/README.md
+++ b/contrib/linux/README.md
@@ -55,4 +55,4 @@ Example trigger payload:
## Troubleshooting
-* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install.
\ No newline at end of file
+* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install.
\ No newline at end of file
diff --git a/contrib/linux/actions/checks/check_loadavg.py b/contrib/linux/actions/checks/check_loadavg.py
index fb7d3938cc..04036924e8 100755
--- a/contrib/linux/actions/checks/check_loadavg.py
+++ b/contrib/linux/actions/checks/check_loadavg.py
@@ -29,7 +29,7 @@
output = {}
try:
- fh = open(loadAvgFile, 'r')
+ fh = open(loadAvgFile, "r")
load = fh.readline().split()[0:3]
except:
print("Error opening %s" % loadAvgFile)
@@ -38,7 +38,7 @@
fh.close()
try:
- fh = open(cpuInfoFile, 'r')
+ fh = open(cpuInfoFile, "r")
for line in fh:
if "processor" in line:
cpus += 1
@@ -48,16 +48,16 @@
finally:
fh.close()
-output['1'] = str(float(load[0]) / cpus)
-output['5'] = str(float(load[1]) / cpus)
-output['15'] = str(float(load[2]) / cpus)
+output["1"] = str(float(load[0]) / cpus)
+output["5"] = str(float(load[1]) / cpus)
+output["15"] = str(float(load[2]) / cpus)
-if time == '1' or time == 'one':
- print(output['1'])
-elif time == '5' or time == 'five':
- print(output['5'])
-elif time == '15' or time == 'fifteen':
- print(output['15'])
+if time == "1" or time == "one":
+ print(output["1"])
+elif time == "5" or time == "five":
+ print(output["5"])
+elif time == "15" or time == "fifteen":
+ print(output["15"])
else:
print(json.dumps(output))
diff --git a/contrib/linux/actions/checks/check_processes.py b/contrib/linux/actions/checks/check_processes.py
index b1ff1af0ae..d2a7db195f 100755
--- a/contrib/linux/actions/checks/check_processes.py
+++ b/contrib/linux/actions/checks/check_processes.py
@@ -41,8 +41,11 @@ def setup(self, debug=False, pidlist=False):
if debug is True:
print("Debug is on")
- self.allProcs = [procs for procs in os.listdir(self.procDir) if procs.isdigit() and
- int(procs) != int(self.myPid)]
+ self.allProcs = [
+ procs
+ for procs in os.listdir(self.procDir)
+ if procs.isdigit() and int(procs) != int(self.myPid)
+ ]
def process(self, criteria):
for p in self.allProcs:
@@ -58,37 +61,37 @@ def process(self, criteria):
cmdfh.close()
fh.close()
- if criteria == 'state':
+ if criteria == "state":
if pInfo[2] == self.state:
self.interestingProcs.append(pInfo)
- elif criteria == 'name':
+ elif criteria == "name":
if re.search(self.name, pInfo[1]):
self.interestingProcs.append(pInfo)
- elif criteria == 'pid':
+ elif criteria == "pid":
if pInfo[0] == self.pid:
self.interestingProcs.append(pInfo)
def byState(self, state):
self.state = state
- self.process(criteria='state')
+ self.process(criteria="state")
self.show()
def byPid(self, pid):
self.pid = pid
- self.process(criteria='pid')
+ self.process(criteria="pid")
self.show()
def byName(self, name):
self.name = name
- self.process(criteria='name')
+ self.process(criteria="name")
self.show()
def run(self, foo, criteria):
- if foo == 'state':
+ if foo == "state":
self.byState(criteria)
- elif foo == 'name':
+ elif foo == "name":
self.byName(criteria)
- elif foo == 'pid':
+ elif foo == "pid":
self.byPid(criteria)
def show(self):
@@ -99,13 +102,13 @@ def show(self):
prettyOut[proc[0]] = proc[1]
if self.pidlist is True:
- pidlist = ' '.join(prettyOut.keys())
+ pidlist = " ".join(prettyOut.keys())
sys.stderr.write(pidlist)
print(json.dumps(prettyOut))
-if __name__ == '__main__':
+if __name__ == "__main__":
if "pidlist" in sys.argv:
pidlist = True
else:
diff --git a/contrib/linux/actions/dig.py b/contrib/linux/actions/dig.py
index 9a3b58a5cd..7eb8518a2a 100644
--- a/contrib/linux/actions/dig.py
+++ b/contrib/linux/actions/dig.py
@@ -25,29 +25,28 @@
class DigAction(Action):
-
def run(self, rand, count, nameserver, hostname, queryopts):
opt_list = []
output = []
- cmd_args = ['dig']
+ cmd_args = ["dig"]
if nameserver:
- nameserver = '@' + nameserver
+ nameserver = "@" + nameserver
cmd_args.append(nameserver)
- if isinstance(queryopts, str) and ',' in queryopts:
- opt_list = queryopts.split(',')
+ if isinstance(queryopts, str) and "," in queryopts:
+ opt_list = queryopts.split(",")
else:
opt_list.append(queryopts)
- cmd_args.extend(['+' + option for option in opt_list])
+ cmd_args.extend(["+" + option for option in opt_list])
cmd_args.append(hostname)
try:
- raw_result = subprocess.Popen(cmd_args,
- stderr=subprocess.PIPE,
- stdout=subprocess.PIPE).communicate()[0]
+ raw_result = subprocess.Popen(
+ cmd_args, stderr=subprocess.PIPE, stdout=subprocess.PIPE
+ ).communicate()[0]
if sys.version_info >= (3,):
# This function might call getpreferred encoding unless we pass
@@ -57,16 +56,19 @@ def run(self, rand, count, nameserver, hostname, queryopts):
else:
result_list_str = str(raw_result)
- result_list = list(filter(None, result_list_str.split('\n')))
+ result_list = list(filter(None, result_list_str.split("\n")))
# NOTE: Python3 supports the FileNotFoundError, the errono.ENOENT is for py2 compat
# for Python3:
# except FileNotFoundError as e:
except OSError as e:
if e.errno == errno.ENOENT:
- return False, "Can't find dig installed in the path (usually /usr/bin/dig). If " \
- "dig isn't installed, you can install it with 'sudo yum install " \
- "bind-utils' or 'sudo apt install dnsutils'"
+ return (
+ False,
+ "Can't find dig installed in the path (usually /usr/bin/dig). If "
+ "dig isn't installed, you can install it with 'sudo yum install "
+ "bind-utils' or 'sudo apt install dnsutils'",
+ )
else:
raise e
diff --git a/contrib/linux/actions/service.py b/contrib/linux/actions/service.py
index 3961438431..335e5038f6 100644
--- a/contrib/linux/actions/service.py
+++ b/contrib/linux/actions/service.py
@@ -26,20 +26,23 @@
distro = platform.linux_distribution()[0]
if len(sys.argv) < 3:
- raise ValueError('Usage: service.py ')
+ raise ValueError("Usage: service.py ")
-args = {'act': quote_unix(sys.argv[1]), 'service': quote_unix(sys.argv[2])}
+args = {"act": quote_unix(sys.argv[1]), "service": quote_unix(sys.argv[2])}
-if re.search(distro, 'Ubuntu'):
- if os.path.isfile("/etc/init/%s.conf" % args['service']):
- cmd_args = ['service', args['service'], args['act']]
- elif os.path.isfile("/etc/init.d/%s" % args['service']):
- cmd_args = ['/etc/init.d/%s' % (args['service']), args['act']]
+if re.search(distro, "Ubuntu"):
+ if os.path.isfile("/etc/init/%s.conf" % args["service"]):
+ cmd_args = ["service", args["service"], args["act"]]
+ elif os.path.isfile("/etc/init.d/%s" % args["service"]):
+ cmd_args = ["/etc/init.d/%s" % (args["service"]), args["act"]]
else:
print("Unknown service")
sys.exit(2)
-elif re.search(distro, 'Redhat') or re.search(distro, 'Fedora') or \
- re.search(distro, 'CentOS Linux'):
- cmd_args = ['systemctl', args['act'], args['service']]
+elif (
+ re.search(distro, "Redhat")
+ or re.search(distro, "Fedora")
+ or re.search(distro, "CentOS Linux")
+):
+ cmd_args = ["systemctl", args["act"], args["service"]]
subprocess.call(cmd_args, shell=False)
diff --git a/contrib/linux/actions/wait_for_ssh.py b/contrib/linux/actions/wait_for_ssh.py
index 4ad4a66050..c29e91ba03 100644
--- a/contrib/linux/actions/wait_for_ssh.py
+++ b/contrib/linux/actions/wait_for_ssh.py
@@ -25,29 +25,47 @@
class BaseAction(Action):
- def run(self, hostname, port, username, password=None, keyfile=None, ssh_timeout=5,
- sleep_delay=20, retries=10):
+ def run(
+ self,
+ hostname,
+ port,
+ username,
+ password=None,
+ keyfile=None,
+ ssh_timeout=5,
+ sleep_delay=20,
+ retries=10,
+ ):
# Note: If neither password nor key file is provided, we try to use system user
# key file
if not password and not keyfile:
keyfile = cfg.CONF.system_user.ssh_key_file
- self.logger.info('Neither "password" nor "keyfile" parameter provided, '
- 'defaulting to using "%s" key file' % (keyfile))
+ self.logger.info(
+ 'Neither "password" nor "keyfile" parameter provided, '
+ 'defaulting to using "%s" key file' % (keyfile)
+ )
- client = ParamikoSSHClient(hostname=hostname, port=port, username=username,
- password=password, key_files=keyfile,
- timeout=ssh_timeout)
+ client = ParamikoSSHClient(
+ hostname=hostname,
+ port=port,
+ username=username,
+ password=password,
+ key_files=keyfile,
+ timeout=ssh_timeout,
+ )
for index in range(retries):
attempt = index + 1
try:
- self.logger.debug('SSH connection attempt: %s' % (attempt))
+ self.logger.debug("SSH connection attempt: %s" % (attempt))
client.connect()
return True
except Exception as e:
- self.logger.info('Attempt %s failed (%s), sleeping for %s seconds...' %
- (attempt, six.text_type(e), sleep_delay))
+ self.logger.info(
+ "Attempt %s failed (%s), sleeping for %s seconds..."
+ % (attempt, six.text_type(e), sleep_delay)
+ )
time.sleep(sleep_delay)
- raise Exception('Exceeded max retries (%s)' % (retries))
+ raise Exception("Exceeded max retries (%s)" % (retries))
diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md
index 7924e91e17..084fcad6a6 100644
--- a/contrib/linux/sensors/README.md
+++ b/contrib/linux/sensors/README.md
@@ -1,6 +1,6 @@
## NOTICE
-File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated.
+File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated.
An example rule to supply a file path is as follows:
@@ -25,5 +25,5 @@ action:
```
-Trigger ``linux.file_watch.line`` still emits the same payload as it used to.
+Trigger ``linux.file_watch.line`` still emits the same payload as it used to.
Just the way to provide the file_path to tail has changed.
diff --git a/contrib/linux/sensors/file_watch_sensor.py b/contrib/linux/sensors/file_watch_sensor.py
index 2597d63926..52e2943116 100644
--- a/contrib/linux/sensors/file_watch_sensor.py
+++ b/contrib/linux/sensors/file_watch_sensor.py
@@ -24,8 +24,9 @@
class FileWatchSensor(Sensor):
def __init__(self, sensor_service, config=None):
- super(FileWatchSensor, self).__init__(sensor_service=sensor_service,
- config=config)
+ super(FileWatchSensor, self).__init__(
+ sensor_service=sensor_service, config=config
+ )
self._trigger = None
self._logger = self._sensor_service.get_logger(__name__)
self._tail = None
@@ -48,16 +49,16 @@ def cleanup(self):
pass
def add_trigger(self, trigger):
- file_path = trigger['parameters'].get('file_path', None)
+ file_path = trigger["parameters"].get("file_path", None)
if not file_path:
self._logger.error('Received trigger type without "file_path" field.')
return
- self._trigger = trigger.get('ref', None)
+ self._trigger = trigger.get("ref", None)
if not self._trigger:
- raise Exception('Trigger %s did not contain a ref.' % trigger)
+ raise Exception("Trigger %s did not contain a ref." % trigger)
# Wait a bit to avoid initialization race in logshipper library
eventlet.sleep(1.0)
@@ -69,7 +70,7 @@ def update_trigger(self, trigger):
pass
def remove_trigger(self, trigger):
- file_path = trigger['parameters'].get('file_path', None)
+ file_path = trigger["parameters"].get("file_path", None)
if not file_path:
self._logger.error('Received trigger type without "file_path" field.')
@@ -83,10 +84,11 @@ def remove_trigger(self, trigger):
def _handle_line(self, file_path, line):
trigger = self._trigger
payload = {
- 'file_path': file_path,
- 'file_name': os.path.basename(file_path),
- 'line': line
+ "file_path": file_path,
+ "file_name": os.path.basename(file_path),
+ "line": line,
}
- self._logger.debug('Sending payload %s for trigger %s to sensor_service.',
- payload, trigger)
+ self._logger.debug(
+ "Sending payload %s for trigger %s to sensor_service.", payload, trigger
+ )
self.sensor_service.dispatch(trigger=trigger, payload=payload)
diff --git a/contrib/linux/tests/test_action_dig.py b/contrib/linux/tests/test_action_dig.py
index 4f363521d9..008cf16e76 100644
--- a/contrib/linux/tests/test_action_dig.py
+++ b/contrib/linux/tests/test_action_dig.py
@@ -27,15 +27,18 @@ def test_run_with_empty_hostname(self):
action = self.get_action_instance()
# Use the defaults from dig.yaml
- result = action.run(rand=False, count=0, nameserver=None, hostname='', queryopts='short')
+ result = action.run(
+ rand=False, count=0, nameserver=None, hostname="", queryopts="short"
+ )
self.assertIsInstance(result, list)
self.assertEqual(len(result), 0)
def test_run_with_empty_queryopts(self):
action = self.get_action_instance()
- results = action.run(rand=False, count=0, nameserver=None, hostname='google.com',
- queryopts='')
+ results = action.run(
+ rand=False, count=0, nameserver=None, hostname="google.com", queryopts=""
+ )
self.assertIsInstance(results, list)
for result in results:
@@ -45,8 +48,13 @@ def test_run_with_empty_queryopts(self):
def test_run(self):
action = self.get_action_instance()
- results = action.run(rand=False, count=0, nameserver=None, hostname='google.com',
- queryopts='short')
+ results = action.run(
+ rand=False,
+ count=0,
+ nameserver=None,
+ hostname="google.com",
+ queryopts="short",
+ )
self.assertIsInstance(results, list)
self.assertGreater(len(results), 0)
diff --git a/contrib/packs/actions/get_config.py b/contrib/packs/actions/get_config.py
index 505ef683c4..07e4654cef 100755
--- a/contrib/packs/actions/get_config.py
+++ b/contrib/packs/actions/get_config.py
@@ -22,8 +22,8 @@
class RenderTemplateAction(Action):
def run(self):
result = {
- 'pack_group': utils.get_pack_group(),
- 'pack_path': utils.get_system_packs_base_path()
+ "pack_group": utils.get_pack_group(),
+ "pack_path": utils.get_system_packs_base_path(),
}
return result
diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml
index 1b8d0d572a..191accd1c3 100644
--- a/contrib/packs/actions/install.meta.yaml
+++ b/contrib/packs/actions/install.meta.yaml
@@ -35,6 +35,6 @@
timeout:
default: 600
required: false
- description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout
+ description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout
type: integer
diff --git a/contrib/packs/actions/pack_mgmt/delete.py b/contrib/packs/actions/pack_mgmt/delete.py
index 93bcc46044..ca0436e834 100644
--- a/contrib/packs/actions/pack_mgmt/delete.py
+++ b/contrib/packs/actions/pack_mgmt/delete.py
@@ -27,15 +27,18 @@
class UninstallPackAction(Action):
def __init__(self, config=None, action_service=None):
- super(UninstallPackAction, self).__init__(config=config, action_service=action_service)
- self._base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path,
- 'virtualenvs/')
+ super(UninstallPackAction, self).__init__(
+ config=config, action_service=action_service
+ )
+ self._base_virtualenvs_path = os.path.join(
+ cfg.CONF.system.base_path, "virtualenvs/"
+ )
def run(self, packs, abs_repo_base, delete_env=True):
intersection = BLOCKED_PACKS & frozenset(packs)
if len(intersection) > 0:
- names = ', '.join(list(intersection))
- raise ValueError('Uninstall includes an uninstallable pack - %s.' % (names))
+ names = ", ".join(list(intersection))
+ raise ValueError("Uninstall includes an uninstallable pack - %s." % (names))
# 1. Delete pack content
for fp in os.listdir(abs_repo_base):
@@ -51,6 +54,8 @@ def run(self, packs, abs_repo_base, delete_env=True):
virtualenv_path = os.path.join(self._base_virtualenvs_path, pack_name)
if os.path.isdir(virtualenv_path):
- self.logger.debug('Deleting virtualenv "%s" for pack "%s"' %
- (virtualenv_path, pack_name))
+ self.logger.debug(
+ 'Deleting virtualenv "%s" for pack "%s"'
+ % (virtualenv_path, pack_name)
+ )
shutil.rmtree(virtualenv_path)
diff --git a/contrib/packs/actions/pack_mgmt/download.py b/contrib/packs/actions/pack_mgmt/download.py
index b4d888630b..cc0f7cd8fb 100644
--- a/contrib/packs/actions/pack_mgmt/download.py
+++ b/contrib/packs/actions/pack_mgmt/download.py
@@ -21,68 +21,85 @@
from st2common.runners.base_action import Action
from st2common.util.pack_management import download_pack
-__all__ = [
- 'DownloadGitRepoAction'
-]
+__all__ = ["DownloadGitRepoAction"]
class DownloadGitRepoAction(Action):
def __init__(self, config=None, action_service=None):
- super(DownloadGitRepoAction, self).__init__(config=config, action_service=action_service)
+ super(DownloadGitRepoAction, self).__init__(
+ config=config, action_service=action_service
+ )
- self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None))
- self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None))
+ self.https_proxy = os.environ.get(
+ "https_proxy", self.config.get("https_proxy", None)
+ )
+ self.http_proxy = os.environ.get(
+ "http_proxy", self.config.get("http_proxy", None)
+ )
self.proxy_ca_bundle_path = os.environ.get(
- 'proxy_ca_bundle_path',
- self.config.get('proxy_ca_bundle_path', None)
+ "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None)
)
- self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None))
+ self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None))
self.proxy_config = None
if self.http_proxy or self.https_proxy:
- self.logger.debug('Using proxy %s',
- self.http_proxy if self.http_proxy else self.https_proxy)
+ self.logger.debug(
+ "Using proxy %s",
+ self.http_proxy if self.http_proxy else self.https_proxy,
+ )
self.proxy_config = {
- 'https_proxy': self.https_proxy,
- 'http_proxy': self.http_proxy,
- 'proxy_ca_bundle_path': self.proxy_ca_bundle_path,
- 'no_proxy': self.no_proxy
+ "https_proxy": self.https_proxy,
+ "http_proxy": self.http_proxy,
+ "proxy_ca_bundle_path": self.proxy_ca_bundle_path,
+ "no_proxy": self.no_proxy,
}
# This is needed for git binary to work with a proxy
- if self.https_proxy and not os.environ.get('https_proxy', None):
- os.environ['https_proxy'] = self.https_proxy
+ if self.https_proxy and not os.environ.get("https_proxy", None):
+ os.environ["https_proxy"] = self.https_proxy
- if self.http_proxy and not os.environ.get('http_proxy', None):
- os.environ['http_proxy'] = self.http_proxy
+ if self.http_proxy and not os.environ.get("http_proxy", None):
+ os.environ["http_proxy"] = self.http_proxy
- if self.no_proxy and not os.environ.get('no_proxy', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.no_proxy and not os.environ.get("no_proxy", None):
+ os.environ["no_proxy"] = self.no_proxy
- if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.proxy_ca_bundle_path and not os.environ.get(
+ "proxy_ca_bundle_path", None
+ ):
+ os.environ["no_proxy"] = self.no_proxy
- def run(self, packs, abs_repo_base, verifyssl=True, force=False,
- dependency_list=None):
+ def run(
+ self, packs, abs_repo_base, verifyssl=True, force=False, dependency_list=None
+ ):
result = {}
pack_url = None
if dependency_list:
for pack_dependency in dependency_list:
- pack_result = download_pack(pack=pack_dependency, abs_repo_base=abs_repo_base,
- verify_ssl=verifyssl, force=force,
- proxy_config=self.proxy_config, force_permissions=True,
- logger=self.logger)
+ pack_result = download_pack(
+ pack=pack_dependency,
+ abs_repo_base=abs_repo_base,
+ verify_ssl=verifyssl,
+ force=force,
+ proxy_config=self.proxy_config,
+ force_permissions=True,
+ logger=self.logger,
+ )
pack_url, pack_ref, pack_result = pack_result
result[pack_ref] = pack_result
else:
for pack in packs:
- pack_result = download_pack(pack=pack, abs_repo_base=abs_repo_base,
- verify_ssl=verifyssl, force=force,
- proxy_config=self.proxy_config,
- force_permissions=True,
- logger=self.logger)
+ pack_result = download_pack(
+ pack=pack,
+ abs_repo_base=abs_repo_base,
+ verify_ssl=verifyssl,
+ force=force,
+ proxy_config=self.proxy_config,
+ force_permissions=True,
+ logger=self.logger,
+ )
pack_url, pack_ref, pack_result = pack_result
result[pack_ref] = pack_result
@@ -99,14 +116,16 @@ def _validate_result(result, repo_url):
if not atleast_one_success:
message_list = []
- message_list.append('The pack has not been downloaded from "%s".\n' % (repo_url))
- message_list.append('Errors:')
+ message_list.append(
+ 'The pack has not been downloaded from "%s".\n' % (repo_url)
+ )
+ message_list.append("Errors:")
for pack, value in result.items():
success, error = value
message_list.append(error)
- message = '\n'.join(message_list)
+ message = "\n".join(message_list)
raise Exception(message)
return sanitized_result
diff --git a/contrib/packs/actions/pack_mgmt/get_installed.py b/contrib/packs/actions/pack_mgmt/get_installed.py
index eaa88b6319..36f2504b85 100644
--- a/contrib/packs/actions/pack_mgmt/get_installed.py
+++ b/contrib/packs/actions/pack_mgmt/get_installed.py
@@ -28,6 +28,7 @@
class GetInstalled(Action):
""""Get information about installed pack."""
+
def run(self, pack):
"""
:param pack: Installed Pack Name to get info about
@@ -47,46 +48,42 @@ def run(self, pack):
# Pack doesn't exist, finish execution normally with empty metadata
if not os.path.isdir(pack_path):
- return {
- 'pack': None,
- 'git_status': None
- }
+ return {"pack": None, "git_status": None}
if not metadata_file:
- error = ('Pack "%s" doesn\'t contain pack.yaml file.' % (pack))
+ error = 'Pack "%s" doesn\'t contain pack.yaml file.' % (pack)
raise Exception(error)
try:
details = self._parse_yaml_file(metadata_file)
except Exception as e:
- error = ('Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % (pack,
- six.text_type(e)))
+ error = 'Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % (
+ pack,
+ six.text_type(e),
+ )
raise Exception(error)
try:
repo = Repo(pack_path)
git_status = "Status:\n%s\n\nRemotes:\n%s" % (
- repo.git.status().split('\n')[0],
- "\n".join([remote.url for remote in repo.remotes])
+ repo.git.status().split("\n")[0],
+ "\n".join([remote.url for remote in repo.remotes]),
)
ahead_behind = repo.git.rev_list(
- '--left-right', '--count', 'HEAD...origin/master'
+ "--left-right", "--count", "HEAD...origin/master"
).split()
# Dear god.
- if ahead_behind != [u'0', u'0']:
+ if ahead_behind != ["0", "0"]:
git_status += "\n\n"
- git_status += "%s commits ahead " if ahead_behind[0] != u'0' else ""
- git_status += "and " if u'0' not in ahead_behind else ""
- git_status += "%s commits behind " if ahead_behind[1] != u'0' else ""
+ git_status += "%s commits ahead " if ahead_behind[0] != "0" else ""
+ git_status += "and " if "0" not in ahead_behind else ""
+ git_status += "%s commits behind " if ahead_behind[1] != "0" else ""
git_status += "origin/master."
except InvalidGitRepositoryError:
git_status = None
- return {
- 'pack': details,
- 'git_status': git_status
- }
+ return {"pack": details, "git_status": git_status}
def _parse_yaml_file(self, file_path):
with open(file_path) as data_file:
diff --git a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py
index 60ab2c9503..b9168526a2 100644
--- a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py
+++ b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py
@@ -40,7 +40,7 @@ def run(self, packs_status, nested):
return result
for pack, status in six.iteritems(packs_status):
- if 'success' not in status.lower():
+ if "success" not in status.lower():
continue
dependency_packs = get_dependency_list(pack)
@@ -50,40 +50,51 @@ def run(self, packs_status, nested):
for dep_pack in dependency_packs:
name_or_url, pack_version = self.get_name_and_version(dep_pack)
- if len(name_or_url.split('/')) == 1:
+ if len(name_or_url.split("/")) == 1:
pack_name = name_or_url
else:
name_or_git = name_or_url.split("/")[-1]
- pack_name = name_or_git if '.git' not in name_or_git else \
- name_or_git.split('.')[0]
+ pack_name = (
+ name_or_git
+ if ".git" not in name_or_git
+ else name_or_git.split(".")[0]
+ )
# Check existing pack by pack name
existing_pack_version = get_pack_version(pack_name)
# Try one more time to get existing pack version by name if 'stackstorm-' is in
# pack name
- if not existing_pack_version and 'stackstorm-' in pack_name.lower():
- existing_pack_version = get_pack_version(pack_name.split('stackstorm-')[-1])
+ if not existing_pack_version and "stackstorm-" in pack_name.lower():
+ existing_pack_version = get_pack_version(
+ pack_name.split("stackstorm-")[-1]
+ )
if existing_pack_version:
- if existing_pack_version and not existing_pack_version.startswith('v'):
- existing_pack_version = 'v' + existing_pack_version
- if pack_version and not pack_version.startswith('v'):
- pack_version = 'v' + pack_version
- if pack_version and existing_pack_version != pack_version \
- and dep_pack not in conflict_list:
+ if existing_pack_version and not existing_pack_version.startswith(
+ "v"
+ ):
+ existing_pack_version = "v" + existing_pack_version
+ if pack_version and not pack_version.startswith("v"):
+ pack_version = "v" + pack_version
+ if (
+ pack_version
+ and existing_pack_version != pack_version
+ and dep_pack not in conflict_list
+ ):
conflict_list.append(dep_pack)
else:
- conflict = self.check_dependency_list_for_conflict(name_or_url, pack_version,
- dependency_list)
+ conflict = self.check_dependency_list_for_conflict(
+ name_or_url, pack_version, dependency_list
+ )
if conflict:
conflict_list.append(dep_pack)
elif dep_pack not in dependency_list:
dependency_list.append(dep_pack)
- result['dependency_list'] = dependency_list
- result['conflict_list'] = conflict_list
- result['nested'] = nested - 1
+ result["dependency_list"] = dependency_list
+ result["conflict_list"] = conflict_list
+ result["nested"] = nested - 1
return result
@@ -112,7 +123,7 @@ def get_pack_version(pack=None):
pack_path = get_pack_base_path(pack)
try:
pack_metadata = get_pack_metadata(pack_dir=pack_path)
- result = pack_metadata.get('version', None)
+ result = pack_metadata.get("version", None)
except Exception:
result = None
finally:
@@ -124,9 +135,9 @@ def get_dependency_list(pack=None):
try:
pack_metadata = get_pack_metadata(pack_dir=pack_path)
- result = pack_metadata.get('dependencies', None)
+ result = pack_metadata.get("dependencies", None)
except Exception:
- print('Could not open pack.yaml at location %s' % pack_path)
+ print("Could not open pack.yaml at location %s" % pack_path)
result = None
finally:
return result
diff --git a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py
index 445a5df0c2..e8f42dcbb6 100755
--- a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py
+++ b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py
@@ -34,7 +34,7 @@ def run(self, packs_status):
return result
for pack, status in six.iteritems(packs_status):
- if 'success' not in status.lower():
+ if "success" not in status.lower():
continue
warning = get_warnings(pack)
@@ -42,7 +42,7 @@ def run(self, packs_status):
if warning:
warning_list.append(warning)
- result['warning_list'] = warning_list
+ result["warning_list"] = warning_list
return result
@@ -54,6 +54,6 @@ def get_warnings(pack=None):
pack_metadata = get_pack_metadata(pack_dir=pack_path)
result = get_pack_warnings(pack_metadata)
except Exception:
- print('Could not open pack.yaml at location %s' % pack_path)
+ print("Could not open pack.yaml at location %s" % pack_path)
finally:
return result
diff --git a/contrib/packs/actions/pack_mgmt/register.py b/contrib/packs/actions/pack_mgmt/register.py
index 220962f0f4..1587333d5b 100644
--- a/contrib/packs/actions/pack_mgmt/register.py
+++ b/contrib/packs/actions/pack_mgmt/register.py
@@ -19,21 +19,19 @@
from st2client.models.keyvalue import KeyValuePair # pylint: disable=no-name-in-module
from st2common.runners.base_action import Action
-__all__ = [
- 'St2RegisterAction'
-]
+__all__ = ["St2RegisterAction"]
COMPATIBILITY_TRANSFORMATIONS = {
- 'runners': 'runner',
- 'triggers': 'trigger',
- 'sensors': 'sensor',
- 'actions': 'action',
- 'rules': 'rule',
- 'rule_types': 'rule_type',
- 'aliases': 'alias',
- 'policiy_types': 'policy_type',
- 'policies': 'policy',
- 'configs': 'config',
+ "runners": "runner",
+ "triggers": "trigger",
+ "sensors": "sensor",
+ "actions": "action",
+ "rules": "rule",
+ "rule_types": "rule_type",
+ "aliases": "alias",
+ "policiy_types": "policy_type",
+ "policies": "policy",
+ "configs": "config",
}
@@ -63,23 +61,23 @@ def __init__(self, config):
def run(self, register, packs=None):
types = []
- for type in register.split(','):
+ for type in register.split(","):
if type in COMPATIBILITY_TRANSFORMATIONS:
types.append(COMPATIBILITY_TRANSFORMATIONS[type])
else:
types.append(type)
- method_kwargs = {
- 'types': types
- }
+ method_kwargs = {"types": types}
packs.reverse()
if packs:
- method_kwargs['packs'] = packs
+ method_kwargs["packs"] = packs
- result = self._run_client_method(method=self.client.packs.register,
- method_kwargs=method_kwargs,
- format_func=format_result)
+ result = self._run_client_method(
+ method=self.client.packs.register,
+ method_kwargs=method_kwargs,
+ format_func=format_result,
+ )
# TODO: make sure to return proper model
return result
@@ -90,42 +88,48 @@ def _get_client(self):
client_kwargs = {}
if cacert:
- client_kwargs['cacert'] = cacert
+ client_kwargs["cacert"] = cacert
- return self._client(base_url=base_url, api_url=api_url,
- auth_url=auth_url, token=token,
- **client_kwargs)
+ return self._client(
+ base_url=base_url,
+ api_url=api_url,
+ auth_url=auth_url,
+ token=token,
+ **client_kwargs,
+ )
def _get_st2_urls(self):
# First try to use base_url from config.
- base_url = self.config.get('base_url', None)
- api_url = self.config.get('api_url', None)
- auth_url = self.config.get('auth_url', None)
+ base_url = self.config.get("base_url", None)
+ api_url = self.config.get("api_url", None)
+ auth_url = self.config.get("auth_url", None)
# not found look up from env vars. Assuming the pack is
# configuered to work with current StackStorm instance.
if not base_url:
- api_url = os.environ.get('ST2_ACTION_API_URL', None)
- auth_url = os.environ.get('ST2_ACTION_AUTH_URL', None)
+ api_url = os.environ.get("ST2_ACTION_API_URL", None)
+ auth_url = os.environ.get("ST2_ACTION_AUTH_URL", None)
return base_url, api_url, auth_url
def _get_auth_token(self):
# First try to use auth_token from config.
- token = self.config.get('auth_token', None)
+ token = self.config.get("auth_token", None)
# not found look up from env vars. Assuming the pack is
# configuered to work with current StackStorm instance.
if not token:
- token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None)
+ token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None)
return token
def _get_cacert(self):
- cacert = self.config.get('cacert', None)
+ cacert = self.config.get("cacert", None)
return cacert
- def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=None):
+ def _run_client_method(
+ self, method, method_kwargs, format_func, format_kwargs=None
+ ):
"""
Run the provided client method and format the result.
@@ -144,8 +148,9 @@ def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=N
# This is a work around since the default values can only be strings
method_kwargs = filter_none_values(method_kwargs)
method_name = method.__name__
- self.logger.debug('Calling client method "%s" with kwargs "%s"' % (method_name,
- method_kwargs))
+ self.logger.debug(
+ 'Calling client method "%s" with kwargs "%s"' % (method_name, method_kwargs)
+ )
result = method(**method_kwargs)
result = format_func(result, **format_kwargs or {})
diff --git a/contrib/packs/actions/pack_mgmt/search.py b/contrib/packs/actions/pack_mgmt/search.py
index b7cb07f7fc..dd732c1b29 100644
--- a/contrib/packs/actions/pack_mgmt/search.py
+++ b/contrib/packs/actions/pack_mgmt/search.py
@@ -22,43 +22,51 @@
class PackSearch(Action):
def __init__(self, config=None, action_service=None):
super(PackSearch, self).__init__(config=config, action_service=action_service)
- self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None))
- self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None))
+ self.https_proxy = os.environ.get(
+ "https_proxy", self.config.get("https_proxy", None)
+ )
+ self.http_proxy = os.environ.get(
+ "http_proxy", self.config.get("http_proxy", None)
+ )
self.proxy_ca_bundle_path = os.environ.get(
- 'proxy_ca_bundle_path',
- self.config.get('proxy_ca_bundle_path', None)
+ "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None)
)
- self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None))
+ self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None))
self.proxy_config = None
if self.http_proxy or self.https_proxy:
- self.logger.debug('Using proxy %s',
- self.http_proxy if self.http_proxy else self.https_proxy)
+ self.logger.debug(
+ "Using proxy %s",
+ self.http_proxy if self.http_proxy else self.https_proxy,
+ )
self.proxy_config = {
- 'https_proxy': self.https_proxy,
- 'http_proxy': self.http_proxy,
- 'proxy_ca_bundle_path': self.proxy_ca_bundle_path,
- 'no_proxy': self.no_proxy
+ "https_proxy": self.https_proxy,
+ "http_proxy": self.http_proxy,
+ "proxy_ca_bundle_path": self.proxy_ca_bundle_path,
+ "no_proxy": self.no_proxy,
}
- if self.https_proxy and not os.environ.get('https_proxy', None):
- os.environ['https_proxy'] = self.https_proxy
+ if self.https_proxy and not os.environ.get("https_proxy", None):
+ os.environ["https_proxy"] = self.https_proxy
- if self.http_proxy and not os.environ.get('http_proxy', None):
- os.environ['http_proxy'] = self.http_proxy
+ if self.http_proxy and not os.environ.get("http_proxy", None):
+ os.environ["http_proxy"] = self.http_proxy
- if self.no_proxy and not os.environ.get('no_proxy', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.no_proxy and not os.environ.get("no_proxy", None):
+ os.environ["no_proxy"] = self.no_proxy
- if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.proxy_ca_bundle_path and not os.environ.get(
+ "proxy_ca_bundle_path", None
+ ):
+ os.environ["no_proxy"] = self.no_proxy
""""Search for packs in StackStorm Exchange and other directories."""
+
def run(self, query):
"""
:param query: A word or a phrase to search for
:type query: ``str``
"""
- self.logger.debug('Proxy config: %s', self.proxy_config)
+ self.logger.debug("Proxy config: %s", self.proxy_config)
return search_pack_index(query, proxy_config=self.proxy_config)
diff --git a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py
index 23f8a75ef7..bf7a32ed7e 100644
--- a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py
+++ b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py
@@ -18,9 +18,7 @@
from st2common.runners.base_action import Action
from st2common.util.virtualenvs import setup_pack_virtualenv
-__all__ = [
- 'SetupVirtualEnvironmentAction'
-]
+__all__ = ["SetupVirtualEnvironmentAction"]
class SetupVirtualEnvironmentAction(Action):
@@ -37,42 +35,50 @@ class SetupVirtualEnvironmentAction(Action):
creation of the virtual environment and performs an update of the
current dependencies as well as an installation of new dependencies
"""
+
def __init__(self, config=None, action_service=None):
super(SetupVirtualEnvironmentAction, self).__init__(
- config=config,
- action_service=action_service)
+ config=config, action_service=action_service
+ )
- self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None))
- self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None))
+ self.https_proxy = os.environ.get(
+ "https_proxy", self.config.get("https_proxy", None)
+ )
+ self.http_proxy = os.environ.get(
+ "http_proxy", self.config.get("http_proxy", None)
+ )
self.proxy_ca_bundle_path = os.environ.get(
- 'proxy_ca_bundle_path',
- self.config.get('proxy_ca_bundle_path', None)
+ "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None)
)
- self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None))
+ self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None))
self.proxy_config = None
if self.http_proxy or self.https_proxy:
- self.logger.debug('Using proxy %s',
- self.http_proxy if self.http_proxy else self.https_proxy)
+ self.logger.debug(
+ "Using proxy %s",
+ self.http_proxy if self.http_proxy else self.https_proxy,
+ )
self.proxy_config = {
- 'https_proxy': self.https_proxy,
- 'http_proxy': self.http_proxy,
- 'proxy_ca_bundle_path': self.proxy_ca_bundle_path,
- 'no_proxy': self.no_proxy
+ "https_proxy": self.https_proxy,
+ "http_proxy": self.http_proxy,
+ "proxy_ca_bundle_path": self.proxy_ca_bundle_path,
+ "no_proxy": self.no_proxy,
}
- if self.https_proxy and not os.environ.get('https_proxy', None):
- os.environ['https_proxy'] = self.https_proxy
+ if self.https_proxy and not os.environ.get("https_proxy", None):
+ os.environ["https_proxy"] = self.https_proxy
- if self.http_proxy and not os.environ.get('http_proxy', None):
- os.environ['http_proxy'] = self.http_proxy
+ if self.http_proxy and not os.environ.get("http_proxy", None):
+ os.environ["http_proxy"] = self.http_proxy
- if self.no_proxy and not os.environ.get('no_proxy', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.no_proxy and not os.environ.get("no_proxy", None):
+ os.environ["no_proxy"] = self.no_proxy
- if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None):
- os.environ['no_proxy'] = self.no_proxy
+ if self.proxy_ca_bundle_path and not os.environ.get(
+ "proxy_ca_bundle_path", None
+ ):
+ os.environ["no_proxy"] = self.no_proxy
def run(self, packs, update=False, no_download=True):
"""
@@ -84,10 +90,15 @@ def run(self, packs, update=False, no_download=True):
"""
for pack_name in packs:
- setup_pack_virtualenv(pack_name=pack_name, update=update, logger=self.logger,
- proxy_config=self.proxy_config,
- no_download=no_download)
-
- message = ('Successfully set up virtualenv for the following packs: %s' %
- (', '.join(packs)))
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=update,
+ logger=self.logger,
+ proxy_config=self.proxy_config,
+ no_download=no_download,
+ )
+
+ message = "Successfully set up virtualenv for the following packs: %s" % (
+ ", ".join(packs)
+ )
return message
diff --git a/contrib/packs/actions/pack_mgmt/show_remote.py b/contrib/packs/actions/pack_mgmt/show_remote.py
index ba5bff8141..6b2f655594 100644
--- a/contrib/packs/actions/pack_mgmt/show_remote.py
+++ b/contrib/packs/actions/pack_mgmt/show_remote.py
@@ -19,11 +19,10 @@
class ShowRemote(Action):
"""Get detailed information about an available pack from the StackStorm Exchange index"""
+
def run(self, pack):
"""
:param pack: Pack Name to get info about
:type pack: ``str``
"""
- return {
- 'pack': get_pack_from_index(pack)
- }
+ return {"pack": get_pack_from_index(pack)}
diff --git a/contrib/packs/actions/pack_mgmt/unload.py b/contrib/packs/actions/pack_mgmt/unload.py
index c72cdf9ce1..46caf9cc7a 100644
--- a/contrib/packs/actions/pack_mgmt/unload.py
+++ b/contrib/packs/actions/pack_mgmt/unload.py
@@ -36,31 +36,48 @@
class UnregisterPackAction(BaseAction):
def __init__(self, config=None, action_service=None):
- super(UnregisterPackAction, self).__init__(config=config, action_service=action_service)
+ super(UnregisterPackAction, self).__init__(
+ config=config, action_service=action_service
+ )
self.initialize()
def initialize(self):
# 1. Setup db connection
- username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None
- password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None
- db_setup(cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port,
- username=username, password=password,
- ssl=cfg.CONF.database.ssl,
- ssl_keyfile=cfg.CONF.database.ssl_keyfile,
- ssl_certfile=cfg.CONF.database.ssl_certfile,
- ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs,
- ssl_ca_certs=cfg.CONF.database.ssl_ca_certs,
- authentication_mechanism=cfg.CONF.database.authentication_mechanism,
- ssl_match_hostname=cfg.CONF.database.ssl_match_hostname)
+ username = (
+ cfg.CONF.database.username
+ if hasattr(cfg.CONF.database, "username")
+ else None
+ )
+ password = (
+ cfg.CONF.database.password
+ if hasattr(cfg.CONF.database, "password")
+ else None
+ )
+ db_setup(
+ cfg.CONF.database.db_name,
+ cfg.CONF.database.host,
+ cfg.CONF.database.port,
+ username=username,
+ password=password,
+ ssl=cfg.CONF.database.ssl,
+ ssl_keyfile=cfg.CONF.database.ssl_keyfile,
+ ssl_certfile=cfg.CONF.database.ssl_certfile,
+ ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs,
+ ssl_ca_certs=cfg.CONF.database.ssl_ca_certs,
+ authentication_mechanism=cfg.CONF.database.authentication_mechanism,
+ ssl_match_hostname=cfg.CONF.database.ssl_match_hostname,
+ )
def run(self, packs):
intersection = BLOCKED_PACKS & frozenset(packs)
if len(intersection) > 0:
- names = ', '.join(list(intersection))
- raise ValueError('Unregister includes an unregisterable pack - %s.' % (names))
+ names = ", ".join(list(intersection))
+ raise ValueError(
+ "Unregister includes an unregisterable pack - %s." % (names)
+ )
for pack in packs:
- self.logger.debug('Removing pack %s.', pack)
+ self.logger.debug("Removing pack %s.", pack)
self._unregister_sensors(pack=pack)
self._unregister_trigger_types(pack=pack)
self._unregister_triggers(pack=pack)
@@ -69,21 +86,27 @@ def run(self, packs):
self._unregister_aliases(pack=pack)
self._unregister_policies(pack=pack)
self._unregister_pack(pack=pack)
- self.logger.info('Removed pack %s.', pack)
+ self.logger.info("Removed pack %s.", pack)
def _unregister_sensors(self, pack):
return self._delete_pack_db_objects(pack=pack, access_cls=SensorType)
def _unregister_trigger_types(self, pack):
- deleted_trigger_types_dbs = self._delete_pack_db_objects(pack=pack, access_cls=TriggerType)
+ deleted_trigger_types_dbs = self._delete_pack_db_objects(
+ pack=pack, access_cls=TriggerType
+ )
# 2. Check if deleted trigger is used by any other rules outside this pack
for trigger_type_db in deleted_trigger_types_dbs:
- rule_dbs = Rule.query(trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack)
+ rule_dbs = Rule.query(
+ trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack
+ )
for rule_db in rule_dbs:
- self.logger.warning('Rule "%s" references deleted trigger "%s"' %
- (rule_db.name, trigger_type_db.ref))
+ self.logger.warning(
+ 'Rule "%s" references deleted trigger "%s"'
+ % (rule_db.name, trigger_type_db.ref)
+ )
return deleted_trigger_types_dbs
@@ -136,25 +159,25 @@ def _delete_pack_db_object(self, pack):
pack_db = None
if not pack_db:
- self.logger.exception('Pack DB object not found')
+ self.logger.exception("Pack DB object not found")
return
try:
Pack.delete(pack_db)
except:
- self.logger.exception('Failed to remove DB object %s.', pack_db)
+ self.logger.exception("Failed to remove DB object %s.", pack_db)
def _delete_config_schema_db_object(self, pack):
try:
config_schema_db = ConfigSchema.get_by_pack(value=pack)
except StackStormDBObjectNotFoundError:
- self.logger.exception('ConfigSchemaDB object not found')
+ self.logger.exception("ConfigSchemaDB object not found")
return
try:
ConfigSchema.delete(config_schema_db)
except:
- self.logger.exception('Failed to remove DB object %s.', config_schema_db)
+ self.logger.exception("Failed to remove DB object %s.", config_schema_db)
def _delete_pack_db_objects(self, pack, access_cls):
db_objs = access_cls.get_all(pack=pack)
@@ -166,6 +189,6 @@ def _delete_pack_db_objects(self, pack, access_cls):
access_cls.delete(db_obj)
deleted_objs.append(db_obj)
except:
- self.logger.exception('Failed to remove DB object %s.', db_obj)
+ self.logger.exception("Failed to remove DB object %s.", db_obj)
return deleted_objs
diff --git a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py
index aedc993f6b..abde082ed3 100644
--- a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py
+++ b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py
@@ -32,7 +32,7 @@ def run(self, packs_status, packs_list=None):
packs = []
for pack_name, status in six.iteritems(packs_status):
- if 'success' in status.lower():
+ if "success" in status.lower():
packs.append(pack_name)
packs_list.extend(packs)
diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml
index 18d1b3df15..47091705f3 100644
--- a/contrib/packs/actions/setup_virtualenv.yaml
+++ b/contrib/packs/actions/setup_virtualenv.yaml
@@ -27,5 +27,5 @@
timeout:
default: 600
required: false
- description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout
+ description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout
type: integer
diff --git a/contrib/packs/tests/test_action_aliases.py b/contrib/packs/tests/test_action_aliases.py
index 858a167751..ecfebe8b68 100644
--- a/contrib/packs/tests/test_action_aliases.py
+++ b/contrib/packs/tests/test_action_aliases.py
@@ -19,73 +19,65 @@ class PackGet(BaseActionAliasTestCase):
action_alias_name = "pack_get"
def test_alias_pack_get(self):
- format_string = self.action_alias_db.formats[0]['representation'][0]
+ format_string = self.action_alias_db.formats[0]["representation"][0]
format_strings = self.action_alias_db.get_format_strings()
command = "pack get st2"
- expected_parameters = {
- 'pack': "st2"
- }
+ expected_parameters = {"pack": "st2"}
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
self.assertCommandMatchesExactlyOneFormatString(
- format_strings=format_strings,
- command=command)
+ format_strings=format_strings, command=command
+ )
class PackInstall(BaseActionAliasTestCase):
action_alias_name = "pack_install"
def test_alias_pack_install(self):
- format_string = self.action_alias_db.formats[0]['representation'][0]
+ format_string = self.action_alias_db.formats[0]["representation"][0]
command = "pack install st2"
- expected_parameters = {
- 'packs': "st2"
- }
+ expected_parameters = {"packs": "st2"}
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
class PackSearch(BaseActionAliasTestCase):
action_alias_name = "pack_search"
def test_alias_pack_search(self):
- format_string = self.action_alias_db.formats[0]['representation'][0]
+ format_string = self.action_alias_db.formats[0]["representation"][0]
format_strings = self.action_alias_db.get_format_strings()
command = "pack search st2"
- expected_parameters = {
- 'query': "st2"
- }
+ expected_parameters = {"query": "st2"}
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
self.assertCommandMatchesExactlyOneFormatString(
- format_strings=format_strings,
- command=command)
+ format_strings=format_strings, command=command
+ )
class PackShow(BaseActionAliasTestCase):
action_alias_name = "pack_show"
def test_alias_pack_show(self):
- format_string = self.action_alias_db.formats[0]['representation'][0]
+ format_string = self.action_alias_db.formats[0]["representation"][0]
format_strings = self.action_alias_db.get_format_strings()
command = "pack show st2"
- expected_parameters = {
- 'pack': "st2"
- }
+ expected_parameters = {"pack": "st2"}
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
self.assertCommandMatchesExactlyOneFormatString(
- format_strings=format_strings,
- command=command)
+ format_strings=format_strings, command=command
+ )
diff --git a/contrib/packs/tests/test_action_download.py b/contrib/packs/tests/test_action_download.py
index 3eeda00886..c29e95fccc 100644
--- a/contrib/packs/tests/test_action_download.py
+++ b/contrib/packs/tests/test_action_download.py
@@ -22,6 +22,7 @@
import hashlib
from st2common.util.monkey_patch import use_select_poll_workaround
+
use_select_poll_workaround()
from lockfile import LockFile
@@ -46,7 +47,7 @@
"author": "st2-dev",
"keywords": ["some", "search", "another", "terms"],
"email": "info@stackstorm.com",
- "description": "st2 pack to test package management pipeline"
+ "description": "st2 pack to test package management pipeline",
},
"test2": {
"version": "0.5.0",
@@ -55,7 +56,7 @@
"author": "stanley",
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
- "description": "another st2 pack to test package management pipeline"
+ "description": "another st2 pack to test package management pipeline",
},
"test3": {
"version": "0.5.0",
@@ -65,16 +66,17 @@
"author": "stanley",
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
- "description": "another st2 pack to test package management pipeline"
+ "description": "another st2 pack to test package management pipeline",
},
"test4": {
"version": "0.5.0",
"name": "test4",
"repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4",
"author": "stanley",
- "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com",
- "description": "another st2 pack to test package management pipeline"
- }
+ "keywords": ["some", "special", "terms"],
+ "email": "info@stackstorm.com",
+ "description": "another st2 pack to test package management pipeline",
+ },
}
@@ -85,7 +87,7 @@ def mock_is_dir_func(path):
"""
Mock function which returns True if path ends with .git
"""
- if path.endswith('.git'):
+ if path.endswith(".git"):
return True
return original_is_dir_func(path)
@@ -95,9 +97,9 @@ def mock_get_gitref(repo, ref):
Mock get_gitref function which return mocked object if ref passed is
PACK_INDEX['test']['version']
"""
- if PACK_INDEX['test']['version'] in ref:
- if ref[0] == 'v':
- return mock.MagicMock(hexsha=PACK_INDEX['test']['version'])
+ if PACK_INDEX["test"]["version"] in ref:
+ if ref[0] == "v":
+ return mock.MagicMock(hexsha=PACK_INDEX["test"]["version"])
else:
return None
elif ref:
@@ -106,21 +108,24 @@ def mock_get_gitref(repo, ref):
return None
-@mock.patch.object(pack_service, 'fetch_pack_index', mock.MagicMock(return_value=(PACK_INDEX, {})))
+@mock.patch.object(
+ pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {}))
+)
class DownloadGitRepoActionTestCase(BaseActionTestCase):
action_cls = DownloadGitRepoAction
def setUp(self):
super(DownloadGitRepoActionTestCase, self).setUp()
- clone_from = mock.patch.object(Repo, 'clone_from')
+ clone_from = mock.patch.object(Repo, "clone_from")
self.addCleanup(clone_from.stop)
self.clone_from = clone_from.start()
self.expand_user_path = tempfile.mkdtemp()
- expand_user = mock.patch.object(os.path, 'expanduser',
- mock.MagicMock(return_value=self.expand_user_path))
+ expand_user = mock.patch.object(
+ os.path, "expanduser", mock.MagicMock(return_value=self.expand_user_path)
+ )
self.addCleanup(expand_user.stop)
self.expand_user = expand_user.start()
@@ -132,8 +137,10 @@ def setUp(self):
def side_effect(url, to_path, **kwargs):
# Since we have no way to pass pack name here, we would have to derive it from repo url
- fixture_name = url.split('/')[-1]
- fixture_path = os.path.join(self._get_base_pack_path(), 'tests/fixtures', fixture_name)
+ fixture_name = url.split("/")[-1]
+ fixture_path = os.path.join(
+ self._get_base_pack_path(), "tests/fixtures", fixture_name
+ )
shutil.copytree(fixture_path, to_path)
return self.repo_instance
@@ -145,13 +152,15 @@ def tearDown(self):
def test_run_pack_download(self):
action = self.get_action_instance()
- result = action.run(packs=['test'], abs_repo_base=self.repo_base)
- temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest()
+ result = action.run(packs=["test"], abs_repo_base=self.repo_base)
+ temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest()
- self.assertEqual(result, {'test': 'Success.'})
- self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dir))
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml')))
+ self.assertEqual(result, {"test": "Success."})
+ self.clone_from.assert_called_once_with(
+ PACK_INDEX["test"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dir),
+ )
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml")))
self.repo_instance.git.checkout.assert_called()
self.repo_instance.git.branch.assert_called()
@@ -159,65 +168,81 @@ def test_run_pack_download(self):
def test_run_pack_download_dependencies(self):
action = self.get_action_instance()
- result = action.run(packs=['test'], dependency_list=['test2', 'test4'],
- abs_repo_base=self.repo_base)
+ result = action.run(
+ packs=["test"],
+ dependency_list=["test2", "test4"],
+ abs_repo_base=self.repo_base,
+ )
temp_dirs = [
- hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest(),
- hashlib.md5(PACK_INDEX['test4']['repo_url'].encode()).hexdigest()
+ hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(),
+ hashlib.md5(PACK_INDEX["test4"]["repo_url"].encode()).hexdigest(),
]
- self.assertEqual(result, {'test2': 'Success.', 'test4': 'Success.'})
- self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dirs[0]))
- self.clone_from.assert_any_call(PACK_INDEX['test4']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dirs[1]))
+ self.assertEqual(result, {"test2": "Success.", "test4": "Success."})
+ self.clone_from.assert_any_call(
+ PACK_INDEX["test2"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dirs[0]),
+ )
+ self.clone_from.assert_any_call(
+ PACK_INDEX["test4"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dirs[1]),
+ )
self.assertEqual(self.clone_from.call_count, 2)
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml')))
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test4/pack.yaml')))
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml")))
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test4/pack.yaml")))
def test_run_pack_download_existing_pack(self):
action = self.get_action_instance()
- action.run(packs=['test'], abs_repo_base=self.repo_base)
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml')))
+ action.run(packs=["test"], abs_repo_base=self.repo_base)
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml")))
- result = action.run(packs=['test'], abs_repo_base=self.repo_base)
+ result = action.run(packs=["test"], abs_repo_base=self.repo_base)
- self.assertEqual(result, {'test': 'Success.'})
+ self.assertEqual(result, {"test": "Success."})
def test_run_pack_download_multiple_packs(self):
action = self.get_action_instance()
- result = action.run(packs=['test', 'test2'], abs_repo_base=self.repo_base)
+ result = action.run(packs=["test", "test2"], abs_repo_base=self.repo_base)
temp_dirs = [
- hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest(),
- hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest()
+ hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest(),
+ hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(),
]
- self.assertEqual(result, {'test': 'Success.', 'test2': 'Success.'})
- self.clone_from.assert_any_call(PACK_INDEX['test']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dirs[0]))
- self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dirs[1]))
+ self.assertEqual(result, {"test": "Success.", "test2": "Success."})
+ self.clone_from.assert_any_call(
+ PACK_INDEX["test"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dirs[0]),
+ )
+ self.clone_from.assert_any_call(
+ PACK_INDEX["test2"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dirs[1]),
+ )
self.assertEqual(self.clone_from.call_count, 2)
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml')))
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml')))
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml")))
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml")))
- @mock.patch.object(Repo, 'clone_from')
+ @mock.patch.object(Repo, "clone_from")
def test_run_pack_download_error(self, clone_from):
- clone_from.side_effect = Exception('Something went terribly wrong during the clone')
+ clone_from.side_effect = Exception(
+ "Something went terribly wrong during the clone"
+ )
action = self.get_action_instance()
- self.assertRaises(Exception, action.run, packs=['test'], abs_repo_base=self.repo_base)
+ self.assertRaises(
+ Exception, action.run, packs=["test"], abs_repo_base=self.repo_base
+ )
def test_run_pack_download_no_tag(self):
self.repo_instance.commit.side_effect = BadName
action = self.get_action_instance()
- self.assertRaises(ValueError, action.run, packs=['test=1.2.3'],
- abs_repo_base=self.repo_base)
+ self.assertRaises(
+ ValueError, action.run, packs=["test=1.2.3"], abs_repo_base=self.repo_base
+ )
def test_run_pack_lock_is_already_acquired(self):
action = self.get_action_instance()
- temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest()
+ temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest()
original_acquire = LockFile.acquire
@@ -227,15 +252,20 @@ def mock_acquire(self, timeout=None):
LockFile.acquire = mock_acquire
try:
- lock_file = LockFile('/tmp/%s' % (temp_dir))
+ lock_file = LockFile("/tmp/%s" % (temp_dir))
# Acquire a lock (file) so acquire inside download will fail
- with open(lock_file.lock_file, 'w') as fp:
- fp.write('')
-
- expected_msg = 'Timeout waiting to acquire lock for'
- self.assertRaisesRegexp(LockTimeout, expected_msg, action.run, packs=['test'],
- abs_repo_base=self.repo_base)
+ with open(lock_file.lock_file, "w") as fp:
+ fp.write("")
+
+ expected_msg = "Timeout waiting to acquire lock for"
+ self.assertRaisesRegexp(
+ LockTimeout,
+ expected_msg,
+ action.run,
+ packs=["test"],
+ abs_repo_base=self.repo_base,
+ )
finally:
os.unlink(lock_file.lock_file)
LockFile.acquire = original_acquire
@@ -243,7 +273,7 @@ def mock_acquire(self, timeout=None):
def test_run_pack_lock_is_already_acquired_force_flag(self):
# Lock is already acquired but force is true so it should be deleted and released
action = self.get_action_instance()
- temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest()
+ temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest()
original_acquire = LockFile.acquire
@@ -253,194 +283,266 @@ def mock_acquire(self, timeout=None):
LockFile.acquire = mock_acquire
try:
- lock_file = LockFile('/tmp/%s' % (temp_dir))
+ lock_file = LockFile("/tmp/%s" % (temp_dir))
# Acquire a lock (file) so acquire inside download will fail
- with open(lock_file.lock_file, 'w') as fp:
- fp.write('')
+ with open(lock_file.lock_file, "w") as fp:
+ fp.write("")
- result = action.run(packs=['test'], abs_repo_base=self.repo_base, force=True)
+ result = action.run(
+ packs=["test"], abs_repo_base=self.repo_base, force=True
+ )
finally:
LockFile.acquire = original_acquire
- self.assertEqual(result, {'test': 'Success.'})
+ self.assertEqual(result, {"test": "Success."})
def test_run_pack_download_v_tag(self):
def side_effect(ref):
- if ref[0] != 'v':
+ if ref[0] != "v":
raise BadName()
- return mock.MagicMock(hexsha='abcdef')
+ return mock.MagicMock(hexsha="abcdef")
self.repo_instance.commit.side_effect = side_effect
self.repo_instance.git = mock.MagicMock(
- branch=(lambda *args: 'master'),
- checkout=(lambda *args: True)
+ branch=(lambda *args: "master"), checkout=(lambda *args: True)
)
action = self.get_action_instance()
- result = action.run(packs=['test=1.2.3'], abs_repo_base=self.repo_base)
+ result = action.run(packs=["test=1.2.3"], abs_repo_base=self.repo_base)
- self.assertEqual(result, {'test': 'Success.'})
+ self.assertEqual(result, {"test": "Success."})
- @mock.patch.object(st2common.util.pack_management, 'get_valid_versions_for_repo',
- mock.Mock(return_value=['1.0.0', '2.0.0']))
+ @mock.patch.object(
+ st2common.util.pack_management,
+ "get_valid_versions_for_repo",
+ mock.Mock(return_value=["1.0.0", "2.0.0"]),
+ )
def test_run_pack_download_invalid_version(self):
self.repo_instance.commit.side_effect = lambda ref: None
action = self.get_action_instance()
- expected_msg = ('is not a valid version, hash, tag or branch.*?'
- 'Available versions are: 1.0.0, 2.0.0.')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run,
- packs=['test=2.2.3'], abs_repo_base=self.repo_base)
+ expected_msg = (
+ "is not a valid version, hash, tag or branch.*?"
+ "Available versions are: 1.0.0, 2.0.0."
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test=2.2.3"],
+ abs_repo_base=self.repo_base,
+ )
def test_download_pack_stackstorm_version_identifier_check(self):
action = self.get_action_instance()
# Version is satisfied
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.0.0'
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.0.0"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(packs=["test3"], abs_repo_base=self.repo_base)
+ self.assertEqual(result["test3"], "Success.")
# Pack requires a version which is not satisfied by current StackStorm version
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.2.0'
- expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
- 'current version is "2.2.0"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'],
- abs_repo_base=self.repo_base)
-
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.3.0'
- expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
- 'current version is "2.3.0"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'],
- abs_repo_base=self.repo_base)
-
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.9'
- expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
- 'current version is "1.5.9"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'],
- abs_repo_base=self.repo_base)
-
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0'
- expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
- 'current version is "1.5.0"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'],
- abs_repo_base=self.repo_base)
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.2.0"
+ expected_msg = (
+ 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
+ 'current version is "2.2.0"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ )
+
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.3.0"
+ expected_msg = (
+ 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
+ 'current version is "2.3.0"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ )
+
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.9"
+ expected_msg = (
+ 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
+ 'current version is "1.5.9"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ )
+
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0"
+ expected_msg = (
+ 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but '
+ 'current version is "1.5.0"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ )
# Version is not met, but force=true parameter is provided
- st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0'
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=True)
- self.assertEqual(result['test3'], 'Success.')
+ st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0"
+ result = action.run(packs=["test3"], abs_repo_base=self.repo_base, force=True)
+ self.assertEqual(result["test3"], "Success.")
def test_download_pack_python_version_check(self):
action = self.get_action_instance()
# No python_versions attribute specified in the metadata file
- with mock.patch('st2common.util.pack_management.get_pack_metadata') as \
- mock_get_pack_metadata:
+ with mock.patch(
+ "st2common.util.pack_management.get_pack_metadata"
+ ) as mock_get_pack_metadata:
mock_get_pack_metadata.return_value = {
- 'name': 'test3',
- 'stackstorm_version': '',
- 'python_versions': []
+ "name": "test3",
+ "stackstorm_version": "",
+ "python_versions": [],
}
st2common.util.pack_management.six.PY2 = True
st2common.util.pack_management.six.PY3 = False
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.11'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.11"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(
+ packs=["test3"], abs_repo_base=self.repo_base, force=False
+ )
+ self.assertEqual(result["test3"], "Success.")
# Pack works with Python 2.x installation is running 2.7
- with mock.patch('st2common.util.pack_management.get_pack_metadata') as \
- mock_get_pack_metadata:
+ with mock.patch(
+ "st2common.util.pack_management.get_pack_metadata"
+ ) as mock_get_pack_metadata:
mock_get_pack_metadata.return_value = {
- 'name': 'test3',
- 'stackstorm_version': '',
- 'python_versions': ['2']
+ "name": "test3",
+ "stackstorm_version": "",
+ "python_versions": ["2"],
}
st2common.util.pack_management.six.PY2 = True
st2common.util.pack_management.six.PY3 = False
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(
+ packs=["test3"], abs_repo_base=self.repo_base, force=False
+ )
+ self.assertEqual(result["test3"], "Success.")
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.12'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.12"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(
+ packs=["test3"], abs_repo_base=self.repo_base, force=False
+ )
+ self.assertEqual(result["test3"], "Success.")
# Pack works with Python 2.x installation is running 3.5
- with mock.patch('st2common.util.pack_management.get_pack_metadata') as \
- mock_get_pack_metadata:
+ with mock.patch(
+ "st2common.util.pack_management.get_pack_metadata"
+ ) as mock_get_pack_metadata:
mock_get_pack_metadata.return_value = {
- 'name': 'test3',
- 'stackstorm_version': '',
- 'python_versions': ['2']
+ "name": "test3",
+ "stackstorm_version": "",
+ "python_versions": ["2"],
}
st2common.util.pack_management.six.PY2 = False
st2common.util.pack_management.six.PY3 = True
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.5.2'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.5.2"
- expected_msg = (r'Pack "test3" requires Python 2.x, but current Python version is '
- '"3.5.2"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run,
- packs=['test3'], abs_repo_base=self.repo_base, force=False)
+ expected_msg = (
+ r'Pack "test3" requires Python 2.x, but current Python version is '
+ '"3.5.2"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ force=False,
+ )
# Pack works with Python 3.x installation is running 2.7
- with mock.patch('st2common.util.pack_management.get_pack_metadata') as \
- mock_get_pack_metadata:
+ with mock.patch(
+ "st2common.util.pack_management.get_pack_metadata"
+ ) as mock_get_pack_metadata:
mock_get_pack_metadata.return_value = {
- 'name': 'test3',
- 'stackstorm_version': '',
- 'python_versions': ['3']
+ "name": "test3",
+ "stackstorm_version": "",
+ "python_versions": ["3"],
}
st2common.util.pack_management.six.PY2 = True
st2common.util.pack_management.six.PY3 = False
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.2'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.2"
- expected_msg = (r'Pack "test3" requires Python 3.x, but current Python version is '
- '"2.7.2"')
- self.assertRaisesRegexp(ValueError, expected_msg, action.run,
- packs=['test3'], abs_repo_base=self.repo_base, force=False)
+ expected_msg = (
+ r'Pack "test3" requires Python 3.x, but current Python version is '
+ '"2.7.2"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["test3"],
+ abs_repo_base=self.repo_base,
+ force=False,
+ )
# Pack works with Python 2.x and 3.x installation is running 2.7 and 3.6.1
- with mock.patch('st2common.util.pack_management.get_pack_metadata') as \
- mock_get_pack_metadata:
+ with mock.patch(
+ "st2common.util.pack_management.get_pack_metadata"
+ ) as mock_get_pack_metadata:
mock_get_pack_metadata.return_value = {
- 'name': 'test3',
- 'stackstorm_version': '',
- 'python_versions': ['2', '3']
+ "name": "test3",
+ "stackstorm_version": "",
+ "python_versions": ["2", "3"],
}
st2common.util.pack_management.six.PY2 = True
st2common.util.pack_management.six.PY3 = False
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(
+ packs=["test3"], abs_repo_base=self.repo_base, force=False
+ )
+ self.assertEqual(result["test3"], "Success.")
st2common.util.pack_management.six.PY2 = False
st2common.util.pack_management.six.PY3 = True
- st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.6.1'
+ st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.6.1"
- result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False)
- self.assertEqual(result['test3'], 'Success.')
+ result = action.run(
+ packs=["test3"], abs_repo_base=self.repo_base, force=False
+ )
+ self.assertEqual(result["test3"], "Success.")
def test_resolve_urls(self):
- url = eval_repo_url(
- "https://github.com/StackStorm-Exchange/stackstorm-test")
+ url = eval_repo_url("https://github.com/StackStorm-Exchange/stackstorm-test")
self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test")
url = eval_repo_url(
- "https://github.com/StackStorm-Exchange/stackstorm-test.git")
- self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test.git")
+ "https://github.com/StackStorm-Exchange/stackstorm-test.git"
+ )
+ self.assertEqual(
+ url, "https://github.com/StackStorm-Exchange/stackstorm-test.git"
+ )
url = eval_repo_url("StackStorm-Exchange/stackstorm-test")
self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test")
@@ -460,11 +562,11 @@ def test_resolve_urls(self):
url = eval_repo_url("file://localhost/home/vagrant/stackstorm-test")
self.assertEqual(url, "file://localhost/home/vagrant/stackstorm-test")
- url = eval_repo_url('ssh:///AutomationStackStorm')
- self.assertEqual(url, 'ssh:///AutomationStackStorm')
+ url = eval_repo_url("ssh:///AutomationStackStorm")
+ self.assertEqual(url, "ssh:///AutomationStackStorm")
- url = eval_repo_url('ssh://joe@local/AutomationStackStorm')
- self.assertEqual(url, 'ssh://joe@local/AutomationStackStorm')
+ url = eval_repo_url("ssh://joe@local/AutomationStackStorm")
+ self.assertEqual(url, "ssh://joe@local/AutomationStackStorm")
def test_run_pack_download_edge_cases(self):
"""
@@ -479,36 +581,35 @@ def test_run_pack_download_edge_cases(self):
"""
def side_effect(ref):
- if ref[0] != 'v':
+ if ref[0] != "v":
raise BadName()
- return mock.MagicMock(hexsha='abcdeF')
+ return mock.MagicMock(hexsha="abcdeF")
self.repo_instance.commit.side_effect = side_effect
edge_cases = [
- ('master', '1.2.3'),
- ('master', 'some-branch'),
- ('master', 'default-branch'),
- ('master', None),
- ('default-branch', '1.2.3'),
- ('default-branch', 'some-branch'),
- ('default-branch', 'default-branch'),
- ('default-branch', None)
+ ("master", "1.2.3"),
+ ("master", "some-branch"),
+ ("master", "default-branch"),
+ ("master", None),
+ ("default-branch", "1.2.3"),
+ ("default-branch", "some-branch"),
+ ("default-branch", "default-branch"),
+ ("default-branch", None),
]
for default_branch, ref in edge_cases:
self.repo_instance.git = mock.MagicMock(
- branch=(lambda *args: default_branch),
- checkout=(lambda *args: True)
+ branch=(lambda *args: default_branch), checkout=(lambda *args: True)
)
# Set default branch
self.repo_instance.active_branch.name = default_branch
- self.repo_instance.active_branch.object = 'aBcdef'
- self.repo_instance.head.commit = 'aBcdef'
+ self.repo_instance.active_branch.object = "aBcdef"
+ self.repo_instance.head.commit = "aBcdef"
# Fake gitref object
- gitref = mock.MagicMock(hexsha='abcDef')
+ gitref = mock.MagicMock(hexsha="abcDef")
# Fool _get_gitref into working when its ref == our ref
def fake_commit(arg_ref):
@@ -516,30 +617,34 @@ def fake_commit(arg_ref):
return gitref
else:
raise BadName()
+
self.repo_instance.commit = fake_commit
self.repo_instance.active_branch.object = gitref
action = self.get_action_instance()
if ref:
- packs = ['test=%s' % (ref)]
+ packs = ["test=%s" % (ref)]
else:
- packs = ['test']
+ packs = ["test"]
result = action.run(packs=packs, abs_repo_base=self.repo_base)
- self.assertEqual(result, {'test': 'Success.'})
+ self.assertEqual(result, {"test": "Success."})
- @mock.patch('os.path.isdir', mock_is_dir_func)
+ @mock.patch("os.path.isdir", mock_is_dir_func)
def test_run_pack_dowload_local_git_repo_detached_head_state(self):
action = self.get_action_instance()
- type(self.repo_instance).active_branch = \
- mock.PropertyMock(side_effect=TypeError('detached head'))
+ type(self.repo_instance).active_branch = mock.PropertyMock(
+ side_effect=TypeError("detached head")
+ )
- pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test')
+ pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test")
- result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base)
- self.assertEqual(result, {'test': 'Success.'})
+ result = action.run(
+ packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base
+ )
+ self.assertEqual(result, {"test": "Success."})
# Verify function has bailed out early
self.repo_instance.git.checkout.assert_not_called()
@@ -551,41 +656,55 @@ def test_run_pack_download_local_directory(self):
# 1. Local directory doesn't exist
expected_msg = r'Local pack directory ".*" doesn\'t exist'
- self.assertRaisesRegexp(ValueError, expected_msg, action.run,
- packs=['file://doesnt_exist'], abs_repo_base=self.repo_base)
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action.run,
+ packs=["file://doesnt_exist"],
+ abs_repo_base=self.repo_base,
+ )
# 2. Local pack which is not a git repository
- pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test4')
+ pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test4")
- result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base)
- self.assertEqual(result, {'test4': 'Success.'})
+ result = action.run(
+ packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base
+ )
+ self.assertEqual(result, {"test4": "Success."})
# Verify pack contents have been copied over
- destination_path = os.path.join(self.repo_base, 'test4')
+ destination_path = os.path.join(self.repo_base, "test4")
self.assertTrue(os.path.exists(destination_path))
- self.assertTrue(os.path.exists(os.path.join(destination_path, 'pack.yaml')))
+ self.assertTrue(os.path.exists(os.path.join(destination_path, "pack.yaml")))
- @mock.patch('st2common.util.pack_management.get_gitref', mock_get_gitref)
+ @mock.patch("st2common.util.pack_management.get_gitref", mock_get_gitref)
def test_run_pack_download_with_tag(self):
action = self.get_action_instance()
- result = action.run(packs=['test'], abs_repo_base=self.repo_base)
- temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest()
+ result = action.run(packs=["test"], abs_repo_base=self.repo_base)
+ temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest()
- self.assertEqual(result, {'test': 'Success.'})
- self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'],
- os.path.join(os.path.expanduser('~'), temp_dir))
- self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml')))
+ self.assertEqual(result, {"test": "Success."})
+ self.clone_from.assert_called_once_with(
+ PACK_INDEX["test"]["repo_url"],
+ os.path.join(os.path.expanduser("~"), temp_dir),
+ )
+ self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml")))
# Check repo.git.checkout is called three times
self.assertEqual(self.repo_instance.git.checkout.call_count, 3)
# Check repo.git.checkout called with latest tag or branch
- self.assertEqual(PACK_INDEX['test']['version'],
- self.repo_instance.git.checkout.call_args_list[1][0][0])
+ self.assertEqual(
+ PACK_INDEX["test"]["version"],
+ self.repo_instance.git.checkout.call_args_list[1][0][0],
+ )
# Check repo.git.checkout called with head
- self.assertEqual(self.repo_instance.head.reference,
- self.repo_instance.git.checkout.call_args_list[2][0][0])
+ self.assertEqual(
+ self.repo_instance.head.reference,
+ self.repo_instance.git.checkout.call_args_list[2][0][0],
+ )
self.repo_instance.git.branch.assert_called_with(
- '-f', self.repo_instance.head.reference, PACK_INDEX['test']['version'])
+ "-f", self.repo_instance.head.reference, PACK_INDEX["test"]["version"]
+ )
diff --git a/contrib/packs/tests/test_action_unload.py b/contrib/packs/tests/test_action_unload.py
index 5e642483d4..fc07ff87c3 100644
--- a/contrib/packs/tests/test_action_unload.py
+++ b/contrib/packs/tests/test_action_unload.py
@@ -20,6 +20,7 @@
from oslo_config import cfg
from st2common.util.monkey_patch import use_select_poll_workaround
+
use_select_poll_workaround()
from st2common.content.bootstrap import register_content
@@ -39,11 +40,11 @@
from pack_mgmt.unload import UnregisterPackAction
-__all__ = [
- 'UnloadActionTestCase'
-]
+__all__ = ["UnloadActionTestCase"]
-PACK_PATH_1 = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1')
+PACK_PATH_1 = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1"
+)
class UnloadActionTestCase(BaseActionTestCase, CleanDbTestCase):
@@ -64,13 +65,15 @@ def setUp(self):
# Register the pack with all the content
# TODO: Don't use pack cache
- cfg.CONF.set_override(name='all', override=True, group='register')
- cfg.CONF.set_override(name='pack', override=PACK_PATH_1, group='register')
- cfg.CONF.set_override(name='no_fail_on_failure', override=True, group='register')
+ cfg.CONF.set_override(name="all", override=True, group="register")
+ cfg.CONF.set_override(name="pack", override=PACK_PATH_1, group="register")
+ cfg.CONF.set_override(
+ name="no_fail_on_failure", override=True, group="register"
+ )
register_content()
def test_run(self):
- pack = 'dummy_pack_1'
+ pack = "dummy_pack_1"
# Verify all the resources are there
pack_dbs = Pack.query(ref=pack)
diff --git a/contrib/packs/tests/test_get_pack_dependencies.py b/contrib/packs/tests/test_get_pack_dependencies.py
index e047d7fca4..a90f940638 100644
--- a/contrib/packs/tests/test_get_pack_dependencies.py
+++ b/contrib/packs/tests/test_get_pack_dependencies.py
@@ -21,21 +21,20 @@
from pack_mgmt.get_pack_dependencies import GetPackDependencies
-UNINSTALLED_PACK = 'uninstalled_pack'
+UNINSTALLED_PACK = "uninstalled_pack"
UNINSTALLED_PACKS = [
UNINSTALLED_PACK,
- 'https://github.com/StackStorm-Exchange/stackstorm-pack1',
- 'https://github.com/StackStorm-Exchange/stackstorm-pack2.git',
- 'https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1',
- 'StackStorm-Exchange/stackstorm-pack4',
- 'git://StackStorm-Exchange/stackstorm-pack5=v2.1.1',
- 'git://StackStorm-Exchange/stackstorm-pack6.git',
- 'git@github.com:foo/pack7.git'
- 'git@github.com:foo/pack8.git=v3.2.1',
- 'file:///home/vagrant/stackstorm-pack9',
- 'file://localhost/home/vagrant/stackstorm-pack10',
- 'ssh:///AutomationStackStorm11',
- 'ssh://joe@local/AutomationStackStorm12'
+ "https://github.com/StackStorm-Exchange/stackstorm-pack1",
+ "https://github.com/StackStorm-Exchange/stackstorm-pack2.git",
+ "https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1",
+ "StackStorm-Exchange/stackstorm-pack4",
+ "git://StackStorm-Exchange/stackstorm-pack5=v2.1.1",
+ "git://StackStorm-Exchange/stackstorm-pack6.git",
+ "git@github.com:foo/pack7.git" "git@github.com:foo/pack8.git=v3.2.1",
+ "file:///home/vagrant/stackstorm-pack9",
+ "file://localhost/home/vagrant/stackstorm-pack10",
+ "ssh:///AutomationStackStorm11",
+ "ssh://joe@local/AutomationStackStorm12",
]
DOWNLOADED_OR_INSTALLED_PACK_METAdATA = {
@@ -58,7 +57,7 @@
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "dependencies": ['uninstalled_pack', 'no_dependencies']
+ "dependencies": ["uninstalled_pack", "no_dependencies"],
},
# List of uninstalled dependency packs.
"test3": {
@@ -70,7 +69,7 @@
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "dependencies": UNINSTALLED_PACKS
+ "dependencies": UNINSTALLED_PACKS,
},
# One conflict pack with existing pack.
"test4": {
@@ -82,9 +81,7 @@
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "dependencies": [
- "test2=v0.4.0"
- ]
+ "dependencies": ["test2=v0.4.0"],
},
# One uninstalled conflict pack.
"test5": {
@@ -93,9 +90,10 @@
"name": "test4",
"repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4",
"author": "stanley",
- "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com",
+ "keywords": ["some", "special", "terms"],
+ "email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "dependencies": ["uninstalled_pack=v0.4.0"]
+ "dependencies": ["uninstalled_pack=v0.4.0"],
},
# One dependency pack without version. It is not checked against conflict.
"test6": {
@@ -104,10 +102,11 @@
"name": "test4",
"repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4",
"author": "stanley",
- "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com",
+ "keywords": ["some", "special", "terms"],
+ "email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "dependencies": ["test2"]
- }
+ "dependencies": ["test2"],
+ },
}
@@ -119,7 +118,7 @@ def mock_get_dependency_list(pack):
if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA:
metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack]
- dependencies = metadata.get('dependencies', None)
+ dependencies = metadata.get("dependencies", None)
return dependencies
@@ -132,13 +131,15 @@ def mock_get_pack_version(pack):
if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA:
metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack]
- version = metadata.get('version', None)
+ version = metadata.get("version", None)
return version
-@mock.patch('pack_mgmt.get_pack_dependencies.get_dependency_list', mock_get_dependency_list)
-@mock.patch('pack_mgmt.get_pack_dependencies.get_pack_version', mock_get_pack_version)
+@mock.patch(
+ "pack_mgmt.get_pack_dependencies.get_dependency_list", mock_get_dependency_list
+)
+@mock.patch("pack_mgmt.get_pack_dependencies.get_pack_version", mock_get_pack_version)
class GetPackDependenciesTestCase(BaseActionTestCase):
action_cls = GetPackDependencies
@@ -167,9 +168,9 @@ def test_run_get_pack_dependencies_with_failed_packs_status(self):
nested = 2
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [])
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [])
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self):
action = self.get_action_instance()
@@ -177,9 +178,9 @@ def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self):
nested = 2
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK])
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK])
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_no_dependency(self):
action = self.get_action_instance()
@@ -187,9 +188,9 @@ def test_run_get_pack_dependencies_with_no_dependency(self):
nested = 3
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [])
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [])
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_dependency(self):
action = self.get_action_instance()
@@ -197,9 +198,9 @@ def test_run_get_pack_dependencies_with_dependency(self):
nested = 1
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK])
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK])
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_dependencies(self):
action = self.get_action_instance()
@@ -207,9 +208,9 @@ def test_run_get_pack_dependencies_with_dependencies(self):
nested = 1
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], UNINSTALLED_PACKS)
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], UNINSTALLED_PACKS)
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_existing_pack_conflict(self):
action = self.get_action_instance()
@@ -217,9 +218,9 @@ def test_run_get_pack_dependencies_with_existing_pack_conflict(self):
nested = 1
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK])
- self.assertEqual(result['conflict_list'], ['test2=v0.4.0'])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK])
+ self.assertEqual(result["conflict_list"], ["test2=v0.4.0"])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_dependency_conflict(self):
action = self.get_action_instance()
@@ -227,9 +228,9 @@ def test_run_get_pack_dependencies_with_dependency_conflict(self):
nested = 1
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], ['uninstalled_pack'])
- self.assertEqual(result['conflict_list'], ['uninstalled_pack=v0.4.0'])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], ["uninstalled_pack"])
+ self.assertEqual(result["conflict_list"], ["uninstalled_pack=v0.4.0"])
+ self.assertEqual(result["nested"], nested - 1)
def test_run_get_pack_dependencies_with_no_version(self):
action = self.get_action_instance()
@@ -237,6 +238,6 @@ def test_run_get_pack_dependencies_with_no_version(self):
nested = 1
result = action.run(packs_status=packs_status, nested=nested)
- self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK])
- self.assertEqual(result['conflict_list'], [])
- self.assertEqual(result['nested'], nested - 1)
+ self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK])
+ self.assertEqual(result["conflict_list"], [])
+ self.assertEqual(result["nested"], nested - 1)
diff --git a/contrib/packs/tests/test_get_pack_warnings.py b/contrib/packs/tests/test_get_pack_warnings.py
index 49e2d920a8..3eac7ba356 100644
--- a/contrib/packs/tests/test_get_pack_warnings.py
+++ b/contrib/packs/tests/test_get_pack_warnings.py
@@ -29,7 +29,7 @@
"keywords": ["some", "search", "another", "terms"],
"email": "info@stackstorm.com",
"description": "st2 pack to test package management pipeline",
- "python_versions": ["2","3"],
+ "python_versions": ["2", "3"],
},
# Python 3
"py3": {
@@ -72,10 +72,11 @@
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
"description": "another st2 pack to test package management pipeline",
- "python_versions": ["2"]
- }
+ "python_versions": ["2"],
+ },
}
+
def mock_get_pack_basepath(pack):
"""
Mock get_pack_basepath function which just returns pack n ame
@@ -94,8 +95,8 @@ def mock_get_pack_metadata(pack_dir):
return metadata
-@mock.patch('pack_mgmt.get_pack_warnings.get_pack_base_path', mock_get_pack_basepath)
-@mock.patch('pack_mgmt.get_pack_warnings.get_pack_metadata', mock_get_pack_metadata)
+@mock.patch("pack_mgmt.get_pack_warnings.get_pack_base_path", mock_get_pack_basepath)
+@mock.patch("pack_mgmt.get_pack_warnings.get_pack_metadata", mock_get_pack_metadata)
class GetPackWarningsTestCase(BaseActionTestCase):
action_cls = GetPackWarnings
@@ -107,15 +108,15 @@ def test_run_get_pack_warnings_py3_pack(self):
packs_status = {"py3": "Success."}
result = action.run(packs_status=packs_status)
- self.assertEqual(result['warning_list'], [])
+ self.assertEqual(result["warning_list"], [])
def test_run_get_pack_warnings_py2_pack(self):
action = self.get_action_instance()
packs_status = {"py2": "Success."}
result = action.run(packs_status=packs_status)
- self.assertEqual(len(result['warning_list']), 1)
- warning = result['warning_list'][0]
+ self.assertEqual(len(result["warning_list"]), 1)
+ warning = result["warning_list"][0]
self.assertTrue("DEPRECATION WARNING" in warning)
self.assertTrue("Pack py2 only supports Python 2" in warning)
@@ -124,28 +125,32 @@ def test_run_get_pack_warnings_py23_pack(self):
packs_status = {"py23": "Success."}
result = action.run(packs_status=packs_status)
- self.assertEqual(result['warning_list'], [])
+ self.assertEqual(result["warning_list"], [])
def test_run_get_pack_warnings_pynone_pack(self):
action = self.get_action_instance()
packs_status = {"pynone": "Success."}
result = action.run(packs_status=packs_status)
- self.assertEqual(result['warning_list'], [])
+ self.assertEqual(result["warning_list"], [])
def test_run_get_pack_warnings_multiple_pack(self):
action = self.get_action_instance()
- packs_status = {"py2": "Success.",
- "py23": "Success.",
- "py22": "Success."}
+ packs_status = {"py2": "Success.", "py23": "Success.", "py22": "Success."}
result = action.run(packs_status=packs_status)
- self.assertEqual(len(result['warning_list']), 2)
- warning0 = result['warning_list'][0]
- warning1 = result['warning_list'][1]
+ self.assertEqual(len(result["warning_list"]), 2)
+ warning0 = result["warning_list"][0]
+ warning1 = result["warning_list"][1]
self.assertTrue("DEPRECATION WARNING" in warning0)
self.assertTrue("DEPRECATION WARNING" in warning1)
- self.assertTrue(("Pack py2 only supports Python 2" in warning0 and
- "Pack py22 only supports Python 2" in warning1) or
- ("Pack py22 only supports Python 2" in warning0 and
- "Pack py2 only supports Python 2" in warning1))
+ self.assertTrue(
+ (
+ "Pack py2 only supports Python 2" in warning0
+ and "Pack py22 only supports Python 2" in warning1
+ )
+ or (
+ "Pack py22 only supports Python 2" in warning0
+ and "Pack py2 only supports Python 2" in warning1
+ )
+ )
diff --git a/contrib/packs/tests/test_virtualenv_setup_prerun.py b/contrib/packs/tests/test_virtualenv_setup_prerun.py
index 63b27410f6..0097ecd8fe 100644
--- a/contrib/packs/tests/test_virtualenv_setup_prerun.py
+++ b/contrib/packs/tests/test_virtualenv_setup_prerun.py
@@ -28,21 +28,26 @@ def setUp(self):
def test_run_with_pack_list(self):
action = self.get_action_instance()
- result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'},
- packs_list=['test3', 'test4'])
+ result = action.run(
+ packs_status={"test1": "Success.", "test2": "Success."},
+ packs_list=["test3", "test4"],
+ )
- self.assertEqual(result, ['test3', 'test4', 'test1', 'test2'])
+ self.assertEqual(result, ["test3", "test4", "test1", "test2"])
def test_run_with_none_pack_list(self):
action = self.get_action_instance()
- result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'},
- packs_list=None)
+ result = action.run(
+ packs_status={"test1": "Success.", "test2": "Success."}, packs_list=None
+ )
- self.assertEqual(result, ['test1', 'test2'])
+ self.assertEqual(result, ["test1", "test2"])
def test_run_with_failed_status(self):
action = self.get_action_instance()
- result = action.run(packs_status={'test1': 'Failed.', 'test2': 'Success.'},
- packs_list=['test3', 'test4'])
+ result = action.run(
+ packs_status={"test1": "Failed.", "test2": "Success."},
+ packs_list=["test3", "test4"],
+ )
- self.assertEqual(result, ['test3', 'test4', 'test2'])
+ self.assertEqual(result, ["test3", "test4", "test2"])
diff --git a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py
+++ b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py
index 39cb873136..e71c12d004 100644
--- a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py
+++ b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py
@@ -50,26 +50,16 @@
from st2common.util.config_loader import get_config
from st2common.util.ujson import fast_deepcopy
-__all__ = [
- 'ActionChainRunner',
- 'ChainHolder',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["ActionChainRunner", "ChainHolder", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
-RESULTS_KEY = '__results'
-JINJA_START_MARKERS = [
- '{{',
- '{%'
-]
-PUBLISHED_VARS_KEY = 'published'
+RESULTS_KEY = "__results"
+JINJA_START_MARKERS = ["{{", "{%"]
+PUBLISHED_VARS_KEY = "published"
class ChainHolder(object):
-
def __init__(self, chainspec, chainname):
self.actionchain = actionchain.ActionChain(**chainspec)
self.chainname = chainname
@@ -78,17 +68,21 @@ def __init__(self, chainspec, chainname):
default = self._get_default(self.actionchain)
self.actionchain.default = default
- LOG.debug('Using %s as default for %s.', self.actionchain.default, self.chainname)
+ LOG.debug(
+ "Using %s as default for %s.", self.actionchain.default, self.chainname
+ )
if not self.actionchain.default:
- raise Exception('Failed to find default node in %s.' % (self.chainname))
+ raise Exception("Failed to find default node in %s." % (self.chainname))
self.vars = {}
def init_vars(self, action_parameters, action_context=None):
if self.actionchain.vars:
- self.vars = self._get_rendered_vars(self.actionchain.vars,
- action_parameters=action_parameters,
- action_context=action_context)
+ self.vars = self._get_rendered_vars(
+ self.actionchain.vars,
+ action_parameters=action_parameters,
+ action_context=action_context,
+ )
def restore_vars(self, ctx_vars):
self.vars.update(fast_deepcopy(ctx_vars))
@@ -107,28 +101,37 @@ def validate(self):
on_failure_node_name = node.on_failure
# Check "on-success" path
- valid_name = self._is_valid_node_name(all_node_names=all_nodes,
- node_name=on_success_node_name)
+ valid_name = self._is_valid_node_name(
+ all_node_names=all_nodes, node_name=on_success_node_name
+ )
if not valid_name:
- msg = ('Unable to find node with name "%s" referenced in "on-success" in '
- 'task "%s".' % (on_success_node_name, node.name))
+ msg = (
+ 'Unable to find node with name "%s" referenced in "on-success" in '
+ 'task "%s".' % (on_success_node_name, node.name)
+ )
raise ValueError(msg)
# Check "on-failure" path
- valid_name = self._is_valid_node_name(all_node_names=all_nodes,
- node_name=on_failure_node_name)
+ valid_name = self._is_valid_node_name(
+ all_node_names=all_nodes, node_name=on_failure_node_name
+ )
if not valid_name:
- msg = ('Unable to find node with name "%s" referenced in "on-failure" in '
- 'task "%s".' % (on_failure_node_name, node.name))
+ msg = (
+ 'Unable to find node with name "%s" referenced in "on-failure" in '
+ 'task "%s".' % (on_failure_node_name, node.name)
+ )
raise ValueError(msg)
# check if node specified in default is valid.
if self.actionchain.default:
- valid_name = self._is_valid_node_name(all_node_names=all_nodes,
- node_name=self.actionchain.default)
+ valid_name = self._is_valid_node_name(
+ all_node_names=all_nodes, node_name=self.actionchain.default
+ )
if not valid_name:
- msg = ('Unable to find node with name "%s" referenced in "default".' %
- self.actionchain.default)
+ msg = (
+ 'Unable to find node with name "%s" referenced in "default".'
+ % self.actionchain.default
+ )
raise ValueError(msg)
return True
@@ -147,8 +150,12 @@ def _get_default(action_chain):
# 2. There are no fragments in the chain.
all_nodes = ChainHolder._get_all_nodes(action_chain=action_chain)
node_names = set(all_nodes)
- on_success_nodes = ChainHolder._get_all_on_success_nodes(action_chain=action_chain)
- on_failure_nodes = ChainHolder._get_all_on_failure_nodes(action_chain=action_chain)
+ on_success_nodes = ChainHolder._get_all_on_success_nodes(
+ action_chain=action_chain
+ )
+ on_failure_nodes = ChainHolder._get_all_on_failure_nodes(
+ action_chain=action_chain
+ )
referenced_nodes = on_success_nodes | on_failure_nodes
possible_default_nodes = node_names - referenced_nodes
if possible_default_nodes:
@@ -210,19 +217,25 @@ def _get_rendered_vars(vars, action_parameters, action_context):
return {}
action_context = action_context or {}
- user = action_context.get('user', cfg.CONF.system_user.user)
+ user = action_context.get("user", cfg.CONF.system_user.user)
context = {}
- context.update({
- kv_constants.DATASTORE_PARENT_SCOPE: {
- kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.FULL_SYSTEM_SCOPE),
- kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup(
- scope=kv_constants.FULL_USER_SCOPE, user=user)
+ context.update(
+ {
+ kv_constants.DATASTORE_PARENT_SCOPE: {
+ kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
+ scope=kv_constants.FULL_SYSTEM_SCOPE
+ ),
+ kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup(
+ scope=kv_constants.FULL_USER_SCOPE, user=user
+ ),
+ }
}
- })
+ )
context.update(action_parameters)
- LOG.info('Rendering action chain vars. Mapping = %s; Context = %s', vars, context)
+ LOG.info(
+ "Rendering action chain vars. Mapping = %s; Context = %s", vars, context
+ )
return jinja_utils.render_values(mapping=vars, context=context)
def get_node(self, node_name=None, raise_on_failure=False):
@@ -233,22 +246,22 @@ def get_node(self, node_name=None, raise_on_failure=False):
return node
if raise_on_failure:
raise runner_exc.ActionRunnerException(
- 'Unable to find node with name "%s".' % (node_name))
+ 'Unable to find node with name "%s".' % (node_name)
+ )
return None
- def get_next_node(self, curr_node_name=None, condition='on-success'):
+ def get_next_node(self, curr_node_name=None, condition="on-success"):
if not curr_node_name:
return self.get_node(self.actionchain.default)
current_node = self.get_node(curr_node_name)
- if condition == 'on-success':
+ if condition == "on-success":
return self.get_node(current_node.on_success, raise_on_failure=True)
- elif condition == 'on-failure':
+ elif condition == "on-failure":
return self.get_node(current_node.on_failure, raise_on_failure=True)
- raise runner_exc.ActionRunnerException('Unknown condition %s.' % condition)
+ raise runner_exc.ActionRunnerException("Unknown condition %s." % condition)
class ActionChainRunner(ActionRunner):
-
def __init__(self, runner_id):
super(ActionChainRunner, self).__init__(runner_id=runner_id)
self.chain_holder = None
@@ -261,16 +274,20 @@ def pre_run(self):
super(ActionChainRunner, self).pre_run()
chainspec_file = self.entry_point
- LOG.debug('Reading action chain from %s for action %s.', chainspec_file,
- self.action)
+ LOG.debug(
+ "Reading action chain from %s for action %s.", chainspec_file, self.action
+ )
try:
- chainspec = self._meta_loader.load(file_path=chainspec_file,
- expected_type=dict)
+ chainspec = self._meta_loader.load(
+ file_path=chainspec_file, expected_type=dict
+ )
except Exception as e:
- message = ('Failed to parse action chain definition from "%s": %s' %
- (chainspec_file, six.text_type(e)))
- LOG.exception('Failed to load action chain definition.')
+ message = 'Failed to parse action chain definition from "%s": %s' % (
+ chainspec_file,
+ six.text_type(e),
+ )
+ LOG.exception("Failed to load action chain definition.")
raise runner_exc.ActionRunnerPreRunError(message)
try:
@@ -279,20 +296,22 @@ def pre_run(self):
# preserve the whole nasty jsonschema message as that is better to get to the
# root cause
message = six.text_type(e)
- LOG.exception('Failed to instantiate ActionChain.')
+ LOG.exception("Failed to instantiate ActionChain.")
raise runner_exc.ActionRunnerPreRunError(message)
except Exception as e:
message = six.text_type(e)
- LOG.exception('Failed to instantiate ActionChain.')
+ LOG.exception("Failed to instantiate ActionChain.")
raise runner_exc.ActionRunnerPreRunError(message)
# Runner attributes are set lazily. So these steps
# should happen outside the constructor.
- if getattr(self, 'liveaction', None):
- self._chain_notify = getattr(self.liveaction, 'notify', None)
+ if getattr(self, "liveaction", None):
+ self._chain_notify = getattr(self.liveaction, "notify", None)
if self.runner_parameters:
- self._skip_notify_tasks = self.runner_parameters.get('skip_notify', [])
- self._display_published = self.runner_parameters.get('display_published', True)
+ self._skip_notify_tasks = self.runner_parameters.get("skip_notify", [])
+ self._display_published = self.runner_parameters.get(
+ "display_published", True
+ )
# Perform some pre-run chain validation
try:
@@ -308,34 +327,38 @@ def cancel(self):
# Identify the list of action executions that are workflows and cascade pause.
for child_exec_id in self.execution.children:
child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True)
- if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and
- child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES):
+ if (
+ child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES
+ and child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES
+ ):
action_service.request_cancellation(
- LiveAction.get(id=child_exec.liveaction['id']),
- self.context.get('user', None)
+ LiveAction.get(id=child_exec.liveaction["id"]),
+ self.context.get("user", None),
)
return (
action_constants.LIVEACTION_STATUS_CANCELING,
self.liveaction.result,
- self.liveaction.context
+ self.liveaction.context,
)
def pause(self):
# Identify the list of action executions that are workflows and cascade pause.
for child_exec_id in self.execution.children:
child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True)
- if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and
- child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING):
+ if (
+ child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES
+ and child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING
+ ):
action_service.request_pause(
- LiveAction.get(id=child_exec.liveaction['id']),
- self.context.get('user', None)
+ LiveAction.get(id=child_exec.liveaction["id"]),
+ self.context.get("user", None),
)
return (
action_constants.LIVEACTION_STATUS_PAUSING,
self.liveaction.result,
- self.liveaction.context
+ self.liveaction.context,
)
def resume(self):
@@ -344,7 +367,7 @@ def resume(self):
self.runner_type.runner_parameters,
self.action.parameters,
self.liveaction.parameters,
- self.liveaction.context
+ self.liveaction.context,
)
# Assign runner parameters needed for pre-run.
@@ -357,9 +380,7 @@ def resume(self):
# Change the status of the liveaction from resuming to running.
self.liveaction = action_service.update_status(
- self.liveaction,
- action_constants.LIVEACTION_STATUS_RUNNING,
- publish=False
+ self.liveaction, action_constants.LIVEACTION_STATUS_RUNNING, publish=False
)
# Run the action chain.
@@ -370,13 +391,15 @@ def _run_chain(self, action_parameters, resuming=False):
chain_status = action_constants.LIVEACTION_STATUS_FAILED
# Result holds the final result that the chain store in the database.
- result = {'tasks': []}
+ result = {"tasks": []}
# Save published variables into the result if specified.
if self._display_published:
result[PUBLISHED_VARS_KEY] = {}
- context_result = {} # Holds result which is used for the template context purposes
+ context_result = (
+ {}
+ ) # Holds result which is used for the template context purposes
top_level_error = None # Stores a reference to a top level error
action_node = None
last_task = None
@@ -384,11 +407,12 @@ def _run_chain(self, action_parameters, resuming=False):
try:
# Initialize vars with the action parameters.
# This allows action parameers to be referenced from vars.
- self.chain_holder.init_vars(action_parameters=action_parameters,
- action_context=self.context)
+ self.chain_holder.init_vars(
+ action_parameters=action_parameters, action_context=self.context
+ )
except Exception as e:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
- m = 'Failed initializing ``vars`` in chain.'
+ m = "Failed initializing ``vars`` in chain."
LOG.exception(m)
top_level_error = self._format_error(e, m)
result.update(top_level_error)
@@ -397,28 +421,32 @@ def _run_chain(self, action_parameters, resuming=False):
# Restore state on resuming an existing chain execution.
if resuming:
# Restore vars is any from the liveaction.
- ctx_vars = self.liveaction.context.pop('vars', {})
+ ctx_vars = self.liveaction.context.pop("vars", {})
self.chain_holder.restore_vars(ctx_vars)
# Restore result if any from the liveaction.
- if self.liveaction and hasattr(self.liveaction, 'result') and self.liveaction.result:
+ if (
+ self.liveaction
+ and hasattr(self.liveaction, "result")
+ and self.liveaction.result
+ ):
result = self.liveaction.result
# Initialize or rebuild existing context_result from liveaction
# which holds the result used for resolving context in Jinja template.
- for task in result.get('tasks', []):
- context_result[task['name']] = task['result']
+ for task in result.get("tasks", []):
+ context_result[task["name"]] = task["result"]
# Restore or initialize the top_level_error
# that stores a reference to a top level error.
- if 'error' in result or 'traceback' in result:
+ if "error" in result or "traceback" in result:
top_level_error = {
- 'error': result.get('error'),
- 'traceback': result.get('traceback')
+ "error": result.get("error"),
+ "traceback": result.get("traceback"),
}
# If there are no executed tasks in the chain, then get the first node.
- if len(result['tasks']) <= 0:
+ if len(result["tasks"]) <= 0:
try:
action_node = self.chain_holder.get_next_node()
except Exception as e:
@@ -433,21 +461,24 @@ def _run_chain(self, action_parameters, resuming=False):
# Otherwise, figure out the last task executed and
# its state to determine where to begin executing.
else:
- last_task = result['tasks'][-1]
- action_node = self.chain_holder.get_node(last_task['name'])
- liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id'])
+ last_task = result["tasks"][-1]
+ action_node = self.chain_holder.get_node(last_task["name"])
+ liveaction = action_db_util.get_liveaction_by_id(last_task["liveaction_id"])
# If the liveaction of the last task has changed, update the result entry.
- if liveaction.status != last_task['state']:
+ if liveaction.status != last_task["state"]:
updated_task_result = self._get_updated_action_exec_result(
- action_node, liveaction, last_task)
- del result['tasks'][-1]
- result['tasks'].append(updated_task_result)
+ action_node, liveaction, last_task
+ )
+ del result["tasks"][-1]
+ result["tasks"].append(updated_task_result)
# Also need to update context_result so the updated result
# is available to Jinja expressions
- updated_task_name = updated_task_result['name']
- context_result[updated_task_name]['result'] = updated_task_result['result']
+ updated_task_name = updated_task_result["name"]
+ context_result[updated_task_name]["result"] = updated_task_result[
+ "result"
+ ]
# If the last task was canceled, then canceled the chain altogether.
if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED:
@@ -463,42 +494,52 @@ def _run_chain(self, action_parameters, resuming=False):
if liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED:
chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED
action_node = self.chain_holder.get_next_node(
- last_task['name'], condition='on-success')
+ last_task["name"], condition="on-success"
+ )
# If the last task failed, then get the next on-failure action node.
if liveaction.status in action_constants.LIVEACTION_FAILED_STATES:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
action_node = self.chain_holder.get_next_node(
- last_task['name'], condition='on-failure')
+ last_task["name"], condition="on-failure"
+ )
# Setup parent context.
- parent_context = {
- 'execution_id': self.execution_id
- }
+ parent_context = {"execution_id": self.execution_id}
- if getattr(self.liveaction, 'context', None):
+ if getattr(self.liveaction, "context", None):
parent_context.update(self.liveaction.context)
# Run the action chain until there are no more tasks.
while action_node:
error = None
liveaction = None
- last_task = result['tasks'][-1] if len(result['tasks']) > 0 else None
+ last_task = result["tasks"][-1] if len(result["tasks"]) > 0 else None
created_at = date_utils.get_datetime_utc_now()
try:
# If last task was paused, then fetch the liveaction and resume it first.
- if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED:
- liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id'])
- del result['tasks'][-1]
+ if (
+ last_task
+ and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
+ liveaction = action_db_util.get_liveaction_by_id(
+ last_task["liveaction_id"]
+ )
+ del result["tasks"][-1]
else:
liveaction = self._get_next_action(
- action_node=action_node, parent_context=parent_context,
- action_params=action_parameters, context_result=context_result)
+ action_node=action_node,
+ parent_context=parent_context,
+ action_params=action_parameters,
+ context_result=context_result,
+ )
except action_exc.InvalidActionReferencedException as e:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
- m = ('Failed to run task "%s". Action with reference "%s" doesn\'t exist.' %
- (action_node.name, action_node.ref))
+ m = (
+ 'Failed to run task "%s". Action with reference "%s" doesn\'t exist.'
+ % (action_node.name, action_node.ref)
+ )
LOG.exception(m)
top_level_error = self._format_error(e, m)
break
@@ -506,24 +547,41 @@ def _run_chain(self, action_parameters, resuming=False):
# Rendering parameters failed before we even got to running this action,
# abort and fail the whole action chain
chain_status = action_constants.LIVEACTION_STATUS_FAILED
- m = 'Failed to run task "%s". Parameter rendering failed.' % action_node.name
+ m = (
+ 'Failed to run task "%s". Parameter rendering failed.'
+ % action_node.name
+ )
LOG.exception(m)
top_level_error = self._format_error(e, m)
break
except db_exc.StackStormDBObjectNotFoundError as e:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
- m = 'Failed to resume task "%s". Unable to find liveaction.' % action_node.name
+ m = (
+ 'Failed to resume task "%s". Unable to find liveaction.'
+ % action_node.name
+ )
LOG.exception(m)
top_level_error = self._format_error(e, m)
break
try:
# If last task was paused, then fetch the liveaction and resume it first.
- if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED:
- LOG.info('Resume task %s for chain %s.', action_node.name, self.liveaction.id)
+ if (
+ last_task
+ and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
+ LOG.info(
+ "Resume task %s for chain %s.",
+ action_node.name,
+ self.liveaction.id,
+ )
liveaction = self._resume_action(liveaction)
else:
- LOG.info('Run task %s for chain %s.', action_node.name, self.liveaction.id)
+ LOG.info(
+ "Run task %s for chain %s.",
+ action_node.name,
+ self.liveaction.id,
+ )
liveaction = self._run_action(liveaction)
except Exception as e:
# Save the traceback and error message
@@ -537,9 +595,12 @@ def _run_chain(self, action_parameters, resuming=False):
# Render and publish variables
rendered_publish_vars = ActionChainRunner._render_publish_vars(
- action_node=action_node, action_parameters=action_parameters,
- execution_result=liveaction.result, previous_execution_results=context_result,
- chain_vars=self.chain_holder.vars)
+ action_node=action_node,
+ action_parameters=action_parameters,
+ execution_result=liveaction.result,
+ previous_execution_results=context_result,
+ chain_vars=self.chain_holder.vars,
+ )
if rendered_publish_vars:
self.chain_holder.vars.update(rendered_publish_vars)
@@ -550,49 +611,68 @@ def _run_chain(self, action_parameters, resuming=False):
updated_at = date_utils.get_datetime_utc_now()
task_result = self._format_action_exec_result(
- action_node,
- liveaction,
- created_at,
- updated_at,
- error=error
+ action_node, liveaction, created_at, updated_at, error=error
)
- result['tasks'].append(task_result)
+ result["tasks"].append(task_result)
try:
if not liveaction:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
action_node = self.chain_holder.get_next_node(
- action_node.name, condition='on-failure')
- elif liveaction.status == action_constants.LIVEACTION_STATUS_TIMED_OUT:
+ action_node.name, condition="on-failure"
+ )
+ elif (
+ liveaction.status
+ == action_constants.LIVEACTION_STATUS_TIMED_OUT
+ ):
chain_status = action_constants.LIVEACTION_STATUS_TIMED_OUT
action_node = self.chain_holder.get_next_node(
- action_node.name, condition='on-failure')
- elif liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED:
- LOG.info('Chain execution (%s) canceled because task "%s" is canceled.',
- self.liveaction_id, action_node.name)
+ action_node.name, condition="on-failure"
+ )
+ elif (
+ liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED
+ ):
+ LOG.info(
+ 'Chain execution (%s) canceled because task "%s" is canceled.',
+ self.liveaction_id,
+ action_node.name,
+ )
chain_status = action_constants.LIVEACTION_STATUS_CANCELED
action_node = None
elif liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED:
- LOG.info('Chain execution (%s) paused because task "%s" is paused.',
- self.liveaction_id, action_node.name)
+ LOG.info(
+ 'Chain execution (%s) paused because task "%s" is paused.',
+ self.liveaction_id,
+ action_node.name,
+ )
chain_status = action_constants.LIVEACTION_STATUS_PAUSED
self._save_vars()
action_node = None
- elif liveaction.status == action_constants.LIVEACTION_STATUS_PENDING:
- LOG.info('Chain execution (%s) paused because task "%s" is pending.',
- self.liveaction_id, action_node.name)
+ elif (
+ liveaction.status == action_constants.LIVEACTION_STATUS_PENDING
+ ):
+ LOG.info(
+ 'Chain execution (%s) paused because task "%s" is pending.',
+ self.liveaction_id,
+ action_node.name,
+ )
chain_status = action_constants.LIVEACTION_STATUS_PAUSED
self._save_vars()
action_node = None
elif liveaction.status in action_constants.LIVEACTION_FAILED_STATES:
chain_status = action_constants.LIVEACTION_STATUS_FAILED
action_node = self.chain_holder.get_next_node(
- action_node.name, condition='on-failure')
- elif liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED:
+ action_node.name, condition="on-failure"
+ )
+ elif (
+ liveaction.status
+ == action_constants.LIVEACTION_STATUS_SUCCEEDED
+ ):
chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED
action_node = self.chain_holder.get_next_node(
- action_node.name, condition='on-success')
+ action_node.name, condition="on-success"
+ )
else:
action_node = None
except Exception as e:
@@ -604,12 +684,12 @@ def _run_chain(self, action_parameters, resuming=False):
break
if action_service.is_action_canceled_or_canceling(self.liveaction.id):
- LOG.info('Chain execution (%s) canceled by user.', self.liveaction.id)
+ LOG.info("Chain execution (%s) canceled by user.", self.liveaction.id)
chain_status = action_constants.LIVEACTION_STATUS_CANCELED
return (chain_status, result, None)
if action_service.is_action_paused_or_pausing(self.liveaction.id):
- LOG.info('Chain execution (%s) paused by user.', self.liveaction.id)
+ LOG.info("Chain execution (%s) paused by user.", self.liveaction.id)
chain_status = action_constants.LIVEACTION_STATUS_PAUSED
self._save_vars()
return (chain_status, result, self.liveaction.context)
@@ -621,17 +701,22 @@ def _run_chain(self, action_parameters, resuming=False):
def _format_error(self, e, msg):
return {
- 'error': '%s. %s' % (msg, six.text_type(e)),
- 'traceback': traceback.format_exc(10)
+ "error": "%s. %s" % (msg, six.text_type(e)),
+ "traceback": traceback.format_exc(10),
}
def _save_vars(self):
# Save the context vars in the liveaction context.
- self.liveaction.context['vars'] = self.chain_holder.vars
+ self.liveaction.context["vars"] = self.chain_holder.vars
@staticmethod
- def _render_publish_vars(action_node, action_parameters, execution_result,
- previous_execution_results, chain_vars):
+ def _render_publish_vars(
+ action_node,
+ action_parameters,
+ execution_result,
+ previous_execution_results,
+ chain_vars,
+ ):
"""
If no output is specified on the action_node the output is the entire execution_result.
If any output is specified then only those variables are published as output of an
@@ -649,36 +734,48 @@ def _render_publish_vars(action_node, action_parameters, execution_result,
context.update(chain_vars)
context.update({RESULTS_KEY: previous_execution_results})
- context.update({
- kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.SYSTEM_SCOPE)
- })
-
- context.update({
- kv_constants.DATASTORE_PARENT_SCOPE: {
+ context.update(
+ {
kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.FULL_SYSTEM_SCOPE)
+ scope=kv_constants.SYSTEM_SCOPE
+ )
}
- })
+ )
+
+ context.update(
+ {
+ kv_constants.DATASTORE_PARENT_SCOPE: {
+ kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
+ scope=kv_constants.FULL_SYSTEM_SCOPE
+ )
+ }
+ }
+ )
try:
- rendered_result = jinja_utils.render_values(mapping=action_node.publish,
- context=context)
+ rendered_result = jinja_utils.render_values(
+ mapping=action_node.publish, context=context
+ )
except Exception as e:
- key = getattr(e, 'key', None)
- value = getattr(e, 'value', None)
- msg = ('Failed rendering value for publish parameter "%s" in task "%s" '
- '(template string=%s): %s' % (key, action_node.name, value, six.text_type(e)))
+ key = getattr(e, "key", None)
+ value = getattr(e, "value", None)
+ msg = (
+ 'Failed rendering value for publish parameter "%s" in task "%s" '
+ "(template string=%s): %s"
+ % (key, action_node.name, value, six.text_type(e))
+ )
raise action_exc.ParameterRenderingFailedException(msg)
return rendered_result
@staticmethod
- def _resolve_params(action_node, original_parameters, results, chain_vars, chain_context):
+ def _resolve_params(
+ action_node, original_parameters, results, chain_vars, chain_context
+ ):
# setup context with original parameters and the intermediate results.
- chain_parent = chain_context.get('parent', {})
- pack = chain_parent.get('pack')
- user = chain_parent.get('user')
+ chain_parent = chain_context.get("parent", {})
+ pack = chain_parent.get("pack")
+ user = chain_parent.get("user")
config = get_config(pack, user)
@@ -688,34 +785,47 @@ def _resolve_params(action_node, original_parameters, results, chain_vars, chain
context.update(chain_vars)
context.update({RESULTS_KEY: results})
- context.update({
- kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.SYSTEM_SCOPE)
- })
-
- context.update({
- kv_constants.DATASTORE_PARENT_SCOPE: {
+ context.update(
+ {
kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.FULL_SYSTEM_SCOPE)
+ scope=kv_constants.SYSTEM_SCOPE
+ )
}
- })
+ )
+
+ context.update(
+ {
+ kv_constants.DATASTORE_PARENT_SCOPE: {
+ kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
+ scope=kv_constants.FULL_SYSTEM_SCOPE
+ )
+ }
+ }
+ )
context.update({action_constants.ACTION_CONTEXT_KV_PREFIX: chain_context})
context.update({pack_constants.PACK_CONFIG_CONTEXT_KV_PREFIX: config})
try:
- rendered_params = jinja_utils.render_values(mapping=action_node.get_parameters(),
- context=context)
+ rendered_params = jinja_utils.render_values(
+ mapping=action_node.get_parameters(), context=context
+ )
except Exception as e:
LOG.exception('Jinja rendering for parameter "%s" failed.' % (e.key))
- key = getattr(e, 'key', None)
- value = getattr(e, 'value', None)
- msg = ('Failed rendering value for action parameter "%s" in task "%s" '
- '(template string=%s): %s') % (key, action_node.name, value, six.text_type(e))
+ key = getattr(e, "key", None)
+ value = getattr(e, "value", None)
+ msg = (
+ 'Failed rendering value for action parameter "%s" in task "%s" '
+ "(template string=%s): %s"
+ ) % (key, action_node.name, value, six.text_type(e))
raise action_exc.ParameterRenderingFailedException(msg)
- LOG.debug('Rendered params: %s: Type: %s', rendered_params, type(rendered_params))
+ LOG.debug(
+ "Rendered params: %s: Type: %s", rendered_params, type(rendered_params)
+ )
return rendered_params
- def _get_next_action(self, action_node, parent_context, action_params, context_result):
+ def _get_next_action(
+ self, action_node, parent_context, action_params, context_result
+ ):
# Verify that the referenced action exists
# TODO: We do another lookup in cast_param, refactor to reduce number of lookups
task_name = action_node.name
@@ -723,18 +833,25 @@ def _get_next_action(self, action_node, parent_context, action_params, context_r
action_db = action_db_util.get_action_by_ref(ref=action_ref)
if not action_db:
- error = 'Task :: %s - Action with ref %s not registered.' % (task_name, action_ref)
+ error = "Task :: %s - Action with ref %s not registered." % (
+ task_name,
+ action_ref,
+ )
raise action_exc.InvalidActionReferencedException(error)
resolved_params = ActionChainRunner._resolve_params(
- action_node=action_node, original_parameters=action_params,
- results=context_result, chain_vars=self.chain_holder.vars,
- chain_context={'parent': parent_context})
+ action_node=action_node,
+ original_parameters=action_params,
+ results=context_result,
+ chain_vars=self.chain_holder.vars,
+ chain_context={"parent": parent_context},
+ )
liveaction = self._build_liveaction_object(
action_node=action_node,
resolved_params=resolved_params,
- parent_context=parent_context)
+ parent_context=parent_context,
+ )
return liveaction
@@ -747,13 +864,16 @@ def _run_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0):
liveaction, _ = action_service.request(liveaction)
except Exception as e:
liveaction.status = action_constants.LIVEACTION_STATUS_FAILED
- LOG.exception('Failed to schedule liveaction.')
+ LOG.exception("Failed to schedule liveaction.")
raise e
- while (wait_for_completion and liveaction.status not in (
- action_constants.LIVEACTION_COMPLETED_STATES +
- [action_constants.LIVEACTION_STATUS_PAUSED,
- action_constants.LIVEACTION_STATUS_PENDING])):
+ while wait_for_completion and liveaction.status not in (
+ action_constants.LIVEACTION_COMPLETED_STATES
+ + [
+ action_constants.LIVEACTION_STATUS_PAUSED,
+ action_constants.LIVEACTION_STATUS_PENDING,
+ ]
+ ):
eventlet.sleep(sleep_delay)
liveaction = action_db_util.get_liveaction_by_id(liveaction.id)
@@ -765,16 +885,17 @@ def _resume_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0):
:type sleep_delay: ``float``
"""
try:
- user = self.context.get('user', None)
+ user = self.context.get("user", None)
liveaction, _ = action_service.request_resume(liveaction, user)
except Exception as e:
liveaction.status = action_constants.LIVEACTION_STATUS_FAILED
- LOG.exception('Failed to schedule liveaction.')
+ LOG.exception("Failed to schedule liveaction.")
raise e
- while (wait_for_completion and liveaction.status not in (
- action_constants.LIVEACTION_COMPLETED_STATES +
- [action_constants.LIVEACTION_STATUS_PAUSED])):
+ while wait_for_completion and liveaction.status not in (
+ action_constants.LIVEACTION_COMPLETED_STATES
+ + [action_constants.LIVEACTION_STATUS_PAUSED]
+ ):
eventlet.sleep(sleep_delay)
liveaction = action_db_util.get_liveaction_by_id(liveaction.id)
@@ -787,14 +908,12 @@ def _build_liveaction_object(self, action_node, resolved_params, parent_context)
notify = self._get_notify(action_node)
if notify:
liveaction.notify = notify
- LOG.debug('%s: Task notify set to: %s', action_node.name, liveaction.notify)
+ LOG.debug("%s: Task notify set to: %s", action_node.name, liveaction.notify)
- liveaction.context = {
- 'parent': parent_context,
- 'chain': vars(action_node)
- }
- liveaction.parameters = action_param_utils.cast_params(action_ref=action_node.ref,
- params=resolved_params)
+ liveaction.context = {"parent": parent_context, "chain": vars(action_node)}
+ liveaction.parameters = action_param_utils.cast_params(
+ action_ref=action_node.ref, params=resolved_params
+ )
return liveaction
def _get_notify(self, action_node):
@@ -807,18 +926,23 @@ def _get_notify(self, action_node):
return None
- def _get_updated_action_exec_result(self, action_node, liveaction, prev_task_result):
+ def _get_updated_action_exec_result(
+ self, action_node, liveaction, prev_task_result
+ ):
if liveaction.status in action_constants.LIVEACTION_COMPLETED_STATES:
- created_at = isotime.parse(prev_task_result['created_at'])
+ created_at = isotime.parse(prev_task_result["created_at"])
updated_at = liveaction.end_timestamp
else:
- created_at = isotime.parse(prev_task_result['created_at'])
- updated_at = isotime.parse(prev_task_result['updated_at'])
+ created_at = isotime.parse(prev_task_result["created_at"])
+ updated_at = isotime.parse(prev_task_result["updated_at"])
- return self._format_action_exec_result(action_node, liveaction, created_at, updated_at)
+ return self._format_action_exec_result(
+ action_node, liveaction, created_at, updated_at
+ )
- def _format_action_exec_result(self, action_node, liveaction_db, created_at, updated_at,
- error=None):
+ def _format_action_exec_result(
+ self, action_node, liveaction_db, created_at, updated_at, error=None
+ ):
"""
Format ActionExecution result so it can be used in the final action result output.
@@ -833,24 +957,24 @@ def _format_action_exec_result(self, action_node, liveaction_db, created_at, upd
if liveaction_db:
execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id))
- result['id'] = action_node.name
- result['name'] = action_node.name
- result['execution_id'] = str(execution_db.id) if execution_db else None
- result['liveaction_id'] = str(liveaction_db.id) if liveaction_db else None
- result['workflow'] = None
+ result["id"] = action_node.name
+ result["name"] = action_node.name
+ result["execution_id"] = str(execution_db.id) if execution_db else None
+ result["liveaction_id"] = str(liveaction_db.id) if liveaction_db else None
+ result["workflow"] = None
- result['created_at'] = isotime.format(dt=created_at)
- result['updated_at'] = isotime.format(dt=updated_at)
+ result["created_at"] = isotime.format(dt=created_at)
+ result["updated_at"] = isotime.format(dt=updated_at)
if error or not liveaction_db:
- result['state'] = action_constants.LIVEACTION_STATUS_FAILED
+ result["state"] = action_constants.LIVEACTION_STATUS_FAILED
else:
- result['state'] = liveaction_db.status
+ result["state"] = liveaction_db.status
if error:
- result['result'] = error
+ result["result"] = error
else:
- result['result'] = liveaction_db.result
+ result["result"] = liveaction_db.result
return result
@@ -860,4 +984,4 @@ def get_runner():
def get_metadata():
- return get_runner_metadata('action_chain_runner')[0]
+ return get_runner_metadata("action_chain_runner")[0]
diff --git a/contrib/runners/action_chain_runner/dist_utils.py b/contrib/runners/action_chain_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/action_chain_runner/dist_utils.py
+++ b/contrib/runners/action_chain_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/action_chain_runner/setup.py b/contrib/runners/action_chain_runner/setup.py
index 6c2043505c..7c96e1e1d1 100644
--- a/contrib/runners/action_chain_runner/setup.py
+++ b/contrib/runners/action_chain_runner/setup.py
@@ -26,31 +26,33 @@
from action_chain_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-action-chain',
+ name="stackstorm-runner-action-chain",
version=__version__,
- description=('Action-Chain workflow action runner for StackStorm event-driven '
- 'automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "Action-Chain workflow action runner for StackStorm event-driven "
+ "automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'action_chain_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"action_chain_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'action-chain = action_chain_runner.action_chain_runner',
+ "st2common.runners.runner": [
+ "action-chain = action_chain_runner.action_chain_runner",
],
- }
+ },
)
diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py
index 32bb5c9249..9daed4fa90 100644
--- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py
+++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py
@@ -39,99 +39,135 @@
class DummyActionExecution(object):
- def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''):
+ def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""):
self.id = None
self.status = status
self.result = result
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_MODELS = {
- 'actions': ['a1.yaml', 'a2.yaml', 'action_4_action_context_param.yaml'],
- 'runners': ['testrunner1.yaml']
+ "actions": ["a1.yaml", "a2.yaml", "action_4_action_context_param.yaml"],
+ "runners": ["testrunner1.yaml"],
}
-MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
-ACTION_1 = MODELS['actions']['a1.yaml']
-ACTION_2 = MODELS['actions']['a2.yaml']
-ACTION_3 = MODELS['actions']['action_4_action_context_param.yaml']
-RUNNER = MODELS['runners']['testrunner1.yaml']
+MODELS = FixturesLoader().load_models(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+)
+ACTION_1 = MODELS["actions"]["a1.yaml"]
+ACTION_2 = MODELS["actions"]["a2.yaml"]
+ACTION_3 = MODELS["actions"]["action_4_action_context_param.yaml"]
+RUNNER = MODELS["runners"]["testrunner1.yaml"]
CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain1.yaml')
+ FIXTURES_PACK, "actionchains", "chain1.yaml"
+)
CHAIN_2_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain2.yaml')
+ FIXTURES_PACK, "actionchains", "chain2.yaml"
+)
CHAIN_ACTION_CALL_NO_PARAMS_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_action_call_no_params.yaml')
+ FIXTURES_PACK, "actionchains", "chain_action_call_no_params.yaml"
+)
CHAIN_NO_DEFAULT = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'no_default_chain.yaml')
+ FIXTURES_PACK, "actionchains", "no_default_chain.yaml"
+)
CHAIN_NO_DEFAULT_2 = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'no_default_chain_2.yaml')
+ FIXTURES_PACK, "actionchains", "no_default_chain_2.yaml"
+)
CHAIN_BAD_DEFAULT = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'bad_default_chain.yaml')
-CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_broken_on_success_path_static_task_name.yaml')
-CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_broken_on_failure_path_static_task_name.yaml')
+ FIXTURES_PACK, "actionchains", "bad_default_chain.yaml"
+)
+CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = (
+ FixturesLoader().get_fixture_file_path_abs(
+ FIXTURES_PACK,
+ "actionchains",
+ "chain_broken_on_success_path_static_task_name.yaml",
+ )
+)
+CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = (
+ FixturesLoader().get_fixture_file_path_abs(
+ FIXTURES_PACK,
+ "actionchains",
+ "chain_broken_on_failure_path_static_task_name.yaml",
+ )
+)
CHAIN_FIRST_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_first_task_parameter_render_fail.yaml')
+ FIXTURES_PACK, "actionchains", "chain_first_task_parameter_render_fail.yaml"
+)
CHAIN_SECOND_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_second_task_parameter_render_fail.yaml')
+ FIXTURES_PACK, "actionchains", "chain_second_task_parameter_render_fail.yaml"
+)
CHAIN_LIST_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_list_template.yaml')
+ FIXTURES_PACK, "actionchains", "chain_list_template.yaml"
+)
CHAIN_DICT_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_dict_template.yaml')
+ FIXTURES_PACK, "actionchains", "chain_dict_template.yaml"
+)
CHAIN_DEP_INPUT = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_dependent_input.yaml')
+ FIXTURES_PACK, "actionchains", "chain_dependent_input.yaml"
+)
CHAIN_DEP_RESULTS_INPUT = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_dep_result_input.yaml')
+ FIXTURES_PACK, "actionchains", "chain_dep_result_input.yaml"
+)
MALFORMED_CHAIN_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'malformedchain.yaml')
+ FIXTURES_PACK, "actionchains", "malformedchain.yaml"
+)
CHAIN_TYPED_PARAMS = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_typed_params.yaml')
+ FIXTURES_PACK, "actionchains", "chain_typed_params.yaml"
+)
CHAIN_SYSTEM_PARAMS = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_typed_system_params.yaml')
+ FIXTURES_PACK, "actionchains", "chain_typed_system_params.yaml"
+)
CHAIN_WITH_ACTIONPARAM_VARS = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_actionparam_vars.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_actionparam_vars.yaml"
+)
CHAIN_WITH_SYSTEM_VARS = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_system_vars.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_system_vars.yaml"
+)
CHAIN_WITH_PUBLISH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_publish.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_publish.yaml"
+)
CHAIN_WITH_PUBLISH_2 = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_publish_2.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_publish_2.yaml"
+)
CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_publish_params_rendering_failure.yaml')
+ FIXTURES_PACK, "actionchains", "chain_publish_params_rendering_failure.yaml"
+)
CHAIN_WITH_INVALID_ACTION = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_invalid_action.yaml')
-CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_action_params_and_parameters.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_invalid_action.yaml"
+)
+CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = (
+ FixturesLoader().get_fixture_file_path_abs(
+ FIXTURES_PACK, "actionchains", "chain_action_params_and_parameters.yaml"
+ )
+)
CHAIN_ACTION_PARAMS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_action_params_attribute.yaml')
+ FIXTURES_PACK, "actionchains", "chain_action_params_attribute.yaml"
+)
CHAIN_ACTION_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_action_parameters_attribute.yaml')
+ FIXTURES_PACK, "actionchains", "chain_action_parameters_attribute.yaml"
+)
CHAIN_ACTION_INVALID_PARAMETER_TYPE = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_invalid_parameter_type_passed_to_action.yaml')
+ FIXTURES_PACK, "actionchains", "chain_invalid_parameter_type_passed_to_action.yaml"
+)
-CHAIN_NOTIFY_API = {'notify': {'on-complete': {'message': 'foo happened.'}}}
+CHAIN_NOTIFY_API = {"notify": {"on-complete": {"message": "foo happened."}}}
CHAIN_NOTIFY_DB = NotificationsHelper.to_model(CHAIN_NOTIFY_API)
@mock.patch.object(
- action_db_util,
- 'get_runnertype_by_name',
- mock.MagicMock(return_value=RUNNER))
+ action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER)
+)
@mock.patch.object(
action_service,
- 'is_action_canceled_or_canceling',
- mock.MagicMock(return_value=False))
+ "is_action_canceled_or_canceling",
+ mock.MagicMock(return_value=False),
+)
@mock.patch.object(
- action_service,
- 'is_action_paused_or_pausing',
- mock.MagicMock(return_value=False))
+ action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False)
+)
class TestActionChainRunner(ExecutionDbTestCase):
-
def test_runner_creation(self):
runner = acr.get_runner()
self.assertTrue(runner)
@@ -143,18 +179,23 @@ def test_malformed_chain(self):
chain_runner.entry_point = MALFORMED_CHAIN_PATH
chain_runner.action = ACTION_1
chain_runner.pre_run()
- self.assertTrue(False, 'Expected pre_run to fail.')
+ self.assertTrue(False, "Expected pre_run to fail.")
except runnerexceptions.ActionRunnerPreRunError:
self.assertTrue(True)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_success_path(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_1_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.liveaction.notify = CHAIN_NOTIFY_DB
chain_runner.pre_run()
@@ -163,9 +204,12 @@ def test_chain_runner_success_path(self, request):
# based on the chain the callcount is known to be 3. Not great but works.
self.assertEqual(request.call_count, 3)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_chain_second_task_times_out(self, request):
# Second task in the chain times out so the action chain status should be timeout
chain_runner = acr.get_runner()
@@ -177,13 +221,15 @@ def test_chain_runner_chain_second_task_times_out(self, request):
def mock_run_action(*args, **kwargs):
original_live_action = args[0]
liveaction = original_run_action(*args, **kwargs)
- if original_live_action.action == 'wolfpack.a2':
+ if original_live_action.action == "wolfpack.a2":
# Mock a timeout for second task
liveaction.status = LIVEACTION_STATUS_TIMED_OUT
return liveaction
chain_runner._run_action = mock_run_action
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, _, _ = chain_runner.run({})
@@ -193,9 +239,12 @@ def mock_run_action(*args, **kwargs):
# based on the chain the callcount is known to be 3. Not great but works.
self.assertEqual(request.call_count, 3)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_task_is_canceled_while_running(self, request):
# Second task in the action is CANCELED, make sure runner doesn't get stuck in an infinite
# loop
@@ -207,7 +256,7 @@ def test_chain_runner_task_is_canceled_while_running(self, request):
def mock_run_action(*args, **kwargs):
original_live_action = args[0]
- if original_live_action.action == 'wolfpack.a2':
+ if original_live_action.action == "wolfpack.a2":
status = LIVEACTION_STATUS_CANCELED
else:
status = LIVEACTION_STATUS_SUCCEEDED
@@ -216,7 +265,9 @@ def mock_run_action(*args, **kwargs):
return liveaction
chain_runner._run_action = mock_run_action
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, _, _ = chain_runner.run({})
@@ -227,16 +278,21 @@ def mock_run_action(*args, **kwargs):
# canceled
self.assertEqual(request.call_count, 2)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_success_task_action_call_with_no_params(self, request):
# Make sure that the runner doesn't explode if task definition contains
# no "params" section
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_ACTION_CALL_NO_PARAMS_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.liveaction.notify = CHAIN_NOTIFY_DB
chain_runner.pre_run()
@@ -245,14 +301,19 @@ def test_chain_runner_success_task_action_call_with_no_params(self, request):
# based on the chain the callcount is known to be 3. Not great but works.
self.assertEqual(request.call_count, 3)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_no_default(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_NO_DEFAULT
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
chain_runner.run({})
@@ -264,9 +325,12 @@ def test_chain_runner_no_default(self, request):
# based on the chain the callcount is known to be 3. Not great but works.
self.assertEqual(request.call_count, 3)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_no_default_multiple_options(self, request):
# subtle difference is that when there are multiple possible default nodes
# the order per chain definition may not be preseved. This is really a
@@ -274,7 +338,9 @@ def test_chain_runner_no_default_multiple_options(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_NO_DEFAULT_2
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
chain_runner.run({})
@@ -286,29 +352,44 @@ def test_chain_runner_no_default_multiple_options(self, request):
# based on the chain the callcount is known to be 2.
self.assertEqual(request.call_count, 2)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_bad_default(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_BAD_DEFAULT
chain_runner.action = ACTION_1
- expected_msg = 'Unable to find node with name "bad_default" referenced in "default".'
- self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError,
- expected_msg, chain_runner.pre_run)
-
- @mock.patch('eventlet.sleep', mock.MagicMock())
- @mock.patch.object(action_db_util, 'get_liveaction_by_id', mock.MagicMock(
- return_value=DummyActionExecution()))
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None))
+ expected_msg = (
+ 'Unable to find node with name "bad_default" referenced in "default".'
+ )
+ self.assertRaisesRegexp(
+ runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run
+ )
+
+ @mock.patch("eventlet.sleep", mock.MagicMock())
+ @mock.patch.object(
+ action_db_util,
+ "get_liveaction_by_id",
+ mock.MagicMock(return_value=DummyActionExecution()),
+ )
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None),
+ )
def test_chain_runner_success_path_with_wait(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_1_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
chain_runner.run({})
@@ -316,15 +397,21 @@ def test_chain_runner_success_path_with_wait(self, request):
# based on the chain the callcount is known to be 3. Not great but works.
self.assertEqual(request.call_count, 3)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None),
+ )
def test_chain_runner_failure_path(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_1_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, _, _ = chain_runner.run({})
@@ -333,42 +420,57 @@ def test_chain_runner_failure_path(self, request):
# based on the chain the callcount is known to be 2. Not great but works.
self.assertEqual(request.call_count, 2)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_broken_on_success_path_static_task_name(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME
chain_runner.action = ACTION_1
- expected_msg = ('Unable to find node with name "c5" referenced in "on-success" '
- 'in task "c2"')
- self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError,
- expected_msg, chain_runner.pre_run)
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(), None))
+ expected_msg = (
+ 'Unable to find node with name "c5" referenced in "on-success" '
+ 'in task "c2"'
+ )
+ self.assertRaisesRegexp(
+ runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run
+ )
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_broken_on_failure_path_static_task_name(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME
chain_runner.action = ACTION_1
- expected_msg = ('Unable to find node with name "c6" referenced in "on-failure" '
- 'in task "c2"')
- self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError,
- expected_msg, chain_runner.pre_run)
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', side_effect=RuntimeError('Test Failure.'))
+ expected_msg = (
+ 'Unable to find node with name "c6" referenced in "on-failure" '
+ 'in task "c2"'
+ )
+ self.assertRaisesRegexp(
+ runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run
+ )
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", side_effect=RuntimeError("Test Failure.")
+ )
def test_chain_runner_action_exception(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_1_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, results, _ = chain_runner.run({})
@@ -379,102 +481,131 @@ def test_chain_runner_action_exception(self, request):
self.assertEqual(request.call_count, 2)
error_count = 0
- for task_result in results['tasks']:
- if task_result['result'].get('error', None):
+ for task_result in results["tasks"]:
+ if task_result["result"].get("error", None):
error_count += 1
self.assertEqual(error_count, 2)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_str_param_temp(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4})
+ chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, {"p1": "1"})
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_list_param_temp(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_LIST_TEMP_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4})
+ chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, {"p1": "[2, 3, 4]"})
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_dict_param_temp(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_DICT_TEMP_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4})
+ chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
expected_value = {"p1": {"p1.3": "[3, 4]", "p1.2": "2", "p1.1": "1"}}
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(result={'o1': '1'}), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(result={"o1": "1"}), None),
+ )
def test_chain_runner_dependent_param_temp(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_DEP_INPUT
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4})
+ chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_values = [{u'p1': u'1'},
- {u'p1': u'1'},
- {u'p2': u'1', u'p3': u'1', u'p1': u'1'}]
+ expected_values = [{"p1": "1"}, {"p1": "1"}, {"p2": "1", "p3": "1", "p1": "1"}]
# Each of the call_args must be one of
for call_args in request.call_args_list:
self.assertIn(call_args[0][0].parameters, expected_values)
expected_values.remove(call_args[0][0].parameters)
- self.assertEqual(len(expected_values), 0, 'Not all expected values received.')
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(result={'o1': '1'}), None))
+ self.assertEqual(len(expected_values), 0, "Not all expected values received.")
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(result={"o1": "1"}), None),
+ )
def test_chain_runner_dependent_results_param(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_DEP_RESULTS_INPUT
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1})
+ chain_runner.run({"s1": 1})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
if six.PY2:
- expected_values = [{u'p1': u'1'},
- {u'p1': u'1'},
- {u'out': u"{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"}]
+ expected_values = [
+ {"p1": "1"},
+ {"p1": "1"},
+ {"out": "{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"},
+ ]
else:
- expected_values = [{'p1': '1'},
- {'p1': '1'},
- {'out': "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"}]
+ expected_values = [
+ {"p1": "1"},
+ {"p1": "1"},
+ {"out": "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"},
+ ]
# Each of the call_args must be one of
self.assertEqual(request.call_count, 3)
@@ -482,104 +613,137 @@ def test_chain_runner_dependent_results_param(self, request):
self.assertIn(call_args[0][0].parameters, expected_values)
expected_values.remove(call_args[0][0].parameters)
- self.assertEqual(len(expected_values), 0, 'Not all expected values received.')
+ self.assertEqual(len(expected_values), 0, "Not all expected values received.")
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(RunnerType, 'get_by_name',
- mock.MagicMock(return_value=RUNNER))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(RunnerType, "get_by_name", mock.MagicMock(return_value=RUNNER))
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_missing_param_temp(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
chain_runner.run({})
- self.assertEqual(request.call_count, 0, 'No call expected.')
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ self.assertEqual(request.call_count, 0, "No call expected.")
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_failure_during_param_rendering_single_task(self, request):
# Parameter rendering should result in a top level error which aborts
# the whole chain
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, result, _ = chain_runner.run({})
# No tasks ran because rendering of parameters for the first task failed
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
- self.assertEqual(result['tasks'], [])
- self.assertIn('error', result)
- self.assertIn('traceback', result)
- self.assertIn('Failed to run task "c1". Parameter rendering failed', result['error'])
- self.assertIn('Traceback', result['traceback'])
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ self.assertEqual(result["tasks"], [])
+ self.assertIn("error", result)
+ self.assertIn("traceback", result)
+ self.assertIn(
+ 'Failed to run task "c1". Parameter rendering failed', result["error"]
+ )
+ self.assertIn("Traceback", result["traceback"])
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_failure_during_param_rendering_multiple_tasks(self, request):
# Parameter rendering should result in a top level error which aborts
# the whole chain
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_SECOND_TASK_RENDER_FAIL_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
status, result, _ = chain_runner.run({})
# Verify that only first task has ran
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
- self.assertEqual(len(result['tasks']), 1)
- self.assertEqual(result['tasks'][0]['name'], 'c1')
-
- expected_error = ('Failed rendering value for action parameter "p1" in '
- 'task "c2" (template string={{s1}}):')
-
- self.assertIn('error', result)
- self.assertIn('traceback', result)
- self.assertIn('Failed to run task "c2". Parameter rendering failed', result['error'])
- self.assertIn(expected_error, result['error'])
- self.assertIn('Traceback', result['traceback'])
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ self.assertEqual(len(result["tasks"]), 1)
+ self.assertEqual(result["tasks"][0]["name"], "c1")
+
+ expected_error = (
+ 'Failed rendering value for action parameter "p1" in '
+ 'task "c2" (template string={{s1}}):'
+ )
+
+ self.assertIn("error", result)
+ self.assertIn("traceback", result)
+ self.assertIn(
+ 'Failed to run task "c2". Parameter rendering failed', result["error"]
+ )
+ self.assertIn(expected_error, result["error"])
+ self.assertIn("Traceback", result["traceback"])
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_typed_params(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_TYPED_PARAMS
chain_runner.action = ACTION_2
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'s1': 1, 's2': 'two', 's3': 3.14})
+ chain_runner.run({"s1": 1, "s2": "two", "s3": 3.14})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_value = {'booltype': True,
- 'inttype': 1,
- 'numbertype': 3.14,
- 'strtype': 'two',
- 'arrtype': ['1', 'two'],
- 'objtype': {'s2': 'two',
- 'k1': '1'}}
+ expected_value = {
+ "booltype": True,
+ "inttype": 1,
+ "numbertype": 3.14,
+ "strtype": "two",
+ "arrtype": ["1", "two"],
+ "objtype": {"s2": "two", "k1": "1"},
+ }
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_typed_system_params(self, request):
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
kvps = []
try:
- kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='1')))
- kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='two')))
+ kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="1")))
+ kvps.append(
+ KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="two"))
+ )
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_SYSTEM_PARAMS
chain_runner.action = ACTION_2
@@ -587,22 +751,28 @@ def test_chain_runner_typed_system_params(self, request):
chain_runner.pre_run()
chain_runner.run({})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_value = {'inttype': 1,
- 'strtype': 'two'}
+ expected_value = {"inttype": 1, "strtype": "two"}
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
finally:
for kvp in kvps:
KeyValuePair.delete(kvp)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_vars_system_params(self, request):
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
kvps = []
try:
- kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='two')))
+ kvps.append(
+ KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="two"))
+ )
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_SYSTEM_VARS
chain_runner.action = ACTION_2
@@ -610,72 +780,88 @@ def test_chain_runner_vars_system_params(self, request):
chain_runner.pre_run()
chain_runner.run({})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_value = {'inttype': 1,
- 'strtype': 'two',
- 'booltype': True}
+ expected_value = {"inttype": 1, "strtype": "two", "booltype": True}
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
finally:
for kvp in kvps:
KeyValuePair.delete(kvp)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_vars_action_params(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_ACTIONPARAM_VARS
chain_runner.action = ACTION_2
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
- chain_runner.run({'input_a': 'two'})
+ chain_runner.run({"input_a": "two"})
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_value = {'inttype': 1,
- 'strtype': 'two',
- 'booltype': True}
+ expected_value = {"inttype": 1, "strtype": "two", "booltype": True}
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(result={'raw_out': 'published'}), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(result={"raw_out": "published"}), None),
+ )
def test_chain_runner_publish(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_PUBLISH
chain_runner.action = ACTION_2
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
- chain_runner.runner_parameters = {'display_published': True}
+ chain_runner.runner_parameters = {"display_published": True}
chain_runner.pre_run()
- action_parameters = {'action_param_1': 'test value 1'}
+ action_parameters = {"action_param_1": "test value 1"}
_, result, _ = chain_runner.run(action_parameters=action_parameters)
# We also assert that the action parameters are available in the
# "publish" scope
self.assertNotEqual(chain_runner.chain_holder.actionchain, None)
- expected_value = {'inttype': 1,
- 'strtype': 'published',
- 'booltype': True,
- 'published_action_param': action_parameters['action_param_1']}
+ expected_value = {
+ "inttype": 1,
+ "strtype": "published",
+ "booltype": True,
+ "published_action_param": action_parameters["action_param_1"],
+ }
mock_args, _ = request.call_args
self.assertEqual(mock_args[0].parameters, expected_value)
# Assert that the variables are correctly published
- self.assertEqual(result['published'],
- {'published_action_param': u'test value 1', 'o1': u'published'})
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ self.assertEqual(
+ result["published"],
+ {"published_action_param": "test value 1", "o1": "published"},
+ )
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_publish_param_rendering_failure(self, request):
# Parameter rendering should result in a top level error which aborts
# the whole chain
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
@@ -685,16 +871,21 @@ def test_chain_runner_publish_param_rendering_failure(self, request):
# TODO: Should we treat this as task error? Right now it bubbles all
# the way up and it's not really consistent with action param
# rendering failure
- expected_error = ('Failed rendering value for publish parameter "p1" in '
- 'task "c2" (template string={{ not_defined }}):')
+ expected_error = (
+ 'Failed rendering value for publish parameter "p1" in '
+ 'task "c2" (template string={{ not_defined }}):'
+ )
self.assertIn(expected_error, six.text_type(e))
pass
else:
- self.fail('Exception was not thrown')
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ self.fail("Exception was not thrown")
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_ACTION_INVALID_PARAMETER_TYPE
@@ -702,48 +893,72 @@ def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request):
chain_runner.pre_run()
action_parameters = {}
- expected_msg = (r'Failed to cast value "stringnotanarray" \(type: str\) for parameter '
- r'"arrtype" of type "array"')
- self.assertRaisesRegexp(ValueError, expected_msg, chain_runner.run,
- action_parameters=action_parameters)
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=None))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(result={'raw_out': 'published'}), None))
+ expected_msg = (
+ r'Failed to cast value "stringnotanarray" \(type: str\) for parameter '
+ r'"arrtype" of type "array"'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ chain_runner.run,
+ action_parameters=action_parameters,
+ )
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=None)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(result={"raw_out": "published"}), None),
+ )
def test_action_chain_runner_referenced_action_doesnt_exist(self, mock_request):
# Action referenced by a task doesn't exist, should result in a top level error
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_INVALID_ACTION
chain_runner.action = ACTION_2
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
action_parameters = {}
status, output, _ = chain_runner.run(action_parameters=action_parameters)
- expected_error = ('Failed to run task "c1". Action with reference "wolfpack.a2" '
- 'doesn\'t exist.')
+ expected_error = (
+ 'Failed to run task "c1". Action with reference "wolfpack.a2" '
+ "doesn't exist."
+ )
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
- self.assertIn(expected_error, output['error'])
- self.assertIn('Traceback', output['traceback'])
+ self.assertIn(expected_error, output["error"])
+ self.assertIn("Traceback", output["traceback"])
- def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided(self):
+ def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided(
+ self,
+ ):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE
chain_runner.action = ACTION_2
- expected_msg = ('Either "params" or "parameters" attribute needs to be provided, but '
- 'not both')
- self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, expected_msg,
- chain_runner.pre_run)
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ expected_msg = (
+ 'Either "params" or "parameters" attribute needs to be provided, but '
+ "not both"
+ )
+ self.assertRaisesRegexp(
+ runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run
+ )
+
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_params_and_parameters_attributes_both_work(self, _):
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
# "params" attribute used
chain_runner = acr.get_runner()
@@ -756,10 +971,12 @@ def test_params_and_parameters_attributes_both_work(self, _):
def mock_build_liveaction_object(action_node, resolved_params, parent_context):
# Verify parameters are correctly passed to the action
- self.assertEqual(resolved_params, {'pparams': 'v1'})
- original_build_liveaction_object(action_node=action_node,
- resolved_params=resolved_params,
- parent_context=parent_context)
+ self.assertEqual(resolved_params, {"pparams": "v1"})
+ original_build_liveaction_object(
+ action_node=action_node,
+ resolved_params=resolved_params,
+ parent_context=parent_context,
+ )
chain_runner._build_liveaction_object = mock_build_liveaction_object
@@ -776,10 +993,12 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context):
def mock_build_liveaction_object(action_node, resolved_params, parent_context):
# Verify parameters are correctly passed to the action
- self.assertEqual(resolved_params, {'pparameters': 'v1'})
- original_build_liveaction_object(action_node=action_node,
- resolved_params=resolved_params,
- parent_context=parent_context)
+ self.assertEqual(resolved_params, {"pparameters": "v1"})
+ original_build_liveaction_object(
+ action_node=action_node,
+ resolved_params=resolved_params,
+ parent_context=parent_context,
+ )
chain_runner._build_liveaction_object = mock_build_liveaction_object
@@ -787,21 +1006,27 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context):
status, output, _ = chain_runner.run(action_parameters=action_parameters)
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_2))
- @mock.patch.object(action_service, 'request',
- return_value=(DummyActionExecution(result={'raw_out': 'published'}), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2)
+ )
+ @mock.patch.object(
+ action_service,
+ "request",
+ return_value=(DummyActionExecution(result={"raw_out": "published"}), None),
+ )
def test_display_published_is_true_by_default(self, _):
- action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_2.name, pack=ACTION_2.pack
+ )
expected_published_values = {
- 't1_publish_param_1': 'foo1',
- 't1_publish_param_2': 'foo2',
- 't1_publish_param_3': 'foo3',
- 't2_publish_param_1': 'foo4',
- 't2_publish_param_2': 'foo5',
- 't2_publish_param_3': 'foo6',
- 'publish_last_wins': 'bar_last',
+ "t1_publish_param_1": "foo1",
+ "t1_publish_param_2": "foo2",
+ "t1_publish_param_3": "foo3",
+ "t2_publish_param_1": "foo4",
+ "t2_publish_param_2": "foo5",
+ "t2_publish_param_3": "foo6",
+ "publish_last_wins": "bar_last",
}
# 1. display_published is True by default
@@ -816,35 +1041,35 @@ def test_display_published_is_true_by_default(self, _):
_, result, _ = chain_runner.run(action_parameters=action_parameters)
# Assert that the variables are correctly published
- self.assertEqual(result['published'], expected_published_values)
+ self.assertEqual(result["published"], expected_published_values)
# 2. display_published is True by default so end result should be the same
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_PUBLISH_2
chain_runner.action = ACTION_2
chain_runner.liveaction = LiveActionDB(action=action_ref)
- chain_runner.runner_parameters = {'display_published': True}
+ chain_runner.runner_parameters = {"display_published": True}
chain_runner.pre_run()
action_parameters = {}
_, result, _ = chain_runner.run(action_parameters=action_parameters)
# Assert that the variables are correctly published
- self.assertEqual(result['published'], expected_published_values)
+ self.assertEqual(result["published"], expected_published_values)
# 3. display_published is disabled
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_WITH_PUBLISH_2
chain_runner.action = ACTION_2
chain_runner.liveaction = LiveActionDB(action=action_ref)
- chain_runner.runner_parameters = {'display_published': False}
+ chain_runner.runner_parameters = {"display_published": False}
chain_runner.pre_run()
action_parameters = {}
_, result, _ = chain_runner.run(action_parameters=action_parameters)
- self.assertNotIn('published', result)
- self.assertEqual(result.get('published', {}), {})
+ self.assertNotIn("published", result)
+ self.assertEqual(result.get("published", {}), {})
@classmethod
def tearDownClass(cls):
diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py
index 7bba3606d8..dca88cf803 100644
--- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py
+++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py
@@ -20,6 +20,7 @@
import tempfile
from st2tests import config as test_config
+
test_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -40,39 +41,25 @@
TEST_FIXTURES = {
- 'chains': [
- 'test_cancel.yaml',
- 'test_cancel_with_subworkflow.yaml'
- ],
- 'actions': [
- 'test_cancel.yaml',
- 'test_cancel_with_subworkflow.yaml'
- ]
+ "chains": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"],
+ "actions": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"],
}
-TEST_PACK = 'action_chain_tests'
-TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "action_chain_tests"
+TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
-PACKS = [
- TEST_PACK_PATH,
- fixturesloader.get_fixtures_packs_base_path() + '/core'
-]
+PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"]
-USERNAME = 'stanley'
+USERNAME = "stanley"
-@mock.patch.object(
- CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
-@mock.patch.object(
- CUDPublisher,
- 'publish_create',
- mock.MagicMock(return_value=None))
+@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None))
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
@mock.patch.object(
LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state),
+)
class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase):
temp_file_path = None
@@ -86,8 +73,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -98,7 +84,7 @@ def setUp(self):
# Create temporary directory used by the tests
_, self.temp_file_path = tempfile.mkstemp()
- os.chmod(self.temp_file_path, 0o755) # nosec
+ os.chmod(self.temp_file_path, 0o755) # nosec
def tearDown(self):
if self.temp_file_path and os.path.exists(self.temp_file_path):
@@ -110,7 +96,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100):
# Wait until the execution has children.
for i in range(0, retries):
execution = ActionExecution.get_by_id(str(execution.id))
- if len(getattr(execution, 'children', [])) <= 0:
+ if len(getattr(execution, "children", [])) <= 0:
eventlet.sleep(interval)
continue
@@ -123,34 +109,42 @@ def test_chain_cancel(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_cancel'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_cancel"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Request action chain to cancel.
- liveaction, execution = action_service.request_cancellation(liveaction, USERNAME)
+ liveaction, execution = action_service.request_cancellation(
+ liveaction, USERNAME
+ )
# Wait until the liveaction is canceling.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELING
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is canceled.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
def test_chain_cancel_cascade_to_subworkflow(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -159,14 +153,16 @@ def test_chain_cancel_cascade_to_subworkflow(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_cancel_with_subworkflow'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_cancel_with_subworkflow"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Wait for subworkflow to register.
execution = self._wait_for_children(execution)
@@ -174,44 +170,58 @@ def test_chain_cancel_cascade_to_subworkflow(self):
# Wait until the subworkflow is running.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Request action chain to cancel.
- liveaction, execution = action_service.request_cancellation(liveaction, USERNAME)
+ liveaction, execution = action_service.request_cancellation(
+ liveaction, USERNAME
+ )
# Wait until the liveaction is canceling.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELING
+ )
self.assertEqual(len(execution.children), 1)
# Wait until the subworkflow is canceling.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_CANCELING
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is canceled.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
self.assertEqual(len(execution.children), 1)
# Wait until the subworkflow is canceled.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_CANCELED
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 1)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 1)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED
+ )
def test_chain_cancel_cascade_to_parent_workflow(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -220,14 +230,16 @@ def test_chain_cancel_cascade_to_parent_workflow(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_cancel_with_subworkflow'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_cancel_with_subworkflow"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Wait for subworkflow to register.
execution = self._wait_for_children(execution)
@@ -235,16 +247,22 @@ def test_chain_cancel_cascade_to_parent_workflow(self):
# Wait until the subworkflow is running.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Request subworkflow to cancel.
- task1_live, task1_exec = action_service.request_cancellation(task1_live, USERNAME)
+ task1_live, task1_exec = action_service.request_cancellation(
+ task1_live, USERNAME
+ )
# Wait until the subworkflow is canceling.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_CANCELING
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
@@ -252,20 +270,26 @@ def test_chain_cancel_cascade_to_parent_workflow(self):
# Wait until the subworkflow is canceled.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_on_status(
+ task1_live, action_constants.LIVEACTION_STATUS_CANCELED
+ )
# Wait until the parent liveaction is canceled.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
self.assertEqual(len(execution.children), 1)
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 1)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 1)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED
+ )
diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py
index 193d6064a1..7997869b13 100644
--- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py
+++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py
@@ -27,51 +27,53 @@
class DummyActionExecution(object):
- def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''):
+ def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""):
self.id = None
self.status = status
self.result = result
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
-TEST_MODELS = {
- 'actions': ['a1.yaml', 'a2.yaml'],
- 'runners': ['testrunner1.yaml']
-}
+TEST_MODELS = {"actions": ["a1.yaml", "a2.yaml"], "runners": ["testrunner1.yaml"]}
-MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
-ACTION_1 = MODELS['actions']['a1.yaml']
-ACTION_2 = MODELS['actions']['a2.yaml']
-RUNNER = MODELS['runners']['testrunner1.yaml']
+MODELS = FixturesLoader().load_models(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+)
+ACTION_1 = MODELS["actions"]["a1.yaml"]
+ACTION_2 = MODELS["actions"]["a2.yaml"]
+RUNNER = MODELS["runners"]["testrunner1.yaml"]
CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs(
- FIXTURES_PACK, 'actionchains', 'chain_with_notifications.yaml')
+ FIXTURES_PACK, "actionchains", "chain_with_notifications.yaml"
+)
@mock.patch.object(
- action_db_util,
- 'get_runnertype_by_name',
- mock.MagicMock(return_value=RUNNER))
+ action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER)
+)
@mock.patch.object(
action_service,
- 'is_action_canceled_or_canceling',
- mock.MagicMock(return_value=False))
+ "is_action_canceled_or_canceling",
+ mock.MagicMock(return_value=False),
+)
@mock.patch.object(
- action_service,
- 'is_action_paused_or_pausing',
- mock.MagicMock(return_value=False))
+ action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False)
+)
class TestActionChainNotifications(ExecutionDbTestCase):
-
- @mock.patch.object(action_db_util, 'get_action_by_ref',
- mock.MagicMock(return_value=ACTION_1))
- @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None))
+ @mock.patch.object(
+ action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1)
+ )
+ @mock.patch.object(
+ action_service, "request", return_value=(DummyActionExecution(), None)
+ )
def test_chain_runner_success_path(self, request):
chain_runner = acr.get_runner()
chain_runner.entry_point = CHAIN_1_PATH
chain_runner.action = ACTION_1
- action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack)
+ action_ref = ResourceReference.to_string_reference(
+ name=ACTION_1.name, pack=ACTION_1.pack
+ )
chain_runner.liveaction = LiveActionDB(action=action_ref)
chain_runner.pre_run()
chain_runner.run({})
@@ -79,8 +81,8 @@ def test_chain_runner_success_path(self, request):
self.assertEqual(request.call_count, 2)
first_call_args = request.call_args_list[0][0]
liveaction_db = first_call_args[0]
- self.assertTrue(liveaction_db.notify, 'Notify property expected.')
+ self.assertTrue(liveaction_db.notify, "Notify property expected.")
second_call_args = request.call_args_list[1][0]
liveaction_db = second_call_args[0]
- self.assertFalse(liveaction_db.notify, 'Notify property not expected.')
+ self.assertFalse(liveaction_db.notify, "Notify property not expected.")
diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py
index 6fa4c6b456..d6278ca61a 100644
--- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py
+++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py
@@ -25,96 +25,96 @@
class ActionChainRunnerResolveParamsTests(unittest2.TestCase):
-
def test_render_params_action_context(self):
runner = acr.get_runner()
chain_context = {
- 'parent': {
- 'execution_id': 'some_awesome_exec_id',
- 'user': 'dad'
- },
- 'user': 'son',
- 'k1': 'v1'
+ "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"},
+ "user": "son",
+ "k1": "v1",
}
task_params = {
- 'exec_id': {'default': '{{action_context.parent.execution_id}}'},
- 'k2': {},
- 'foo': {'default': 1}
+ "exec_id": {"default": "{{action_context.parent.execution_id}}"},
+ "k2": {},
+ "foo": {"default": 1},
}
- action_node = Node(name='test_action_context_params', ref='core.local', params=task_params)
+ action_node = Node(
+ name="test_action_context_params", ref="core.local", params=task_params
+ )
rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context)
- self.assertEqual(rendered_params['exec_id']['default'], 'some_awesome_exec_id')
+ self.assertEqual(rendered_params["exec_id"]["default"], "some_awesome_exec_id")
def test_render_params_action_context_non_existent_member(self):
runner = acr.get_runner()
chain_context = {
- 'parent': {
- 'execution_id': 'some_awesome_exec_id',
- 'user': 'dad'
- },
- 'user': 'son',
- 'k1': 'v1'
+ "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"},
+ "user": "son",
+ "k1": "v1",
}
task_params = {
- 'exec_id': {'default': '{{action_context.parent.yo_gimme_tha_key}}'},
- 'k2': {},
- 'foo': {'default': 1}
+ "exec_id": {"default": "{{action_context.parent.yo_gimme_tha_key}}"},
+ "k2": {},
+ "foo": {"default": 1},
}
- action_node = Node(name='test_action_context_params', ref='core.local', params=task_params)
+ action_node = Node(
+ name="test_action_context_params", ref="core.local", params=task_params
+ )
try:
runner._resolve_params(action_node, {}, {}, {}, chain_context)
- self.fail('Should have thrown an instance of %s' % ParameterRenderingFailedException)
+ self.fail(
+ "Should have thrown an instance of %s"
+ % ParameterRenderingFailedException
+ )
except ParameterRenderingFailedException:
pass
def test_render_params_with_config(self):
- with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader:
+ with mock.patch(
+ "st2common.util.config_loader.ContentPackConfigLoader"
+ ) as config_loader:
config_loader().get_config.return_value = {
- 'amazing_config_value_fo_lyfe': 'no'
+ "amazing_config_value_fo_lyfe": "no"
}
runner = acr.get_runner()
chain_context = {
- 'parent': {
- 'execution_id': 'some_awesome_exec_id',
- 'user': 'dad',
- 'pack': 'mom'
+ "parent": {
+ "execution_id": "some_awesome_exec_id",
+ "user": "dad",
+ "pack": "mom",
},
- 'user': 'son',
+ "user": "son",
}
task_params = {
- 'config_val': '{{config_context.amazing_config_value_fo_lyfe}}'
+ "config_val": "{{config_context.amazing_config_value_fo_lyfe}}"
}
action_node = Node(
- name='test_action_context_params',
- ref='core.local',
- params=task_params
+ name="test_action_context_params", ref="core.local", params=task_params
+ )
+ rendered_params = runner._resolve_params(
+ action_node, {}, {}, {}, chain_context
)
- rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context)
- self.assertEqual(rendered_params['config_val'], 'no')
+ self.assertEqual(rendered_params["config_val"], "no")
def test_init_params_vars_with_unicode_value(self):
chain_spec = {
- 'vars': {
- 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž',
- 'unicode_var_param': u'{{ param }}'
+ "vars": {
+ "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž",
+ "unicode_var_param": "{{ param }}",
},
- 'chain': [
+ "chain": [
{
- 'name': 'c1',
- 'ref': 'core.local',
- 'parameters': {
- 'cmd': 'echo {{ unicode_var }}'
- }
+ "name": "c1",
+ "ref": "core.local",
+ "parameters": {"cmd": "echo {{ unicode_var }}"},
}
- ]
+ ],
}
- chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname='foo')
- chain_holder.init_vars(action_parameters={'param': u'٩(̾●̮̮̃̾•̃̾)۶'})
+ chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname="foo")
+ chain_holder.init_vars(action_parameters={"param": "٩(̾●̮̮̃̾•̃̾)۶"})
expected = {
- 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž',
- 'unicode_var_param': u'٩(̾●̮̮̃̾•̃̾)۶'
+ "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž",
+ "unicode_var_param": "٩(̾●̮̮̃̾•̃̾)۶",
}
self.assertEqual(chain_holder.vars, expected)
diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py
index c093c2061c..46f948d73a 100644
--- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py
+++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py
@@ -20,6 +20,7 @@
import tempfile
from st2tests import config as test_config
+
test_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -42,53 +43,45 @@
TEST_FIXTURES = {
- 'chains': [
- 'test_pause_resume.yaml',
- 'test_pause_resume_context_result',
- 'test_pause_resume_with_published_vars.yaml',
- 'test_pause_resume_with_error.yaml',
- 'test_pause_resume_with_subworkflow.yaml',
- 'test_pause_resume_with_context_access.yaml',
- 'test_pause_resume_with_init_vars.yaml',
- 'test_pause_resume_with_no_more_task.yaml',
- 'test_pause_resume_last_task_failed_with_no_next_task.yaml'
+ "chains": [
+ "test_pause_resume.yaml",
+ "test_pause_resume_context_result",
+ "test_pause_resume_with_published_vars.yaml",
+ "test_pause_resume_with_error.yaml",
+ "test_pause_resume_with_subworkflow.yaml",
+ "test_pause_resume_with_context_access.yaml",
+ "test_pause_resume_with_init_vars.yaml",
+ "test_pause_resume_with_no_more_task.yaml",
+ "test_pause_resume_last_task_failed_with_no_next_task.yaml",
+ ],
+ "actions": [
+ "test_pause_resume.yaml",
+ "test_pause_resume_context_result",
+ "test_pause_resume_with_published_vars.yaml",
+ "test_pause_resume_with_error.yaml",
+ "test_pause_resume_with_subworkflow.yaml",
+ "test_pause_resume_with_context_access.yaml",
+ "test_pause_resume_with_init_vars.yaml",
+ "test_pause_resume_with_no_more_task.yaml",
+ "test_pause_resume_last_task_failed_with_no_next_task.yaml",
],
- 'actions': [
- 'test_pause_resume.yaml',
- 'test_pause_resume_context_result',
- 'test_pause_resume_with_published_vars.yaml',
- 'test_pause_resume_with_error.yaml',
- 'test_pause_resume_with_subworkflow.yaml',
- 'test_pause_resume_with_context_access.yaml',
- 'test_pause_resume_with_init_vars.yaml',
- 'test_pause_resume_with_no_more_task.yaml',
- 'test_pause_resume_last_task_failed_with_no_next_task.yaml'
- ]
}
-TEST_PACK = 'action_chain_tests'
-TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "action_chain_tests"
+TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
-PACKS = [
- TEST_PACK_PATH,
- fixturesloader.get_fixtures_packs_base_path() + '/core'
-]
+PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"]
-USERNAME = 'stanley'
+USERNAME = "stanley"
-@mock.patch.object(
- CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
-@mock.patch.object(
- CUDPublisher,
- 'publish_create',
- mock.MagicMock(return_value=None))
+@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None))
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
@mock.patch.object(
LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state),
+)
class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase):
temp_file_path = None
@@ -102,8 +95,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -114,7 +106,7 @@ def setUp(self):
# Create temporary directory used by the tests
_, self.temp_file_path = tempfile.mkstemp()
- os.chmod(self.temp_file_path, 0o755) # nosec
+ os.chmod(self.temp_file_path, 0o755) # nosec
def tearDown(self):
if self.temp_file_path and os.path.exists(self.temp_file_path):
@@ -138,7 +130,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100):
# Wait until the execution has children.
for i in range(0, retries):
execution = ActionExecution.get_by_id(str(execution.id))
- if len(getattr(execution, 'children', [])) <= 0:
+ if len(getattr(execution, "children", [])) <= 0:
eventlet.sleep(interval)
continue
@@ -151,32 +143,42 @@ def test_chain_pause_resume(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -185,15 +187,19 @@ def test_chain_pause_resume(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
def test_chain_pause_resume_with_published_vars(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -202,32 +208,42 @@ def test_chain_pause_resume_with_published_vars(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_published_vars"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -236,17 +252,23 @@ def test_chain_pause_resume_with_published_vars(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
- self.assertIn('published', liveaction.result)
- self.assertDictEqual({'var1': 'foobar', 'var2': 'fubar'}, liveaction.result['published'])
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
+ self.assertIn("published", liveaction.result)
+ self.assertDictEqual(
+ {"var1": "foobar", "var2": "fubar"}, liveaction.result["published"]
+ )
def test_chain_pause_resume_with_published_vars_display_false(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -255,32 +277,42 @@ def test_chain_pause_resume_with_published_vars_display_false(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars'
- params = {'tempfile': path, 'message': 'foobar', 'display_published': False}
+ action = TEST_PACK + "." + "test_pause_resume_with_published_vars"
+ params = {"tempfile": path, "message": "foobar", "display_published": False}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -289,16 +321,20 @@ def test_chain_pause_resume_with_published_vars_display_false(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
- self.assertNotIn('published', liveaction.result)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
+ self.assertNotIn("published", liveaction.result)
def test_chain_pause_resume_with_error(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -307,32 +343,42 @@ def test_chain_pause_resume_with_error(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_error'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_error"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -341,19 +387,23 @@ def test_chain_pause_resume_with_error(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
- self.assertTrue(liveaction.result['tasks'][0]['result']['failed'])
- self.assertEqual(1, liveaction.result['tasks'][0]['result']['return_code'])
- self.assertTrue(liveaction.result['tasks'][1]['result']['succeeded'])
- self.assertEqual(0, liveaction.result['tasks'][1]['result']['return_code'])
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
+ self.assertTrue(liveaction.result["tasks"][0]["result"]["failed"])
+ self.assertEqual(1, liveaction.result["tasks"][0]["result"]["return_code"])
+ self.assertTrue(liveaction.result["tasks"][1]["result"]["succeeded"])
+ self.assertEqual(0, liveaction.result["tasks"][1]["result"]["return_code"])
def test_chain_pause_resume_cascade_to_subworkflow(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -362,14 +412,16 @@ def test_chain_pause_resume_cascade_to_subworkflow(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_subworkflow"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Wait for subworkflow to register.
@@ -378,71 +430,97 @@ def test_chain_pause_resume_cascade_to_subworkflow(self):
# Wait until the subworkflow is running.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
self.assertEqual(len(execution.children), 1)
# Wait until the subworkflow is pausing.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(task1_live)
- self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
self.assertEqual(len(execution.children), 1)
# Wait until the subworkflow is paused.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(task1_live)
- self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 1)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 1)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED
+ )
# Request action chain to resume.
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 2)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 2)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
def test_chain_pause_resume_cascade_to_parent_workflow(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -451,14 +529,16 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_subworkflow"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Wait for subworkflow to register.
@@ -467,8 +547,10 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self):
# Wait until the subworkflow is running.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request subworkflow to pause.
@@ -476,10 +558,14 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self):
# Wait until the subworkflow is pausing.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(task1_live)
- self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
@@ -487,39 +573,55 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self):
# Wait until the subworkflow is paused.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(task1_live)
- self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait until the parent liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
self.assertEqual(len(execution.children), 1)
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 1)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 1)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED
+ )
# Request subworkflow to resume.
task1_live, task1_exec = action_service.request_resume(task1_live, USERNAME)
# Wait until the subworkflow is paused.
task1_exec = ActionExecution.get_by_id(execution.children[0])
- task1_live = LiveAction.get_by_id(task1_exec.liveaction['id'])
- task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"])
+ task1_live = self._wait_for_status(
+ task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# The parent workflow will stay paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED)
# Wait for non-blocking threads to complete.
@@ -527,30 +629,38 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self):
# Check liveaction result of the parent, which should stay the same
# because only the subworkflow was resumed.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 1)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 1)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED
+ )
# Request parent workflow to resume.
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
- subworkflow = liveaction.result['tasks'][0]
- self.assertEqual(len(subworkflow['result']['tasks']), 2)
- self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ subworkflow = liveaction.result["tasks"][0]
+ self.assertEqual(len(subworkflow["result"]["tasks"]), 2)
+ self.assertEqual(
+ subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
def test_chain_pause_resume_with_context_access(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -559,32 +669,42 @@ def test_chain_pause_resume_with_context_access(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_context_access'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_context_access"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -593,16 +713,20 @@ def test_chain_pause_resume_with_context_access(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 3)
- self.assertEqual(liveaction.result['tasks'][2]['result']['stdout'], 'foobar')
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 3)
+ self.assertEqual(liveaction.result["tasks"][2]["result"]["stdout"], "foobar")
def test_chain_pause_resume_with_init_vars(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -611,32 +735,42 @@ def test_chain_pause_resume_with_init_vars(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_init_vars'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_init_vars"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -645,16 +779,20 @@ def test_chain_pause_resume_with_init_vars(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
- self.assertEqual(liveaction.result['tasks'][1]['result']['stdout'], 'FOOBAR')
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
+ self.assertEqual(liveaction.result["tasks"][1]["result"]["stdout"], "FOOBAR")
def test_chain_pause_resume_with_no_more_task(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -663,32 +801,42 @@ def test_chain_pause_resume_with_no_more_task(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_with_no_more_task'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = TEST_PACK + "." + "test_pause_resume_with_no_more_task"
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -697,15 +845,19 @@ def test_chain_pause_resume_with_no_more_task(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
def test_chain_pause_resume_last_task_failed_with_no_next_task(self):
# A temp file is created during test setup. Ensure the temp file exists.
@@ -714,32 +866,44 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self):
path = self.temp_file_path
self.assertTrue(os.path.exists(path))
- action = TEST_PACK + '.' + 'test_pause_resume_last_task_failed_with_no_next_task'
- params = {'tempfile': path, 'message': 'foobar'}
+ action = (
+ TEST_PACK + "." + "test_pause_resume_last_task_failed_with_no_next_task"
+ )
+ params = {"tempfile": path, "message": "foobar"}
liveaction = LiveActionDB(action=action, parameters=params)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is running.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request action chain to pause.
liveaction, execution = action_service.request_pause(liveaction, USERNAME)
# Wait until the liveaction is pausing.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSING
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info
+ )
# Delete the temporary file that the action chain is waiting on.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
@@ -748,62 +912,70 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self):
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_FAILED
+ )
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_FAILED)
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 1)
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 1)
self.assertEqual(
- liveaction.result['tasks'][0]['state'],
- action_constants.LIVEACTION_STATUS_FAILED
+ liveaction.result["tasks"][0]["state"],
+ action_constants.LIVEACTION_STATUS_FAILED,
)
def test_chain_pause_resume_status_change(self):
# Tests context_result is updated when last task's status changes between pause and resume
- action = TEST_PACK + '.' + 'test_pause_resume_context_result'
+ action = TEST_PACK + "." + "test_pause_resume_context_result"
liveaction = LiveActionDB(action=action)
liveaction, execution = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
# Wait until the liveaction is paused.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
extra_info = str(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info
+ )
# Wait for non-blocking threads to complete. Ensure runner is not running.
MockLiveActionPublisherNonBlocking.wait_all()
- last_task_liveaction_id = liveaction.result['tasks'][-1]['liveaction_id']
+ last_task_liveaction_id = liveaction.result["tasks"][-1]["liveaction_id"]
action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_SUCCEEDED,
end_timestamp=date_utils.get_datetime_utc_now(),
- result={'foo': 'bar'},
- liveaction_id=last_task_liveaction_id
+ result={"foo": "bar"},
+ liveaction_id=last_task_liveaction_id,
)
# Request action chain to resume.
liveaction, execution = action_service.request_resume(liveaction, USERNAME)
# Wait until the liveaction is completed.
- liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_for_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
self.assertEqual(
liveaction.status,
action_constants.LIVEACTION_STATUS_SUCCEEDED,
- str(liveaction)
+ str(liveaction),
)
# Wait for non-blocking threads to complete.
MockLiveActionPublisherNonBlocking.wait_all()
# Check liveaction result.
- self.assertIn('tasks', liveaction.result)
- self.assertEqual(len(liveaction.result['tasks']), 2)
- self.assertEqual(liveaction.result['tasks'][0]['result']['foo'], 'bar')
+ self.assertIn("tasks", liveaction.result)
+ self.assertEqual(len(liveaction.result["tasks"]), 2)
+ self.assertEqual(liveaction.result["tasks"][0]["result"]["foo"], "bar")
diff --git a/contrib/runners/announcement_runner/announcement_runner/__init__.py b/contrib/runners/announcement_runner/announcement_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/announcement_runner/announcement_runner/__init__.py
+++ b/contrib/runners/announcement_runner/announcement_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py
index 6d219f2819..4782544c3c 100644
--- a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py
+++ b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py
@@ -24,12 +24,7 @@
from st2common.models.api.trace import TraceContext
from st2common.transport.announcement import AnnouncementDispatcher
-__all__ = [
- 'AnnouncementRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["AnnouncementRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
@@ -42,28 +37,28 @@ def __init__(self, runner_id):
def pre_run(self):
super(AnnouncementRunner, self).pre_run()
- LOG.debug('Entering AnnouncementRunner.pre_run() for liveaction_id="%s"',
- self.liveaction_id)
+ LOG.debug(
+ 'Entering AnnouncementRunner.pre_run() for liveaction_id="%s"',
+ self.liveaction_id,
+ )
- if not self.runner_parameters.get('experimental'):
- message = ('Experimental flag is missing for action %s' % self.action.ref)
- LOG.exception('Experimental runner is called without experimental flag.')
+ if not self.runner_parameters.get("experimental"):
+ message = "Experimental flag is missing for action %s" % self.action.ref
+ LOG.exception("Experimental runner is called without experimental flag.")
raise runnerexceptions.ActionRunnerPreRunError(message)
- self._route = self.runner_parameters.get('route')
+ self._route = self.runner_parameters.get("route")
def run(self, action_parameters):
- trace_context = self.liveaction.context.get('trace_context', None)
+ trace_context = self.liveaction.context.get("trace_context", None)
if trace_context:
trace_context = TraceContext(**trace_context)
- self._dispatcher.dispatch(self._route,
- payload=action_parameters,
- trace_context=trace_context)
+ self._dispatcher.dispatch(
+ self._route, payload=action_parameters, trace_context=trace_context
+ )
- result = {
- "output": action_parameters
- }
+ result = {"output": action_parameters}
result.update(action_parameters)
return (LIVEACTION_STATUS_SUCCEEDED, result, None)
@@ -74,4 +69,4 @@ def get_runner():
def get_metadata():
- return get_runner_metadata('announcement_runner')[0]
+ return get_runner_metadata("announcement_runner")[0]
diff --git a/contrib/runners/announcement_runner/dist_utils.py b/contrib/runners/announcement_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/announcement_runner/dist_utils.py
+++ b/contrib/runners/announcement_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/announcement_runner/setup.py b/contrib/runners/announcement_runner/setup.py
index efd60b14af..a72469ffea 100644
--- a/contrib/runners/announcement_runner/setup.py
+++ b/contrib/runners/announcement_runner/setup.py
@@ -26,30 +26,32 @@
from announcement_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-announcement',
+ name="stackstorm-runner-announcement",
version=__version__,
- description=('Announcement action runner for StackStorm event-driven automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "Announcement action runner for StackStorm event-driven automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'announcement_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"announcement_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'announcement = announcement_runner.announcement_runner',
+ "st2common.runners.runner": [
+ "announcement = announcement_runner.announcement_runner",
],
- }
+ },
)
diff --git a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py
index cc9c541015..9ad56a2115 100644
--- a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py
+++ b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py
@@ -26,69 +26,63 @@
mock_dispatcher = mock.Mock()
-@mock.patch('st2common.transport.announcement.AnnouncementDispatcher.dispatch')
+@mock.patch("st2common.transport.announcement.AnnouncementDispatcher.dispatch")
class AnnouncementRunnerTestCase(RunnerTestCase):
-
@classmethod
def setUpClass(cls):
tests_config.parse_args()
def test_runner_creation(self, dispatch):
runner = announcement_runner.get_runner()
- self.assertIsNotNone(runner, 'Creation failed. No instance.')
- self.assertEqual(type(runner), announcement_runner.AnnouncementRunner,
- 'Creation failed. No instance.')
+ self.assertIsNotNone(runner, "Creation failed. No instance.")
+ self.assertEqual(
+ type(runner),
+ announcement_runner.AnnouncementRunner,
+ "Creation failed. No instance.",
+ )
self.assertEqual(runner._dispatcher.dispatch, dispatch)
def test_announcement(self, dispatch):
runner = announcement_runner.get_runner()
- runner.runner_parameters = {
- 'experimental': True,
- 'route': 'general'
- }
+ runner.runner_parameters = {"experimental": True, "route": "general"}
runner.liveaction = mock.Mock(context={})
runner.pre_run()
- (status, result, _) = runner.run({'test': 'passed'})
+ (status, result, _) = runner.run({"test": "passed"})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(result)
- self.assertEqual(result['test'], 'passed')
- dispatch.assert_called_once_with('general', payload={'test': 'passed'},
- trace_context=None)
+ self.assertEqual(result["test"], "passed")
+ dispatch.assert_called_once_with(
+ "general", payload={"test": "passed"}, trace_context=None
+ )
def test_announcement_no_experimental(self, dispatch):
runner = announcement_runner.get_runner()
- runner.action = mock.Mock(ref='some.thing')
- runner.runner_parameters = {
- 'route': 'general'
- }
+ runner.action = mock.Mock(ref="some.thing")
+ runner.runner_parameters = {"route": "general"}
runner.liveaction = mock.Mock(context={})
- expected_msg = 'Experimental flag is missing for action some.thing'
+ expected_msg = "Experimental flag is missing for action some.thing"
self.assertRaisesRegexp(Exception, expected_msg, runner.pre_run)
- @mock.patch('st2common.models.api.trace.TraceContext.__new__')
+ @mock.patch("st2common.models.api.trace.TraceContext.__new__")
def test_announcement_with_trace(self, context, dispatch):
runner = announcement_runner.get_runner()
- runner.runner_parameters = {
- 'experimental': True,
- 'route': 'general'
- }
- runner.liveaction = mock.Mock(context={
- 'trace_context': {
- 'id_': 'a',
- 'trace_tag': 'b'
- }
- })
+ runner.runner_parameters = {"experimental": True, "route": "general"}
+ runner.liveaction = mock.Mock(
+ context={"trace_context": {"id_": "a", "trace_tag": "b"}}
+ )
runner.pre_run()
- (status, result, _) = runner.run({'test': 'passed'})
+ (status, result, _) = runner.run({"test": "passed"})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(result)
- self.assertEqual(result['test'], 'passed')
- context.assert_called_once_with(TraceContext,
- **runner.liveaction.context['trace_context'])
- dispatch.assert_called_once_with('general', payload={'test': 'passed'},
- trace_context=context.return_value)
+ self.assertEqual(result["test"], "passed")
+ context.assert_called_once_with(
+ TraceContext, **runner.liveaction.context["trace_context"]
+ )
+ dispatch.assert_called_once_with(
+ "general", payload={"test": "passed"}, trace_context=context.return_value
+ )
diff --git a/contrib/runners/http_runner/dist_utils.py b/contrib/runners/http_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/http_runner/dist_utils.py
+++ b/contrib/runners/http_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/http_runner/http_runner/__init__.py b/contrib/runners/http_runner/http_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/http_runner/http_runner/__init__.py
+++ b/contrib/runners/http_runner/http_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/http_runner/http_runner/http_runner.py b/contrib/runners/http_runner/http_runner/http_runner.py
index b2ff115fc6..6a02c809b9 100644
--- a/contrib/runners/http_runner/http_runner/http_runner.py
+++ b/contrib/runners/http_runner/http_runner/http_runner.py
@@ -35,45 +35,36 @@
import six
from six.moves import range
-__all__ = [
- 'HttpRunner',
-
- 'HTTPClient',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["HttpRunner", "HTTPClient", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
SUCCESS_STATUS_CODES = [code for code in range(200, 207)]
# Lookup constants for runner params
-RUNNER_ON_BEHALF_USER = 'user'
-RUNNER_URL = 'url'
-RUNNER_HEADERS = 'headers' # Debatable whether this should be action params.
-RUNNER_COOKIES = 'cookies'
-RUNNER_ALLOW_REDIRECTS = 'allow_redirects'
-RUNNER_HTTP_PROXY = 'http_proxy'
-RUNNER_HTTPS_PROXY = 'https_proxy'
-RUNNER_VERIFY_SSL_CERT = 'verify_ssl_cert'
-RUNNER_USERNAME = 'username'
-RUNNER_PASSWORD = 'password'
-RUNNER_URL_HOSTS_BLACKLIST = 'url_hosts_blacklist'
-RUNNER_URL_HOSTS_WHITELIST = 'url_hosts_whitelist'
+RUNNER_ON_BEHALF_USER = "user"
+RUNNER_URL = "url"
+RUNNER_HEADERS = "headers" # Debatable whether this should be action params.
+RUNNER_COOKIES = "cookies"
+RUNNER_ALLOW_REDIRECTS = "allow_redirects"
+RUNNER_HTTP_PROXY = "http_proxy"
+RUNNER_HTTPS_PROXY = "https_proxy"
+RUNNER_VERIFY_SSL_CERT = "verify_ssl_cert"
+RUNNER_USERNAME = "username"
+RUNNER_PASSWORD = "password"
+RUNNER_URL_HOSTS_BLACKLIST = "url_hosts_blacklist"
+RUNNER_URL_HOSTS_WHITELIST = "url_hosts_whitelist"
# Lookup constants for action params
-ACTION_AUTH = 'auth'
-ACTION_BODY = 'body'
-ACTION_TIMEOUT = 'timeout'
-ACTION_METHOD = 'method'
-ACTION_QUERY_PARAMS = 'params'
-FILE_NAME = 'file_name'
-FILE_CONTENT = 'file_content'
-FILE_CONTENT_TYPE = 'file_content_type'
+ACTION_AUTH = "auth"
+ACTION_BODY = "body"
+ACTION_TIMEOUT = "timeout"
+ACTION_METHOD = "method"
+ACTION_QUERY_PARAMS = "params"
+FILE_NAME = "file_name"
+FILE_CONTENT = "file_content"
+FILE_CONTENT_TYPE = "file_content_type"
-RESPONSE_BODY_PARSE_FUNCTIONS = {
- 'application/json': json.loads
-}
+RESPONSE_BODY_PARSE_FUNCTIONS = {"application/json": json.loads}
class HttpRunner(ActionRunner):
@@ -85,37 +76,48 @@ def __init__(self, runner_id):
def pre_run(self):
super(HttpRunner, self).pre_run()
- LOG.debug('Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id)
- self._on_behalf_user = self.runner_parameters.get(RUNNER_ON_BEHALF_USER,
- self._on_behalf_user)
+ LOG.debug(
+ 'Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id
+ )
+ self._on_behalf_user = self.runner_parameters.get(
+ RUNNER_ON_BEHALF_USER, self._on_behalf_user
+ )
self._url = self.runner_parameters.get(RUNNER_URL, None)
self._headers = self.runner_parameters.get(RUNNER_HEADERS, {})
self._cookies = self.runner_parameters.get(RUNNER_COOKIES, None)
- self._allow_redirects = self.runner_parameters.get(RUNNER_ALLOW_REDIRECTS, False)
+ self._allow_redirects = self.runner_parameters.get(
+ RUNNER_ALLOW_REDIRECTS, False
+ )
self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
self._http_proxy = self.runner_parameters.get(RUNNER_HTTP_PROXY, None)
self._https_proxy = self.runner_parameters.get(RUNNER_HTTPS_PROXY, None)
self._verify_ssl_cert = self.runner_parameters.get(RUNNER_VERIFY_SSL_CERT, None)
- self._url_hosts_blacklist = self.runner_parameters.get(RUNNER_URL_HOSTS_BLACKLIST, [])
- self._url_hosts_whitelist = self.runner_parameters.get(RUNNER_URL_HOSTS_WHITELIST, [])
+ self._url_hosts_blacklist = self.runner_parameters.get(
+ RUNNER_URL_HOSTS_BLACKLIST, []
+ )
+ self._url_hosts_whitelist = self.runner_parameters.get(
+ RUNNER_URL_HOSTS_WHITELIST, []
+ )
def run(self, action_parameters):
client = self._get_http_client(action_parameters)
if self._url_hosts_blacklist and self._url_hosts_whitelist:
- msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
- 'exclusive. Only one should be provided.')
+ msg = (
+ '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
+ "exclusive. Only one should be provided."
+ )
raise ValueError(msg)
try:
result = client.run()
except requests.exceptions.Timeout as e:
- result = {'error': six.text_type(e)}
+ result = {"error": six.text_type(e)}
status = LIVEACTION_STATUS_TIMED_OUT
else:
- status = HttpRunner._get_result_status(result.get('status_code', None))
+ status = HttpRunner._get_result_status(result.get("status_code", None))
return (status, result, None)
@@ -132,8 +134,8 @@ def _get_http_client(self, action_parameters):
# Include our user agent and action name so requests can be tracked back
headers = copy.deepcopy(self._headers) if self._headers else {}
- headers['User-Agent'] = 'st2/v%s' % (st2_version)
- headers['X-Stanley-Action'] = self.action_name
+ headers["User-Agent"] = "st2/v%s" % (st2_version)
+ headers["X-Stanley-Action"] = self.action_name
if file_name and file_content:
files = {}
@@ -141,7 +143,7 @@ def _get_http_client(self, action_parameters):
if file_content_type:
value = (file_content, file_content_type)
else:
- value = (file_content)
+ value = file_content
files[file_name] = value
else:
@@ -150,43 +152,72 @@ def _get_http_client(self, action_parameters):
proxies = {}
if self._http_proxy:
- proxies['http'] = self._http_proxy
+ proxies["http"] = self._http_proxy
if self._https_proxy:
- proxies['https'] = self._https_proxy
-
- return HTTPClient(url=self._url, method=method, body=body, params=params,
- headers=headers, cookies=self._cookies, auth=auth,
- timeout=timeout, allow_redirects=self._allow_redirects,
- proxies=proxies, files=files, verify=self._verify_ssl_cert,
- username=self._username, password=self._password,
- url_hosts_blacklist=self._url_hosts_blacklist,
- url_hosts_whitelist=self._url_hosts_whitelist)
+ proxies["https"] = self._https_proxy
+
+ return HTTPClient(
+ url=self._url,
+ method=method,
+ body=body,
+ params=params,
+ headers=headers,
+ cookies=self._cookies,
+ auth=auth,
+ timeout=timeout,
+ allow_redirects=self._allow_redirects,
+ proxies=proxies,
+ files=files,
+ verify=self._verify_ssl_cert,
+ username=self._username,
+ password=self._password,
+ url_hosts_blacklist=self._url_hosts_blacklist,
+ url_hosts_whitelist=self._url_hosts_whitelist,
+ )
@staticmethod
def _get_result_status(status_code):
- return LIVEACTION_STATUS_SUCCEEDED if status_code in SUCCESS_STATUS_CODES \
+ return (
+ LIVEACTION_STATUS_SUCCEEDED
+ if status_code in SUCCESS_STATUS_CODES
else LIVEACTION_STATUS_FAILED
+ )
class HTTPClient(object):
- def __init__(self, url=None, method=None, body='', params=None, headers=None, cookies=None,
- auth=None, timeout=60, allow_redirects=False, proxies=None,
- files=None, verify=False, username=None, password=None,
- url_hosts_blacklist=None, url_hosts_whitelist=None):
+ def __init__(
+ self,
+ url=None,
+ method=None,
+ body="",
+ params=None,
+ headers=None,
+ cookies=None,
+ auth=None,
+ timeout=60,
+ allow_redirects=False,
+ proxies=None,
+ files=None,
+ verify=False,
+ username=None,
+ password=None,
+ url_hosts_blacklist=None,
+ url_hosts_whitelist=None,
+ ):
if url is None:
- raise Exception('URL must be specified.')
+ raise Exception("URL must be specified.")
if method is None:
if files or body:
- method = 'POST'
+ method = "POST"
else:
- method = 'GET'
+ method = "GET"
headers = headers or {}
normalized_headers = self._normalize_headers(headers=headers)
- if body and 'content-length' not in normalized_headers:
- headers['Content-Length'] = str(len(body))
+ if body and "content-length" not in normalized_headers:
+ headers["Content-Length"] = str(len(body))
self.url = url
self.method = method
@@ -207,8 +238,10 @@ def __init__(self, url=None, method=None, body='', params=None, headers=None, co
self.url_hosts_whitelist = url_hosts_whitelist or []
if self.url_hosts_blacklist and self.url_hosts_whitelist:
- msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
- 'exclusive. Only one should be provided.')
+ msg = (
+ '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
+ "exclusive. Only one should be provided."
+ )
raise ValueError(msg)
def run(self):
@@ -235,7 +268,7 @@ def run(self):
try:
data = json.dumps(data)
except ValueError:
- msg = 'Request body (%s) can\'t be parsed as JSON' % (data)
+ msg = "Request body (%s) can't be parsed as JSON" % (data)
raise ValueError(msg)
else:
data = self.body
@@ -245,7 +278,7 @@ def run(self):
# Ensure data is bytes since that what request expects
if isinstance(data, six.text_type):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
resp = requests.request(
self.method,
@@ -259,19 +292,19 @@ def run(self):
allow_redirects=self.allow_redirects,
proxies=self.proxies,
files=self.files,
- verify=self.verify
+ verify=self.verify,
)
headers = dict(resp.headers)
body, parsed = self._parse_response_body(headers=headers, body=resp.text)
- results['status_code'] = resp.status_code
- results['body'] = body
- results['parsed'] = parsed # flag which indicates if body has been parsed
- results['headers'] = headers
+ results["status_code"] = resp.status_code
+ results["body"] = body
+ results["parsed"] = parsed # flag which indicates if body has been parsed
+ results["headers"] = headers
return results
except Exception as e:
- LOG.exception('Exception making request to remote URL: %s, %s', self.url, e)
+ LOG.exception("Exception making request to remote URL: %s, %s", self.url, e)
raise
finally:
if resp:
@@ -285,27 +318,27 @@ def _parse_response_body(self, headers, body):
:return: (parsed body, flag which indicates if body has been parsed)
:rtype: (``object``, ``bool``)
"""
- body = body or ''
+ body = body or ""
headers = self._normalize_headers(headers=headers)
- content_type = headers.get('content-type', None)
+ content_type = headers.get("content-type", None)
parsed = False
if not content_type:
return (body, parsed)
# The header can also contain charset which we simply discard
- content_type = content_type.split(';')[0]
+ content_type = content_type.split(";")[0]
parse_func = RESPONSE_BODY_PARSE_FUNCTIONS.get(content_type, None)
if not parse_func:
return (body, parsed)
- LOG.debug('Parsing body with content type: %s', content_type)
+ LOG.debug("Parsing body with content type: %s", content_type)
try:
body = parse_func(body)
except Exception:
- LOG.exception('Failed to parse body')
+ LOG.exception("Failed to parse body")
else:
parsed = True
@@ -323,7 +356,7 @@ def _normalize_headers(self, headers):
def _is_json_content(self):
normalized = self._normalize_headers(self.headers)
- return normalized.get('content-type', None) == 'application/json'
+ return normalized.get("content-type", None) == "application/json"
def _cast_object(self, value):
if isinstance(value, str) or isinstance(value, six.text_type):
@@ -370,10 +403,10 @@ def _get_host_from_url(self, url):
parsed = urlparse.urlparse(url)
# Remove port and []
- host = parsed.netloc.replace('[', '').replace(']', '')
+ host = parsed.netloc.replace("[", "").replace("]", "")
if parsed.port is not None:
- host = host.replace(':%s' % (parsed.port), '')
+ host = host.replace(":%s" % (parsed.port), "")
return host
@@ -383,4 +416,4 @@ def get_runner():
def get_metadata():
- return get_runner_metadata('http_runner')[0]
+ return get_runner_metadata("http_runner")[0]
diff --git a/contrib/runners/http_runner/setup.py b/contrib/runners/http_runner/setup.py
index 2b962da599..2a5c9e217b 100644
--- a/contrib/runners/http_runner/setup.py
+++ b/contrib/runners/http_runner/setup.py
@@ -26,30 +26,32 @@
from http_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-http',
+ name="stackstorm-runner-http",
version=__version__,
- description=('HTTP(s) action runner for StackStorm event-driven automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "HTTP(s) action runner for StackStorm event-driven automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'http_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"http_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'http-request = http_runner.http_runner',
+ "st2common.runners.runner": [
+ "http-request = http_runner.http_runner",
],
- }
+ },
)
diff --git a/contrib/runners/http_runner/tests/unit/test_http_runner.py b/contrib/runners/http_runner/tests/unit/test_http_runner.py
index be64f6d420..9d2d99a7c1 100644
--- a/contrib/runners/http_runner/tests/unit/test_http_runner.py
+++ b/contrib/runners/http_runner/tests/unit/test_http_runner.py
@@ -28,16 +28,13 @@
import st2tests.config as tests_config
-__all__ = [
- 'HTTPClientTestCase',
- 'HTTPRunnerTestCase'
-]
+__all__ = ["HTTPClientTestCase", "HTTPRunnerTestCase"]
if six.PY2:
- EXPECTED_DATA = ''
+ EXPECTED_DATA = ""
else:
- EXPECTED_DATA = b''
+ EXPECTED_DATA = b""
class MockResult(object):
@@ -49,70 +46,70 @@ class HTTPClientTestCase(unittest2.TestCase):
def setUpClass(cls):
tests_config.parse_args()
- @mock.patch('http_runner.http_runner.requests')
+ @mock.patch("http_runner.http_runner.requests")
def test_parse_response_body(self, mock_requests):
- client = HTTPClient(url='http://127.0.0.1')
+ client = HTTPClient(url="http://127.0.0.1")
mock_result = MockResult()
# Unknown content type, body should be returned raw
- mock_result.text = 'foo bar ponies'
- mock_result.headers = {'Content-Type': 'text/html'}
+ mock_result.text = "foo bar ponies"
+ mock_result.headers = {"Content-Type": "text/html"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertEqual(result['body'], mock_result.text)
- self.assertEqual(result['status_code'], mock_result.status_code)
- self.assertEqual(result['headers'], mock_result.headers)
+ self.assertEqual(result["body"], mock_result.text)
+ self.assertEqual(result["status_code"], mock_result.status_code)
+ self.assertEqual(result["headers"], mock_result.headers)
# Unknown content type, JSON body
mock_result.text = '{"test1": "val1"}'
- mock_result.headers = {'Content-Type': 'text/html'}
+ mock_result.headers = {"Content-Type": "text/html"}
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertEqual(result['body'], mock_result.text)
+ self.assertEqual(result["body"], mock_result.text)
# JSON content-type and JSON body
mock_result.text = '{"test1": "val1"}'
- mock_result.headers = {'Content-Type': 'application/json'}
+ mock_result.headers = {"Content-Type": "application/json"}
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertIsInstance(result['body'], dict)
- self.assertEqual(result['body'], {'test1': 'val1'})
+ self.assertIsInstance(result["body"], dict)
+ self.assertEqual(result["body"], {"test1": "val1"})
# JSON content-type with charset and JSON body
mock_result.text = '{"test1": "val1"}'
- mock_result.headers = {'Content-Type': 'application/json; charset=UTF-8'}
+ mock_result.headers = {"Content-Type": "application/json; charset=UTF-8"}
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertIsInstance(result['body'], dict)
- self.assertEqual(result['body'], {'test1': 'val1'})
+ self.assertIsInstance(result["body"], dict)
+ self.assertEqual(result["body"], {"test1": "val1"})
# JSON content-type and invalid json body
- mock_result.text = 'not json'
- mock_result.headers = {'Content-Type': 'application/json'}
+ mock_result.text = "not json"
+ mock_result.headers = {"Content-Type": "application/json"}
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertNotIsInstance(result['body'], dict)
- self.assertEqual(result['body'], mock_result.text)
+ self.assertNotIsInstance(result["body"], dict)
+ self.assertEqual(result["body"], mock_result.text)
- @mock.patch('http_runner.http_runner.requests')
+ @mock.patch("http_runner.http_runner.requests")
def test_https_verify(self, mock_requests):
- url = 'https://127.0.0.1:8888'
+ url = "https://127.0.0.1:8888"
client = HTTPClient(url=url, verify=True)
mock_result = MockResult()
- mock_result.text = 'foo bar ponies'
- mock_result.headers = {'Content-Type': 'text/html'}
+ mock_result.text = "foo bar ponies"
+ mock_result.headers = {"Content-Type": "text/html"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
@@ -121,23 +118,33 @@ def test_https_verify(self, mock_requests):
self.assertTrue(client.verify)
if six.PY2:
- data = ''
+ data = ""
else:
- data = b''
+ data = b""
mock_requests.request.assert_called_with(
- 'GET', url, allow_redirects=False, auth=None, cookies=None,
- data=data, files=None, headers={}, params=None, proxies=None,
- timeout=60, verify=True)
-
- @mock.patch('http_runner.http_runner.requests')
+ "GET",
+ url,
+ allow_redirects=False,
+ auth=None,
+ cookies=None,
+ data=data,
+ files=None,
+ headers={},
+ params=None,
+ proxies=None,
+ timeout=60,
+ verify=True,
+ )
+
+ @mock.patch("http_runner.http_runner.requests")
def test_https_verify_false(self, mock_requests):
- url = 'https://127.0.0.1:8888'
+ url = "https://127.0.0.1:8888"
client = HTTPClient(url=url)
mock_result = MockResult()
- mock_result.text = 'foo bar ponies'
- mock_result.headers = {'Content-Type': 'text/html'}
+ mock_result.text = "foo bar ponies"
+ mock_result.headers = {"Content-Type": "text/html"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
@@ -146,182 +153,202 @@ def test_https_verify_false(self, mock_requests):
self.assertFalse(client.verify)
mock_requests.request.assert_called_with(
- 'GET', url, allow_redirects=False, auth=None, cookies=None,
- data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None,
- timeout=60, verify=False)
-
- @mock.patch('http_runner.http_runner.requests')
+ "GET",
+ url,
+ allow_redirects=False,
+ auth=None,
+ cookies=None,
+ data=EXPECTED_DATA,
+ files=None,
+ headers={},
+ params=None,
+ proxies=None,
+ timeout=60,
+ verify=False,
+ )
+
+ @mock.patch("http_runner.http_runner.requests")
def test_https_auth_basic(self, mock_requests):
- url = 'https://127.0.0.1:8888'
- username = 'misspiggy'
- password = 'kermit'
+ url = "https://127.0.0.1:8888"
+ username = "misspiggy"
+ password = "kermit"
client = HTTPClient(url=url, username=username, password=password)
mock_result = MockResult()
- mock_result.text = 'muppet show'
- mock_result.headers = {'Authorization': 'bWlzc3BpZ2d5Omtlcm1pdA=='}
+ mock_result.text = "muppet show"
+ mock_result.headers = {"Authorization": "bWlzc3BpZ2d5Omtlcm1pdA=="}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertEqual(result['headers'], mock_result.headers)
+ self.assertEqual(result["headers"], mock_result.headers)
mock_requests.request.assert_called_once_with(
- 'GET', url, allow_redirects=False, auth=client.auth, cookies=None,
- data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None,
- timeout=60, verify=False)
-
- @mock.patch('http_runner.http_runner.requests')
+ "GET",
+ url,
+ allow_redirects=False,
+ auth=client.auth,
+ cookies=None,
+ data=EXPECTED_DATA,
+ files=None,
+ headers={},
+ params=None,
+ proxies=None,
+ timeout=60,
+ verify=False,
+ )
+
+ @mock.patch("http_runner.http_runner.requests")
def test_http_unicode_body_data(self, mock_requests):
- url = 'http://127.0.0.1:8888'
- method = 'POST'
+ url = "http://127.0.0.1:8888"
+ method = "POST"
mock_result = MockResult()
# 1. String data
headers = {}
- body = 'žžžžž'
- client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1)
+ body = "žžžžž"
+ client = HTTPClient(
+ url=url, method=method, headers=headers, body=body, timeout=0.1
+ )
mock_result.text = '{"foo": "bar"}'
- mock_result.headers = {'Content-Type': 'application/json'}
+ mock_result.headers = {"Content-Type": "application/json"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertEqual(result['status_code'], 200)
+ self.assertEqual(result["status_code"], 200)
call_kwargs = mock_requests.request.call_args_list[0][1]
- expected_data = u'žžžžž'.encode('utf-8')
- self.assertEqual(call_kwargs['data'], expected_data)
+ expected_data = "žžžžž".encode("utf-8")
+ self.assertEqual(call_kwargs["data"], expected_data)
# 1. Object / JSON data
- body = {
- 'foo': u'ažž'
- }
- headers = {
- 'Content-Type': 'application/json; charset=utf-8'
- }
- client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1)
+ body = {"foo": "ažž"}
+ headers = {"Content-Type": "application/json; charset=utf-8"}
+ client = HTTPClient(
+ url=url, method=method, headers=headers, body=body, timeout=0.1
+ )
mock_result.text = '{"foo": "bar"}'
- mock_result.headers = {'Content-Type': 'application/json'}
+ mock_result.headers = {"Content-Type": "application/json"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
result = client.run()
- self.assertEqual(result['status_code'], 200)
+ self.assertEqual(result["status_code"], 200)
call_kwargs = mock_requests.request.call_args_list[1][1]
if six.PY2:
- expected_data = {
- 'foo': u'a\u017e\u017e'
- }
+ expected_data = {"foo": "a\u017e\u017e"}
else:
expected_data = body
- self.assertEqual(call_kwargs['data'], expected_data)
+ self.assertEqual(call_kwargs["data"], expected_data)
- @mock.patch('http_runner.http_runner.requests')
+ @mock.patch("http_runner.http_runner.requests")
def test_blacklisted_url_url_hosts_blacklist_runner_parameter(self, mock_requests):
# Black list is empty
self.assertEqual(mock_requests.request.call_count, 0)
- url = 'http://www.example.com'
- client = HTTPClient(url=url, method='GET')
+ url = "http://www.example.com"
+ client = HTTPClient(url=url, method="GET")
client.run()
self.assertEqual(mock_requests.request.call_count, 1)
# Blacklist is set
url_hosts_blacklist = [
- 'example.com',
- '127.0.0.1',
- '::1',
- '2001:0db8:85a3:0000:0000:8a2e:0370:7334'
+ "example.com",
+ "127.0.0.1",
+ "::1",
+ "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
]
# Blacklisted urls
urls = [
- 'https://example.com',
- 'http://example.com',
- 'http://example.com:81',
- 'http://example.com:80',
- 'http://example.com:9000',
- 'http://[::1]:80/',
- 'http://[::1]',
- 'http://[::1]:9000',
- 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]',
- 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000'
+ "https://example.com",
+ "http://example.com",
+ "http://example.com:81",
+ "http://example.com:80",
+ "http://example.com:9000",
+ "http://[::1]:80/",
+ "http://[::1]",
+ "http://[::1]:9000",
+ "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
+ "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000",
]
for url in urls:
expected_msg = r'URL "%s" is blacklisted' % (re.escape(url))
- client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist)
+ client = HTTPClient(
+ url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist
+ )
self.assertRaisesRegexp(ValueError, expected_msg, client.run)
# Non blacklisted URLs
- urls = [
- 'https://example2.com',
- 'http://example3.com',
- 'http://example4.com:81'
- ]
+ urls = ["https://example2.com", "http://example3.com", "http://example4.com:81"]
for url in urls:
mock_requests.request.reset_mock()
self.assertEqual(mock_requests.request.call_count, 0)
- client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist)
+ client = HTTPClient(
+ url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist
+ )
client.run()
self.assertEqual(mock_requests.request.call_count, 1)
- @mock.patch('http_runner.http_runner.requests')
+ @mock.patch("http_runner.http_runner.requests")
def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_requests):
# Whitelist is empty
self.assertEqual(mock_requests.request.call_count, 0)
- url = 'http://www.example.com'
- client = HTTPClient(url=url, method='GET')
+ url = "http://www.example.com"
+ client = HTTPClient(url=url, method="GET")
client.run()
self.assertEqual(mock_requests.request.call_count, 1)
# Whitelist is set
url_hosts_whitelist = [
- 'example.com',
- '127.0.0.1',
- '::1',
- '2001:0db8:85a3:0000:0000:8a2e:0370:7334'
+ "example.com",
+ "127.0.0.1",
+ "::1",
+ "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
]
# Non whitelisted urls
urls = [
- 'https://www.google.com',
- 'https://www.example2.com',
- 'http://127.0.0.2'
+ "https://www.google.com",
+ "https://www.example2.com",
+ "http://127.0.0.2",
]
for url in urls:
expected_msg = r'URL "%s" is not whitelisted' % (re.escape(url))
- client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist)
+ client = HTTPClient(
+ url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist
+ )
self.assertRaisesRegexp(ValueError, expected_msg, client.run)
# Whitelisted URLS
urls = [
- 'https://example.com',
- 'http://example.com',
- 'http://example.com:81',
- 'http://example.com:80',
- 'http://example.com:9000',
- 'http://[::1]:80/',
- 'http://[::1]',
- 'http://[::1]:9000',
- 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]',
- 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000'
+ "https://example.com",
+ "http://example.com",
+ "http://example.com:81",
+ "http://example.com:80",
+ "http://example.com:9000",
+ "http://[::1]:80/",
+ "http://[::1]",
+ "http://[::1]:9000",
+ "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
+ "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000",
]
for url in urls:
@@ -329,57 +356,71 @@ def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_request
self.assertEqual(mock_requests.request.call_count, 0)
- client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist)
+ client = HTTPClient(
+ url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist
+ )
client.run()
self.assertEqual(mock_requests.request.call_count, 1)
- def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self):
- url = 'http://www.example.com'
-
- expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
- 'exclusive.')
- self.assertRaisesRegexp(ValueError, expected_msg, HTTPClient, url=url, method='GET',
- url_hosts_blacklist=[url], url_hosts_whitelist=[url])
+ def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(
+ self,
+ ):
+ url = "http://www.example.com"
+
+ expected_msg = (
+ r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
+ "exclusive."
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ HTTPClient,
+ url=url,
+ method="GET",
+ url_hosts_blacklist=[url],
+ url_hosts_whitelist=[url],
+ )
class HTTPRunnerTestCase(unittest2.TestCase):
- @mock.patch('http_runner.http_runner.requests')
+ @mock.patch("http_runner.http_runner.requests")
def test_get_success(self, mock_requests):
mock_result = MockResult()
# Unknown content type, body should be returned raw
- mock_result.text = 'foo bar ponies'
- mock_result.headers = {'Content-Type': 'text/html'}
+ mock_result.text = "foo bar ponies"
+ mock_result.headers = {"Content-Type": "text/html"}
mock_result.status_code = 200
mock_requests.request.return_value = mock_result
- runner_parameters = {
- 'url': 'http://www.example.com',
- 'method': 'GET'
- }
- runner = HttpRunner('id')
+ runner_parameters = {"url": "http://www.example.com", "method": "GET"}
+ runner = HttpRunner("id")
runner.runner_parameters = runner_parameters
runner.pre_run()
status, result, _ = runner.run({})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['body'], 'foo bar ponies')
- self.assertEqual(result['status_code'], 200)
- self.assertEqual(result['parsed'], False)
+ self.assertEqual(result["body"], "foo bar ponies")
+ self.assertEqual(result["status_code"], 200)
+ self.assertEqual(result["parsed"], False)
- def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self):
+ def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(
+ self,
+ ):
runner_parameters = {
- 'url': 'http://www.example.com',
- 'method': 'GET',
- 'url_hosts_blacklist': ['http://127.0.0.1'],
- 'url_hosts_whitelist': ['http://127.0.0.1'],
+ "url": "http://www.example.com",
+ "method": "GET",
+ "url_hosts_blacklist": ["http://127.0.0.1"],
+ "url_hosts_whitelist": ["http://127.0.0.1"],
}
- runner = HttpRunner('id')
+ runner = HttpRunner("id")
runner.runner_parameters = runner_parameters
runner.pre_run()
- expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
- 'exclusive.')
+ expected_msg = (
+ r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually '
+ "exclusive."
+ )
self.assertRaisesRegexp(ValueError, expected_msg, runner.run, {})
diff --git a/contrib/runners/inquirer_runner/dist_utils.py b/contrib/runners/inquirer_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/inquirer_runner/dist_utils.py
+++ b/contrib/runners/inquirer_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py
+++ b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py
index 6a5757bc44..af0f0c6f34 100644
--- a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py
+++ b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py
@@ -29,20 +29,16 @@
from st2common.util import action_db as action_utils
-__all__ = [
- 'Inquirer',
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["Inquirer", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
# constants to lookup in runner_parameters.
-RUNNER_SCHEMA = 'schema'
-RUNNER_ROLES = 'roles'
-RUNNER_USERS = 'users'
-RUNNER_ROUTE = 'route'
-RUNNER_TTL = 'ttl'
+RUNNER_SCHEMA = "schema"
+RUNNER_ROLES = "roles"
+RUNNER_USERS = "users"
+RUNNER_ROUTE = "route"
+RUNNER_TTL = "ttl"
DEFAULT_SCHEMA = {
"title": "response_data",
@@ -51,15 +47,14 @@
"continue": {
"type": "boolean",
"description": "Would you like to continue the workflow?",
- "required": True
+ "required": True,
}
- }
+ },
}
class Inquirer(runners.ActionRunner):
- """This runner implements the ability to ask for more input during a workflow
- """
+ """This runner implements the ability to ask for more input during a workflow"""
def __init__(self, runner_id):
super(Inquirer, self).__init__(runner_id=runner_id)
@@ -83,14 +78,11 @@ def run(self, action_parameters):
# Assemble and dispatch trigger
trigger_ref = sys_db_models.ResourceReference.to_string_reference(
- pack=trigger_constants.INQUIRY_TRIGGER['pack'],
- name=trigger_constants.INQUIRY_TRIGGER['name']
+ pack=trigger_constants.INQUIRY_TRIGGER["pack"],
+ name=trigger_constants.INQUIRY_TRIGGER["name"],
)
- trigger_payload = {
- "id": str(exc.id),
- "route": self.route
- }
+ trigger_payload = {"id": str(exc.id), "route": self.route}
self.trigger_dispatcher.dispatch(trigger_ref, trigger_payload)
@@ -99,7 +91,7 @@ def run(self, action_parameters):
"roles": self.roles_param,
"users": self.users_param,
"route": self.route,
- "ttl": self.ttl
+ "ttl": self.ttl,
}
return (action_constants.LIVEACTION_STATUS_PENDING, result, None)
@@ -110,9 +102,10 @@ def post_run(self, status, result):
# is made in the run method, but because the liveaction hasn't update to pending status
# yet, there is a race condition where the pause request is mishandled.
if status == action_constants.LIVEACTION_STATUS_PENDING:
- pause_parent = (
- self.liveaction.context.get("parent") and
- not workflow_service.is_action_execution_under_workflow_context(self.liveaction)
+ pause_parent = self.liveaction.context.get(
+ "parent"
+ ) and not workflow_service.is_action_execution_under_workflow_context(
+ self.liveaction
)
# For action execution under Action Chain workflows, request the entire
@@ -122,7 +115,9 @@ def post_run(self, status, result):
# to pause the workflow.
if pause_parent:
root_liveaction = action_service.get_root_liveaction(self.liveaction)
- action_service.request_pause(root_liveaction, self.context.get('user', None))
+ action_service.request_pause(
+ root_liveaction, self.context.get("user", None)
+ )
# Invoke post run of parent for common post run related work.
super(Inquirer, self).post_run(status, result)
@@ -133,4 +128,4 @@ def get_runner():
def get_metadata():
- return runners.get_metadata('inquirer_runner')[0]
+ return runners.get_metadata("inquirer_runner")[0]
diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml
index 26711df850..60d79a5b74 100644
--- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml
+++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml
@@ -23,7 +23,7 @@
roles:
default: []
required: false
- description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES
+ description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES
type: array
users:
default: []
diff --git a/contrib/runners/inquirer_runner/setup.py b/contrib/runners/inquirer_runner/setup.py
index 9be54704f9..44d4a4d7f7 100644
--- a/contrib/runners/inquirer_runner/setup.py
+++ b/contrib/runners/inquirer_runner/setup.py
@@ -26,30 +26,32 @@
from inquirer_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-inquirer',
+ name="stackstorm-runner-inquirer",
version=__version__,
- description=('Inquirer action runner for StackStorm event-driven automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "Inquirer action runner for StackStorm event-driven automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'inquirer_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"inquirer_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'inquirer = inquirer_runner.inquirer_runner',
+ "st2common.runners.runner": [
+ "inquirer = inquirer_runner.inquirer_runner",
],
- }
+ },
)
diff --git a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py
index da9c70b78a..caa47bc53a 100644
--- a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py
+++ b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py
@@ -28,7 +28,7 @@
mock_exc_get = mock.Mock()
-mock_exc_get.id = 'abcdef'
+mock_exc_get.id = "abcdef"
mock_inquiry_liveaction_db = mock.Mock()
mock_inquiry_liveaction_db.result = {"response": {}}
@@ -37,7 +37,7 @@
mock_action_utils.return_value = mock_inquiry_liveaction_db
test_parent = mock.Mock()
-test_parent.id = '1234567890'
+test_parent.id = "1234567890"
mock_get_root = mock.Mock()
mock_get_root.return_value = test_parent
@@ -45,38 +45,19 @@
mock_trigger_dispatcher = mock.Mock()
mock_request_pause = mock.Mock()
-test_user = 'st2admin'
+test_user = "st2admin"
-runner_params = {
- "users": [],
- "roles": [],
- "route": "developers",
- "schema": {}
-}
+runner_params = {"users": [], "roles": [], "route": "developers", "schema": {}}
+@mock.patch.object(reactor_transport, "TriggerDispatcher", mock_trigger_dispatcher)
+@mock.patch.object(action_utils, "get_liveaction_by_id", mock_action_utils)
+@mock.patch.object(action_service, "request_pause", mock_request_pause)
+@mock.patch.object(action_service, "get_root_liveaction", mock_get_root)
@mock.patch.object(
- reactor_transport,
- 'TriggerDispatcher',
- mock_trigger_dispatcher)
-@mock.patch.object(
- action_utils,
- 'get_liveaction_by_id',
- mock_action_utils)
-@mock.patch.object(
- action_service,
- 'request_pause',
- mock_request_pause)
-@mock.patch.object(
- action_service,
- 'get_root_liveaction',
- mock_get_root)
-@mock.patch.object(
- ex_db_access.ActionExecution,
- 'get',
- mock.MagicMock(return_value=mock_exc_get))
+ ex_db_access.ActionExecution, "get", mock.MagicMock(return_value=mock_exc_get)
+)
class InquiryTestCase(st2tests.RunnerTestCase):
-
def tearDown(self):
mock_trigger_dispatcher.reset_mock()
mock_action_utils.reset_mock()
@@ -85,17 +66,19 @@ def tearDown(self):
def test_runner_creation(self):
runner = inquirer_runner.get_runner()
- self.assertIsNotNone(runner, 'Creation failed. No instance.')
- self.assertEqual(type(runner), inquirer_runner.Inquirer, 'Creation failed. No instance.')
+ self.assertIsNotNone(runner, "Creation failed. No instance.")
+ self.assertEqual(
+ type(runner), inquirer_runner.Inquirer, "Creation failed. No instance."
+ )
def test_simple_inquiry(self):
runner = inquirer_runner.get_runner()
- runner.context = {'user': test_user}
+ runner.context = {"user": test_user}
runner.action = self._get_mock_action_obj()
runner.runner_parameters = runner_params
runner.pre_run()
- mock_inquiry_liveaction_db.context = {'parent': test_parent.id}
+ mock_inquiry_liveaction_db.context = {"parent": test_parent.id}
runner.liveaction = mock_inquiry_liveaction_db
(status, output, _) = runner.run({})
@@ -104,20 +87,16 @@ def test_simple_inquiry(self):
self.assertEqual(
output,
{
- 'users': [],
- 'roles': [],
- 'route': "developers",
- 'schema': {},
- 'ttl': 1440
- }
+ "users": [],
+ "roles": [],
+ "route": "developers",
+ "schema": {},
+ "ttl": 1440,
+ },
)
mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with(
- 'core.st2.generic.inquiry',
- {
- 'id': mock_exc_get.id,
- 'route': "developers"
- }
+ "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"}
)
runner.post_run(action_constants.LIVEACTION_STATUS_PENDING, {})
@@ -125,37 +104,28 @@ def test_simple_inquiry(self):
mock_request_pause.assert_called_once_with(test_parent, test_user)
def test_inquiry_no_parent(self):
- """Should behave like a regular execution, but without requesting a pause
- """
+ """Should behave like a regular execution, but without requesting a pause"""
runner = inquirer_runner.get_runner()
- runner.context = {
- 'user': 'st2admin'
- }
+ runner.context = {"user": "st2admin"}
runner.action = self._get_mock_action_obj()
runner.runner_parameters = runner_params
runner.pre_run()
- mock_inquiry_liveaction_db.context = {
- "parent": None
- }
+ mock_inquiry_liveaction_db.context = {"parent": None}
(status, output, _) = runner.run({})
self.assertEqual(status, action_constants.LIVEACTION_STATUS_PENDING)
self.assertEqual(
output,
{
- 'users': [],
- 'roles': [],
- 'route': "developers",
- 'schema': {},
- 'ttl': 1440
- }
+ "users": [],
+ "roles": [],
+ "route": "developers",
+ "schema": {},
+ "ttl": 1440,
+ },
)
mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with(
- 'core.st2.generic.inquiry',
- {
- 'id': mock_exc_get.id,
- 'route': "developers"
- }
+ "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"}
)
mock_request_pause.assert_not_called()
diff --git a/contrib/runners/local_runner/dist_utils.py b/contrib/runners/local_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/local_runner/dist_utils.py
+++ b/contrib/runners/local_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/local_runner/local_runner/__init__.py b/contrib/runners/local_runner/local_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/local_runner/local_runner/__init__.py
+++ b/contrib/runners/local_runner/local_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/local_runner/local_runner/base.py b/contrib/runners/local_runner/local_runner/base.py
index 5bcf20137f..4fda9c1866 100644
--- a/contrib/runners/local_runner/local_runner/base.py
+++ b/contrib/runners/local_runner/local_runner/base.py
@@ -39,32 +39,36 @@
from st2common.services.action import store_execution_output_data
from st2common.runners.utils import make_read_and_store_stream_func
-__all__ = [
- 'BaseLocalShellRunner',
-
- 'RUNNER_COMMAND'
-]
+__all__ = ["BaseLocalShellRunner", "RUNNER_COMMAND"]
LOG = logging.getLogger(__name__)
-DEFAULT_KWARG_OP = '--'
+DEFAULT_KWARG_OP = "--"
LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0]
# constants to lookup in runner_parameters.
-RUNNER_SUDO = 'sudo'
-RUNNER_SUDO_PASSWORD = 'sudo_password'
-RUNNER_ON_BEHALF_USER = 'user'
-RUNNER_COMMAND = 'cmd'
-RUNNER_CWD = 'cwd'
-RUNNER_ENV = 'env'
-RUNNER_KWARG_OP = 'kwarg_op'
-RUNNER_TIMEOUT = 'timeout'
+RUNNER_SUDO = "sudo"
+RUNNER_SUDO_PASSWORD = "sudo_password"
+RUNNER_ON_BEHALF_USER = "user"
+RUNNER_COMMAND = "cmd"
+RUNNER_CWD = "cwd"
+RUNNER_ENV = "env"
+RUNNER_KWARG_OP = "kwarg_op"
+RUNNER_TIMEOUT = "timeout"
PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP = {
- str(exit_code_constants.SUCCESS_EXIT_CODE): action_constants.LIVEACTION_STATUS_SUCCEEDED,
- str(exit_code_constants.FAILURE_EXIT_CODE): action_constants.LIVEACTION_STATUS_FAILED,
- str(-1 * exit_code_constants.SIGKILL_EXIT_CODE): action_constants.LIVEACTION_STATUS_TIMED_OUT,
- str(-1 * exit_code_constants.SIGTERM_EXIT_CODE): action_constants.LIVEACTION_STATUS_ABANDONED
+ str(
+ exit_code_constants.SUCCESS_EXIT_CODE
+ ): action_constants.LIVEACTION_STATUS_SUCCEEDED,
+ str(
+ exit_code_constants.FAILURE_EXIT_CODE
+ ): action_constants.LIVEACTION_STATUS_FAILED,
+ str(
+ -1 * exit_code_constants.SIGKILL_EXIT_CODE
+ ): action_constants.LIVEACTION_STATUS_TIMED_OUT,
+ str(
+ -1 * exit_code_constants.SIGTERM_EXIT_CODE
+ ): action_constants.LIVEACTION_STATUS_ABANDONED,
}
@@ -77,7 +81,8 @@ class BaseLocalShellRunner(ActionRunner, ShellRunnerMixin):
Note: The user under which the action runner service is running (stanley user by default) needs
to have pasworless sudo access set up.
"""
- KEYS_TO_TRANSFORM = ['stdout', 'stderr']
+
+ KEYS_TO_TRANSFORM = ["stdout", "stderr"]
def __init__(self, runner_id):
super(BaseLocalShellRunner, self).__init__(runner_id=runner_id)
@@ -87,14 +92,17 @@ def pre_run(self):
self._sudo = self.runner_parameters.get(RUNNER_SUDO, False)
self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None)
- self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME)
+ self._on_behalf_user = self.context.get(
+ RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME
+ )
self._user = cfg.CONF.system_user.user
self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
self._env = self.runner_parameters.get(RUNNER_ENV, {})
self._env = self._env or {}
self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, DEFAULT_KWARG_OP)
self._timeout = self.runner_parameters.get(
- RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT)
+ RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT
+ )
def _run(self, action):
env_vars = self._env
@@ -110,8 +118,11 @@ def _run(self, action):
# For consistency with the old Fabric based runner, make sure the file is executable
if script_action:
script_local_path_abs = self.entry_point
- args = 'chmod +x %s ; %s' % (script_local_path_abs, args)
- sanitized_args = 'chmod +x %s ; %s' % (script_local_path_abs, sanitized_args)
+ args = "chmod +x %s ; %s" % (script_local_path_abs, args)
+ sanitized_args = "chmod +x %s ; %s" % (
+ script_local_path_abs,
+ sanitized_args,
+ )
env = os.environ.copy()
@@ -122,22 +133,38 @@ def _run(self, action):
st2_env_vars = self._get_common_action_env_variables()
env.update(st2_env_vars)
- LOG.info('Executing action via LocalRunner: %s', self.runner_id)
- LOG.info('[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s' %
- (action.name, action.action_exec_id, sanitized_args, action.user, action.sudo))
+ LOG.info("Executing action via LocalRunner: %s", self.runner_id)
+ LOG.info(
+ "[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s"
+ % (
+ action.name,
+ action.action_exec_id,
+ sanitized_args,
+ action.user,
+ action.sudo,
+ )
+ )
stdout = StringIO()
stderr = StringIO()
- store_execution_stdout_line = functools.partial(store_execution_output_data,
- output_type='stdout')
- store_execution_stderr_line = functools.partial(store_execution_output_data,
- output_type='stderr')
+ store_execution_stdout_line = functools.partial(
+ store_execution_output_data, output_type="stdout"
+ )
+ store_execution_stderr_line = functools.partial(
+ store_execution_output_data, output_type="stderr"
+ )
- read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution,
- action_db=self.action, store_data_func=store_execution_stdout_line)
- read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution,
- action_db=self.action, store_data_func=store_execution_stderr_line)
+ read_and_store_stdout = make_read_and_store_stream_func(
+ execution_db=self.execution,
+ action_db=self.action,
+ store_data_func=store_execution_stdout_line,
+ )
+ read_and_store_stderr = make_read_and_store_stream_func(
+ execution_db=self.execution,
+ action_db=self.action,
+ store_data_func=store_execution_stderr_line,
+ )
subprocess = concurrency.get_subprocess_module()
@@ -145,9 +172,10 @@ def _run(self, action):
# Note: We don't need to explicitly escape the argument because we pass command as a list
# to subprocess.Popen and all the arguments are escaped by the function.
if self._sudo_password:
- LOG.debug('Supplying sudo password via stdin')
- echo_process = concurrency.subprocess_popen(['echo', self._sudo_password + '\n'],
- stdout=subprocess.PIPE)
+ LOG.debug("Supplying sudo password via stdin")
+ echo_process = concurrency.subprocess_popen(
+ ["echo", self._sudo_password + "\n"], stdout=subprocess.PIPE
+ )
stdin = echo_process.stdout
else:
stdin = None
@@ -161,57 +189,64 @@ def _run(self, action):
# Ideally os.killpg should have done the trick but for some reason that failed.
# Note: pkill will set the returncode to 143 so we don't need to explicitly set
# it to some non-zero value.
- exit_code, stdout, stderr, timed_out = shell.run_command(cmd=args,
- stdin=stdin,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- shell=True,
- cwd=self._cwd,
- env=env,
- timeout=self._timeout,
- preexec_func=os.setsid,
- kill_func=kill_process,
- read_stdout_func=read_and_store_stdout,
- read_stderr_func=read_and_store_stderr,
- read_stdout_buffer=stdout,
- read_stderr_buffer=stderr)
+ exit_code, stdout, stderr, timed_out = shell.run_command(
+ cmd=args,
+ stdin=stdin,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ cwd=self._cwd,
+ env=env,
+ timeout=self._timeout,
+ preexec_func=os.setsid,
+ kill_func=kill_process,
+ read_stdout_func=read_and_store_stdout,
+ read_stderr_func=read_and_store_stderr,
+ read_stdout_buffer=stdout,
+ read_stderr_buffer=stderr,
+ )
error = None
if timed_out:
- error = 'Action failed to complete in %s seconds' % (self._timeout)
+ error = "Action failed to complete in %s seconds" % (self._timeout)
exit_code = -1 * exit_code_constants.SIGKILL_EXIT_CODE
# Detect if user provided an invalid sudo password or sudo is not configured for that user
if self._sudo_password:
- if re.search(r'sudo: \d+ incorrect password attempts', stderr):
- match = re.search(r'\[sudo\] password for (.+?)\:', stderr)
+ if re.search(r"sudo: \d+ incorrect password attempts", stderr):
+ match = re.search(r"\[sudo\] password for (.+?)\:", stderr)
if match:
username = match.groups()[0]
else:
- username = 'unknown'
+ username = "unknown"
- error = ('Invalid sudo password provided or sudo is not configured for this user '
- '(%s)' % (username))
+ error = (
+ "Invalid sudo password provided or sudo is not configured for this user "
+ "(%s)" % (username)
+ )
exit_code = -1
- succeeded = (exit_code == exit_code_constants.SUCCESS_EXIT_CODE)
+ succeeded = exit_code == exit_code_constants.SUCCESS_EXIT_CODE
result = {
- 'failed': not succeeded,
- 'succeeded': succeeded,
- 'return_code': exit_code,
- 'stdout': strip_shell_chars(stdout),
- 'stderr': strip_shell_chars(stderr)
+ "failed": not succeeded,
+ "succeeded": succeeded,
+ "return_code": exit_code,
+ "stdout": strip_shell_chars(stdout),
+ "stderr": strip_shell_chars(stderr),
}
if error:
- result['error'] = error
+ result["error"] = error
status = PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP.get(
- str(exit_code),
- action_constants.LIVEACTION_STATUS_FAILED
+ str(exit_code), action_constants.LIVEACTION_STATUS_FAILED
)
- return (status, jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM), None)
+ return (
+ status,
+ jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM),
+ None,
+ )
diff --git a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py
index 4ae61f3225..cbf603de27 100644
--- a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py
+++ b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py
@@ -23,28 +23,25 @@
from local_runner.base import BaseLocalShellRunner
from local_runner.base import RUNNER_COMMAND
-__all__ = [
- 'LocalShellCommandRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["LocalShellCommandRunner", "get_runner", "get_metadata"]
class LocalShellCommandRunner(BaseLocalShellRunner):
def run(self, action_parameters):
if self.entry_point:
- raise ValueError('entry_point is only valid for local-shell-script runner')
+ raise ValueError("entry_point is only valid for local-shell-script runner")
command = self.runner_parameters.get(RUNNER_COMMAND, None)
- action = ShellCommandAction(name=self.action_name,
- action_exec_id=str(self.liveaction_id),
- command=command,
- user=self._user,
- env_vars=self._env,
- sudo=self._sudo,
- timeout=self._timeout,
- sudo_password=self._sudo_password)
+ action = ShellCommandAction(
+ name=self.action_name,
+ action_exec_id=str(self.liveaction_id),
+ command=command,
+ user=self._user,
+ env_vars=self._env,
+ sudo=self._sudo,
+ timeout=self._timeout,
+ sudo_password=self._sudo_password,
+ )
return self._run(action=action)
@@ -54,7 +51,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('local_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("local_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py
index 24a0fe6ddb..257e457ca1 100644
--- a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py
+++ b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py
@@ -23,34 +23,31 @@
from local_runner.base import BaseLocalShellRunner
-__all__ = [
- 'LocalShellScriptRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["LocalShellScriptRunner", "get_runner", "get_metadata"]
class LocalShellScriptRunner(BaseLocalShellRunner, GitWorktreeActionRunner):
def run(self, action_parameters):
if not self.entry_point:
- raise ValueError('Missing entry_point action metadata attribute')
+ raise ValueError("Missing entry_point action metadata attribute")
script_local_path_abs = self.entry_point
positional_args, named_args = self._get_script_args(action_parameters)
named_args = self._transform_named_args(named_args)
- action = ShellScriptAction(name=self.action_name,
- action_exec_id=str(self.liveaction_id),
- script_local_path_abs=script_local_path_abs,
- named_args=named_args,
- positional_args=positional_args,
- user=self._user,
- env_vars=self._env,
- sudo=self._sudo,
- timeout=self._timeout,
- cwd=self._cwd,
- sudo_password=self._sudo_password)
+ action = ShellScriptAction(
+ name=self.action_name,
+ action_exec_id=str(self.liveaction_id),
+ script_local_path_abs=script_local_path_abs,
+ named_args=named_args,
+ positional_args=positional_args,
+ user=self._user,
+ env_vars=self._env,
+ sudo=self._sudo,
+ timeout=self._timeout,
+ cwd=self._cwd,
+ sudo_password=self._sudo_password,
+ )
return self._run(action=action)
@@ -60,7 +57,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('local_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("local_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/local_runner/setup.py b/contrib/runners/local_runner/setup.py
index feb1cb6554..063314ab74 100644
--- a/contrib/runners/local_runner/setup.py
+++ b/contrib/runners/local_runner/setup.py
@@ -26,32 +26,34 @@
from local_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-local',
+ name="stackstorm-runner-local",
version=__version__,
- description=('Local Shell Command and Script action runner for StackStorm event-driven '
- 'automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "Local Shell Command and Script action runner for StackStorm event-driven "
+ "automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'local_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"local_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'local-shell-cmd = local_runner.local_shell_command_runner',
- 'local-shell-script = local_runner.local_shell_script_runner',
+ "st2common.runners.runner": [
+ "local-shell-cmd = local_runner.local_shell_command_runner",
+ "local-shell-script = local_runner.local_shell_script_runner",
],
- }
+ },
)
diff --git a/contrib/runners/local_runner/tests/integration/test_localrunner.py b/contrib/runners/local_runner/tests/integration/test_localrunner.py
index 0e5a2f3efc..05c241f46b 100644
--- a/contrib/runners/local_runner/tests/integration/test_localrunner.py
+++ b/contrib/runners/local_runner/tests/integration/test_localrunner.py
@@ -22,6 +22,7 @@
import st2tests.config as tests_config
from six.moves import range
+
tests_config.parse_args()
from st2common.constants import action as action_constants
@@ -40,13 +41,10 @@
from local_runner.local_shell_command_runner import LocalShellCommandRunner
from local_runner.local_shell_script_runner import LocalShellScriptRunner
-__all__ = [
- 'LocalShellCommandRunnerTestCase',
- 'LocalShellScriptRunnerTestCase'
-]
+__all__ = ["LocalShellCommandRunnerTestCase", "LocalShellScriptRunnerTestCase"]
MOCK_EXECUTION = mock.Mock()
-MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b'
+MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b"
class LocalShellCommandRunnerTestCase(RunnerTestCase, CleanDbTestCase):
@@ -56,108 +54,115 @@ def setUp(self):
super(LocalShellCommandRunnerTestCase, self).setUp()
# False is a default behavior so end result should be the same
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False)
+ cfg.CONF.set_override(
+ name="stream_output", group="actionrunner", override=False
+ )
def test_shell_command_action_basic(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
- runner = self._get_runner(action_db, cmd='echo 10')
+ runner = self._get_runner(action_db, cmd="echo 10")
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], 10)
+ self.assertEqual(result["stdout"], 10)
# End result should be the same when streaming is enabled
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True)
+ cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True)
# Verify initial state
output_dbs = ActionExecutionOutput.get_all()
self.assertEqual(len(output_dbs), 0)
- runner = self._get_runner(action_db, cmd='echo 10')
+ runner = self._get_runner(action_db, cmd="echo 10")
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], 10)
+ self.assertEqual(result["stdout"], 10)
output_dbs = ActionExecutionOutput.get_all()
self.assertEqual(len(output_dbs), 1)
- self.assertEqual(output_dbs[0].output_type, 'stdout')
- self.assertEqual(output_dbs[0].data, '10\n')
+ self.assertEqual(output_dbs[0].output_type, "stdout")
+ self.assertEqual(output_dbs[0].data, "10\n")
def test_timeout(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
# smaller timeout == faster tests.
- runner = self._get_runner(action_db, cmd='sleep 10', timeout=0.01)
+ runner = self._get_runner(action_db, cmd="sleep 10", timeout=0.01)
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_TIMED_OUT)
@mock.patch.object(
- shell, 'run_command',
- mock.MagicMock(return_value=(-15, '', '', False)))
+ shell, "run_command", mock.MagicMock(return_value=(-15, "", "", False))
+ )
def test_shutdown(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
- runner = self._get_runner(action_db, cmd='sleep 0.1')
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
+ runner = self._get_runner(action_db, cmd="sleep 0.1")
runner.pre_run()
status, result, _ = runner.run({})
self.assertEqual(status, action_constants.LIVEACTION_STATUS_ABANDONED)
def test_common_st2_env_vars_are_available_to_the_action(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
- runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL')
+ runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL")
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'].strip(), get_full_public_api_url())
+ self.assertEqual(result["stdout"].strip(), get_full_public_api_url())
- runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_AUTH_TOKEN')
+ runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_AUTH_TOKEN")
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'].strip(), 'mock-token')
+ self.assertEqual(result["stdout"].strip(), "mock-token")
def test_sudo_and_env_variable_preservation(self):
# Verify that the environment environment are correctly preserved when running as a
# root / non-system user
# Note: This test will fail if SETENV option is not present in the sudoers file
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
- cmd = 'echo `whoami` ; echo ${VAR1}'
- env = {'VAR1': 'poniesponies'}
+ cmd = "echo `whoami` ; echo ${VAR1}"
+ env = {"VAR1": "poniesponies"}
runner = self._get_runner(action_db, cmd=cmd, sudo=True, env=env)
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'].strip(), 'root\nponiesponies')
+ self.assertEqual(result["stdout"].strip(), "root\nponiesponies")
- @mock.patch('st2common.util.concurrency.subprocess_popen')
- @mock.patch('st2common.util.concurrency.spawn')
+ @mock.patch("st2common.util.concurrency.subprocess_popen")
+ @mock.patch("st2common.util.concurrency.spawn")
def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen):
# Feature is enabled
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True)
+ cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True)
# Note: We need to mock spawn function so we can test everything in single event loop
# iteration
@@ -165,78 +170,75 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop
# No output to stdout and no result (implicit None)
mock_stdout = [
- 'stdout line 1\n',
- 'stdout line 2\n',
- ]
- mock_stderr = [
- 'stderr line 1\n',
- 'stderr line 2\n',
- 'stderr line 3\n'
+ "stdout line 1\n",
+ "stdout line 2\n",
]
+ mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"]
mock_process = mock.Mock()
mock_process.returncode = 0
mock_popen.return_value = mock_process
mock_process.stdout.closed = False
mock_process.stderr.closed = False
- mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout,
- stop_counter=2)
- mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr,
- stop_counter=3)
+ mock_process.stdout.readline = make_mock_stream_readline(
+ mock_process.stdout, mock_stdout, stop_counter=2
+ )
+ mock_process.stderr.readline = make_mock_stream_readline(
+ mock_process.stderr, mock_stderr, stop_counter=3
+ )
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
- runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL')
+ runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL")
runner.pre_run()
status, result, _ = runner.run({})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2')
- self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3')
- self.assertEqual(result['return_code'], 0)
+ self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2")
+ self.assertEqual(
+ result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3"
+ )
+ self.assertEqual(result["return_code"], 0)
# Verify stdout and stderr lines have been correctly stored in the db
- output_dbs = ActionExecutionOutput.query(output_type='stdout')
+ output_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(output_dbs), 2)
self.assertEqual(output_dbs[0].data, mock_stdout[0])
self.assertEqual(output_dbs[1].data, mock_stdout[1])
- output_dbs = ActionExecutionOutput.query(output_type='stderr')
+ output_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(output_dbs), 3)
self.assertEqual(output_dbs[0].data, mock_stderr[0])
self.assertEqual(output_dbs[1].data, mock_stderr[1])
self.assertEqual(output_dbs[2].data, mock_stderr[2])
- @mock.patch('st2common.util.concurrency.subprocess_popen')
- @mock.patch('st2common.util.concurrency.spawn')
- def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, mock_spawn,
- mock_popen):
+ @mock.patch("st2common.util.concurrency.subprocess_popen")
+ @mock.patch("st2common.util.concurrency.spawn")
+ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(
+ self, mock_spawn, mock_popen
+ ):
# Verify that we correctly retrieve all the output and wait for stdout and stderr reading
# threads for short running actions.
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
# Feature is enabled
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True)
+ cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True)
# Note: We need to mock spawn function so we can test everything in single event loop
# iteration
mock_spawn.side_effect = blocking_eventlet_spawn
# No output to stdout and no result (implicit None)
- mock_stdout = [
- 'stdout line 1\n',
- 'stdout line 2\n'
- ]
- mock_stderr = [
- 'stderr line 1\n',
- 'stderr line 2\n'
- ]
+ mock_stdout = ["stdout line 1\n", "stdout line 2\n"]
+ mock_stderr = ["stderr line 1\n", "stderr line 2\n"]
# We add a sleep to simulate action process exiting before we finish reading data from
mock_process = mock.Mock()
@@ -244,11 +246,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self,
mock_popen.return_value = mock_process
mock_process.stdout.closed = False
mock_process.stderr.closed = False
- mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout,
- stop_counter=2,
- sleep_delay=1)
- mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr,
- stop_counter=2)
+ mock_process.stdout.readline = make_mock_stream_readline(
+ mock_process.stdout, mock_stdout, stop_counter=2, sleep_delay=1
+ )
+ mock_process.stderr.readline = make_mock_stream_readline(
+ mock_process.stderr, mock_stderr, stop_counter=2
+ )
for index in range(1, 4):
mock_process.stdout.closed = False
@@ -263,12 +266,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self,
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2')
- self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2')
- self.assertEqual(result['return_code'], 0)
+ self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2")
+ self.assertEqual(result["stderr"], "stderr line 1\nstderr line 2")
+ self.assertEqual(result["return_code"], 0)
# Verify stdout and stderr lines have been correctly stored in the db
- output_dbs = ActionExecutionOutput.query(output_type='stdout')
+ output_dbs = ActionExecutionOutput.query(output_type="stdout")
if index == 1:
db_index_1 = 0
@@ -287,7 +290,7 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self,
self.assertEqual(output_dbs[db_index_1].data, mock_stdout[0])
self.assertEqual(output_dbs[db_index_2].data, mock_stdout[1])
- output_dbs = ActionExecutionOutput.query(output_type='stderr')
+ output_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(output_dbs), (index * 2))
self.assertEqual(output_dbs[db_index_1].data, mock_stderr[0])
self.assertEqual(output_dbs[db_index_2].data, mock_stderr[1])
@@ -295,16 +298,13 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self,
def test_shell_command_sudo_password_is_passed_to_sudo_binary(self):
# Verify that sudo password is correctly passed to sudo binary via stdin
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
- sudo_passwords = [
- 'pass 1',
- 'sudopass',
- '$sudo p@ss 2'
- ]
+ sudo_passwords = ["pass 1", "sudopass", "$sudo p@ss 2"]
- cmd = ('{ read sudopass; echo $sudopass; }')
+ cmd = "{ read sudopass; echo $sudopass; }"
# without sudo
for sudo_password in sudo_passwords:
@@ -314,9 +314,8 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self):
status, result, _ = runner.run({})
runner.post_run(status, result)
- self.assertEqual(status,
- action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], sudo_password)
+ self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(result["stdout"], sudo_password)
# with sudo
for sudo_password in sudo_passwords:
@@ -327,12 +326,13 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self):
status, result, _ = runner.run({})
runner.post_run(status, result)
- self.assertEqual(status,
- action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['stdout'], sudo_password)
+ self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(result["stdout"], sudo_password)
# Verify new process which provides password via stdin to the command is created
- with mock.patch('st2common.util.concurrency.subprocess_popen') as mock_subproc_popen:
+ with mock.patch(
+ "st2common.util.concurrency.subprocess_popen"
+ ) as mock_subproc_popen:
index = 0
for sudo_password in sudo_passwords:
runner = self._get_runner(action_db, cmd=cmd)
@@ -349,58 +349,67 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self):
index += 1
- self.assertEqual(call_args[0][0], ['echo', '%s\n' % (sudo_password)])
+ self.assertEqual(call_args[0][0], ["echo", "%s\n" % (sudo_password)])
self.assertEqual(index, len(sudo_passwords))
def test_shell_command_invalid_stdout_password(self):
# Simulate message printed to stderr by sudo when invalid sudo password is provided
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']})
- action_db = models['actions']['local.yaml']
-
- cmd = ('echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:'
- ' Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password '
- 'attempts" 1>&2; exit 1')
+ fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]}
+ )
+ action_db = models["actions"]["local.yaml"]
+
+ cmd = (
+ 'echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:'
+ " Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password "
+ 'attempts" 1>&2; exit 1'
+ )
runner = self._get_runner(action_db, cmd=cmd)
runner.pre_run()
- runner._sudo_password = 'pass'
+ runner._sudo_password = "pass"
status, result, _ = runner.run({})
runner.post_run(status, result)
- expected_error = ('Invalid sudo password provided or sudo is not configured for this '
- 'user (bar)')
+ expected_error = (
+ "Invalid sudo password provided or sudo is not configured for this "
+ "user (bar)"
+ )
self.assertEqual(status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(result['error'], expected_error)
- self.assertEqual(result['stdout'], '')
+ self.assertEqual(result["error"], expected_error)
+ self.assertEqual(result["stdout"], "")
@staticmethod
- def _get_runner(action_db,
- entry_point=None,
- cmd=None,
- on_behalf_user=None,
- user=None,
- kwarg_op=local_runner.DEFAULT_KWARG_OP,
- timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT,
- sudo=False,
- env=None):
+ def _get_runner(
+ action_db,
+ entry_point=None,
+ cmd=None,
+ on_behalf_user=None,
+ user=None,
+ kwarg_op=local_runner.DEFAULT_KWARG_OP,
+ timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT,
+ sudo=False,
+ env=None,
+ ):
runner = LocalShellCommandRunner(uuid.uuid4().hex)
runner.execution = MOCK_EXECUTION
runner.action = action_db
runner.action_name = action_db.name
runner.liveaction_id = uuid.uuid4().hex
runner.entry_point = entry_point
- runner.runner_parameters = {local_runner.RUNNER_COMMAND: cmd,
- local_runner.RUNNER_SUDO: sudo,
- local_runner.RUNNER_ENV: env,
- local_runner.RUNNER_ON_BEHALF_USER: user,
- local_runner.RUNNER_KWARG_OP: kwarg_op,
- local_runner.RUNNER_TIMEOUT: timeout}
+ runner.runner_parameters = {
+ local_runner.RUNNER_COMMAND: cmd,
+ local_runner.RUNNER_SUDO: sudo,
+ local_runner.RUNNER_ENV: env,
+ local_runner.RUNNER_ON_BEHALF_USER: user,
+ local_runner.RUNNER_KWARG_OP: kwarg_op,
+ local_runner.RUNNER_TIMEOUT: timeout,
+ }
runner.context = dict()
runner.callback = dict()
runner.libs_dir_path = None
runner.auth_token = mock.Mock()
- runner.auth_token.token = 'mock-token'
+ runner.auth_token.token = "mock-token"
return runner
@@ -411,22 +420,27 @@ def setUp(self):
super(LocalShellScriptRunnerTestCase, self).setUp()
# False is a default behavior so end result should be the same
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False)
+ cfg.CONF.set_override(
+ name="stream_output", group="actionrunner", override=False
+ )
def test_script_with_parameters_parameter_serialization(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']})
- action_db = models['actions']['local_script_with_params.yaml']
- entry_point = os.path.join(get_fixtures_base_path(),
- 'generic/actions/local_script_with_params.sh')
+ fixtures_pack="generic",
+ fixtures_dict={"actions": ["local_script_with_params.yaml"]},
+ )
+ action_db = models["actions"]["local_script_with_params.yaml"]
+ entry_point = os.path.join(
+ get_fixtures_base_path(), "generic/actions/local_script_with_params.sh"
+ )
action_parameters = {
- 'param_string': 'test string',
- 'param_integer': 1,
- 'param_float': 2.55,
- 'param_boolean': True,
- 'param_list': ['a', 'b', 'c'],
- 'param_object': {'foo': 'bar'}
+ "param_string": "test string",
+ "param_integer": 1,
+ "param_float": 2.55,
+ "param_boolean": True,
+ "param_list": ["a", "b", "c"],
+ "param_object": {"foo": "bar"},
}
runner = self._get_runner(action_db=action_db, entry_point=entry_point)
@@ -434,20 +448,20 @@ def test_script_with_parameters_parameter_serialization(self):
status, result, _ = runner.run(action_parameters=action_parameters)
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('PARAM_STRING=test string', result['stdout'])
- self.assertIn('PARAM_INTEGER=1', result['stdout'])
- self.assertIn('PARAM_FLOAT=2.55', result['stdout'])
- self.assertIn('PARAM_BOOLEAN=1', result['stdout'])
- self.assertIn('PARAM_LIST=a,b,c', result['stdout'])
- self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout'])
+ self.assertIn("PARAM_STRING=test string", result["stdout"])
+ self.assertIn("PARAM_INTEGER=1", result["stdout"])
+ self.assertIn("PARAM_FLOAT=2.55", result["stdout"])
+ self.assertIn("PARAM_BOOLEAN=1", result["stdout"])
+ self.assertIn("PARAM_LIST=a,b,c", result["stdout"])
+ self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"])
action_parameters = {
- 'param_string': 'test string',
- 'param_integer': 1,
- 'param_float': 2.55,
- 'param_boolean': False,
- 'param_list': ['a', 'b', 'c'],
- 'param_object': {'foo': 'bar'}
+ "param_string": "test string",
+ "param_integer": 1,
+ "param_float": 2.55,
+ "param_boolean": False,
+ "param_list": ["a", "b", "c"],
+ "param_object": {"foo": "bar"},
}
runner = self._get_runner(action_db=action_db, entry_point=entry_point)
@@ -456,12 +470,12 @@ def test_script_with_parameters_parameter_serialization(self):
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('PARAM_BOOLEAN=0', result['stdout'])
+ self.assertIn("PARAM_BOOLEAN=0", result["stdout"])
action_parameters = {
- 'param_string': '',
- 'param_integer': None,
- 'param_float': None,
+ "param_string": "",
+ "param_integer": None,
+ "param_float": None,
}
runner = self._get_runner(action_db=action_db, entry_point=entry_point)
@@ -470,24 +484,24 @@ def test_script_with_parameters_parameter_serialization(self):
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('PARAM_STRING=\n', result['stdout'])
- self.assertIn('PARAM_INTEGER=\n', result['stdout'])
- self.assertIn('PARAM_FLOAT=\n', result['stdout'])
+ self.assertIn("PARAM_STRING=\n", result["stdout"])
+ self.assertIn("PARAM_INTEGER=\n", result["stdout"])
+ self.assertIn("PARAM_FLOAT=\n", result["stdout"])
# End result should be the same when streaming is enabled
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True)
+ cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True)
# Verify initial state
output_dbs = ActionExecutionOutput.get_all()
self.assertEqual(len(output_dbs), 0)
action_parameters = {
- 'param_string': 'test string',
- 'param_integer': 1,
- 'param_float': 2.55,
- 'param_boolean': True,
- 'param_list': ['a', 'b', 'c'],
- 'param_object': {'foo': 'bar'}
+ "param_string": "test string",
+ "param_integer": 1,
+ "param_float": 2.55,
+ "param_boolean": True,
+ "param_list": ["a", "b", "c"],
+ "param_object": {"foo": "bar"},
}
runner = self._get_runner(action_db=action_db, entry_point=entry_point)
@@ -496,26 +510,26 @@ def test_script_with_parameters_parameter_serialization(self):
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('PARAM_STRING=test string', result['stdout'])
- self.assertIn('PARAM_INTEGER=1', result['stdout'])
- self.assertIn('PARAM_FLOAT=2.55', result['stdout'])
- self.assertIn('PARAM_BOOLEAN=1', result['stdout'])
- self.assertIn('PARAM_LIST=a,b,c', result['stdout'])
- self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout'])
-
- output_dbs = ActionExecutionOutput.query(output_type='stdout')
+ self.assertIn("PARAM_STRING=test string", result["stdout"])
+ self.assertIn("PARAM_INTEGER=1", result["stdout"])
+ self.assertIn("PARAM_FLOAT=2.55", result["stdout"])
+ self.assertIn("PARAM_BOOLEAN=1", result["stdout"])
+ self.assertIn("PARAM_LIST=a,b,c", result["stdout"])
+ self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"])
+
+ output_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(output_dbs), 6)
- self.assertEqual(output_dbs[0].data, 'PARAM_STRING=test string\n')
+ self.assertEqual(output_dbs[0].data, "PARAM_STRING=test string\n")
self.assertEqual(output_dbs[5].data, 'PARAM_OBJECT={"foo": "bar"}\n')
- output_dbs = ActionExecutionOutput.query(output_type='stderr')
+ output_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(output_dbs), 0)
- @mock.patch('st2common.util.concurrency.subprocess_popen')
- @mock.patch('st2common.util.concurrency.spawn')
+ @mock.patch("st2common.util.concurrency.subprocess_popen")
+ @mock.patch("st2common.util.concurrency.spawn")
def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen):
# Feature is enabled
- cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True)
+ cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True)
# Note: We need to mock spawn function so we can test everything in single event loop
# iteration
@@ -523,40 +537,41 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop
# No output to stdout and no result (implicit None)
mock_stdout = [
- 'stdout line 1\n',
- 'stdout line 2\n',
- 'stdout line 3\n',
- 'stdout line 4\n'
- ]
- mock_stderr = [
- 'stderr line 1\n',
- 'stderr line 2\n',
- 'stderr line 3\n'
+ "stdout line 1\n",
+ "stdout line 2\n",
+ "stdout line 3\n",
+ "stdout line 4\n",
]
+ mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"]
mock_process = mock.Mock()
mock_process.returncode = 0
mock_popen.return_value = mock_process
mock_process.stdout.closed = False
mock_process.stderr.closed = False
- mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout,
- stop_counter=4)
- mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr,
- stop_counter=3)
+ mock_process.stdout.readline = make_mock_stream_readline(
+ mock_process.stdout, mock_stdout, stop_counter=4
+ )
+ mock_process.stderr.readline = make_mock_stream_readline(
+ mock_process.stderr, mock_stderr, stop_counter=3
+ )
models = self.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']})
- action_db = models['actions']['local_script_with_params.yaml']
- entry_point = os.path.join(get_fixtures_base_path(),
- 'generic/actions/local_script_with_params.sh')
+ fixtures_pack="generic",
+ fixtures_dict={"actions": ["local_script_with_params.yaml"]},
+ )
+ action_db = models["actions"]["local_script_with_params.yaml"]
+ entry_point = os.path.join(
+ get_fixtures_base_path(), "generic/actions/local_script_with_params.sh"
+ )
action_parameters = {
- 'param_string': 'test string',
- 'param_integer': 1,
- 'param_float': 2.55,
- 'param_boolean': True,
- 'param_list': ['a', 'b', 'c'],
- 'param_object': {'foo': 'bar'}
+ "param_string": "test string",
+ "param_integer": 1,
+ "param_float": 2.55,
+ "param_boolean": True,
+ "param_list": ["a", "b", "c"],
+ "param_object": {"foo": "bar"},
}
runner = self._get_runner(action_db=action_db, entry_point=entry_point)
@@ -564,20 +579,24 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop
status, result, _ = runner.run(action_parameters=action_parameters)
runner.post_run(status, result)
- self.assertEqual(result['stdout'],
- 'stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4')
- self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3')
- self.assertEqual(result['return_code'], 0)
+ self.assertEqual(
+ result["stdout"],
+ "stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4",
+ )
+ self.assertEqual(
+ result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3"
+ )
+ self.assertEqual(result["return_code"], 0)
# Verify stdout and stderr lines have been correctly stored in the db
- output_dbs = ActionExecutionOutput.query(output_type='stdout')
+ output_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(output_dbs), 4)
self.assertEqual(output_dbs[0].data, mock_stdout[0])
self.assertEqual(output_dbs[1].data, mock_stdout[1])
self.assertEqual(output_dbs[2].data, mock_stdout[2])
self.assertEqual(output_dbs[3].data, mock_stdout[3])
- output_dbs = ActionExecutionOutput.query(output_type='stderr')
+ output_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(output_dbs), 3)
self.assertEqual(output_dbs[0].data, mock_stderr[0])
self.assertEqual(output_dbs[1].data, mock_stderr[1])
@@ -585,30 +604,36 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop
def test_shell_script_action(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']})
- action_db = models['actions']['text_gen.yml']
+ fixtures_pack="localrunner_pack",
+ fixtures_dict={"actions": ["text_gen.yml"]},
+ )
+ action_db = models["actions"]["text_gen.yml"]
entry_point = self.fixtures_loader.get_fixture_file_path_abs(
- 'localrunner_pack', 'actions', 'text_gen.py')
+ "localrunner_pack", "actions", "text_gen.py"
+ )
runner = self._get_runner(action_db, entry_point=entry_point)
runner.pre_run()
- status, result, _ = runner.run({'chars': 1000})
+ status, result, _ = runner.run({"chars": 1000})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(len(result['stdout']), 1000)
+ self.assertEqual(len(result["stdout"]), 1000)
def test_large_stdout(self):
models = self.fixtures_loader.load_models(
- fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']})
- action_db = models['actions']['text_gen.yml']
+ fixtures_pack="localrunner_pack",
+ fixtures_dict={"actions": ["text_gen.yml"]},
+ )
+ action_db = models["actions"]["text_gen.yml"]
entry_point = self.fixtures_loader.get_fixture_file_path_abs(
- 'localrunner_pack', 'actions', 'text_gen.py')
+ "localrunner_pack", "actions", "text_gen.py"
+ )
runner = self._get_runner(action_db, entry_point=entry_point)
runner.pre_run()
char_count = 10 ** 6 # Note 10^7 succeeds but ends up being slow.
- status, result, _ = runner.run({'chars': char_count})
+ status, result, _ = runner.run({"chars": char_count})
runner.post_run(status, result)
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(len(result['stdout']), char_count)
+ self.assertEqual(len(result["stdout"]), char_count)
def _get_runner(self, action_db, entry_point):
runner = LocalShellScriptRunner(uuid.uuid4().hex)
@@ -622,5 +647,5 @@ def _get_runner(self, action_db, entry_point):
runner.callback = dict()
runner.libs_dir_path = None
runner.auth_token = mock.Mock()
- runner.auth_token.token = 'mock-token'
+ runner.auth_token.token = "mock-token"
return runner
diff --git a/contrib/runners/noop_runner/dist_utils.py b/contrib/runners/noop_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/noop_runner/dist_utils.py
+++ b/contrib/runners/noop_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/noop_runner/noop_runner/__init__.py b/contrib/runners/noop_runner/noop_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/noop_runner/noop_runner/__init__.py
+++ b/contrib/runners/noop_runner/noop_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/noop_runner/noop_runner/noop_runner.py b/contrib/runners/noop_runner/noop_runner/noop_runner.py
index 0eb745218a..b4dda10fd5 100644
--- a/contrib/runners/noop_runner/noop_runner/noop_runner.py
+++ b/contrib/runners/noop_runner/noop_runner/noop_runner.py
@@ -22,12 +22,7 @@
from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED
import st2common.util.jsonify as jsonify
-__all__ = [
- 'NoopRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["NoopRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
@@ -36,7 +31,8 @@ class NoopRunner(ActionRunner):
"""
Runner which does absolutely nothing. No-op action.
"""
- KEYS_TO_TRANSFORM = ['stdout', 'stderr']
+
+ KEYS_TO_TRANSFORM = ["stdout", "stderr"]
def __init__(self, runner_id):
super(NoopRunner, self).__init__(runner_id=runner_id)
@@ -46,14 +42,15 @@ def pre_run(self):
def run(self, action_parameters):
- LOG.info('Executing action via NoopRunner: %s', self.runner_id)
- LOG.info('[Action info] name: %s, Id: %s',
- self.action_name, str(self.execution_id))
+ LOG.info("Executing action via NoopRunner: %s", self.runner_id)
+ LOG.info(
+ "[Action info] name: %s, Id: %s", self.action_name, str(self.execution_id)
+ )
result = {
- 'failed': False,
- 'succeeded': True,
- 'return_code': 0,
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
}
status = LIVEACTION_STATUS_SUCCEEDED
@@ -65,4 +62,4 @@ def get_runner():
def get_metadata():
- return get_runner_metadata('noop_runner')[0]
+ return get_runner_metadata("noop_runner")[0]
diff --git a/contrib/runners/noop_runner/setup.py b/contrib/runners/noop_runner/setup.py
index 30b00bd68b..94b518c55f 100644
--- a/contrib/runners/noop_runner/setup.py
+++ b/contrib/runners/noop_runner/setup.py
@@ -26,30 +26,30 @@
from noop_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-noop',
+ name="stackstorm-runner-noop",
version=__version__,
- description=('No-Op action runner for StackStorm event-driven automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=("No-Op action runner for StackStorm event-driven automation platform"),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'noop_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"noop_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'noop = noop_runner.noop_runner',
+ "st2common.runners.runner": [
+ "noop = noop_runner.noop_runner",
],
- }
+ },
)
diff --git a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py
index 6783404ffb..98c66c33cd 100644
--- a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py
+++ b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py
@@ -19,6 +19,7 @@
import mock
import st2tests.config as tests_config
+
tests_config.parse_args()
from unittest2 import TestCase
@@ -33,16 +34,17 @@ class TestNoopRunner(TestCase):
def test_noop_command_executes(self):
models = TestNoopRunner.fixtures_loader.load_models(
- fixtures_pack='generic', fixtures_dict={'actions': ['noop.yaml']})
+ fixtures_pack="generic", fixtures_dict={"actions": ["noop.yaml"]}
+ )
- action_db = models['actions']['noop.yaml']
+ action_db = models["actions"]["noop.yaml"]
runner = TestNoopRunner._get_runner(action_db)
status, result, _ = runner.run({})
self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(result['failed'], False)
- self.assertEqual(result['succeeded'], True)
- self.assertEqual(result['return_code'], 0)
+ self.assertEqual(result["failed"], False)
+ self.assertEqual(result["succeeded"], True)
+ self.assertEqual(result["return_code"], 0)
@staticmethod
def _get_runner(action_db):
@@ -55,5 +57,5 @@ def _get_runner(action_db):
runner.callback = dict()
runner.libs_dir_path = None
runner.auth_token = mock.Mock()
- runner.auth_token.token = 'mock-token'
+ runner.auth_token.token = "mock-token"
return runner
diff --git a/contrib/runners/orquesta_runner/dist_utils.py b/contrib/runners/orquesta_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/orquesta_runner/dist_utils.py
+++ b/contrib/runners/orquesta_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py
index f5986392d7..e71dcafd40 100644
--- a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py
+++ b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py
@@ -33,15 +33,15 @@ def format_task_result(instances):
instance = instances[-1]
return {
- 'task_execution_id': str(instance.id),
- 'workflow_execution_id': instance.workflow_execution,
- 'task_name': instance.task_id,
- 'task_id': instance.task_id,
- 'route': instance.task_route,
- 'result': instance.result,
- 'status': instance.status,
- 'start_timestamp': str(instance.start_timestamp),
- 'end_timestamp': str(instance.end_timestamp)
+ "task_execution_id": str(instance.id),
+ "workflow_execution_id": instance.workflow_execution,
+ "task_name": instance.task_id,
+ "task_id": instance.task_id,
+ "route": instance.task_route,
+ "result": instance.result,
+ "status": instance.status,
+ "start_timestamp": str(instance.start_timestamp),
+ "end_timestamp": str(instance.end_timestamp),
}
@@ -54,17 +54,17 @@ def task(context, task_id=None, route=None):
current_task = {}
if task_id is None:
- task_id = current_task['id']
+ task_id = current_task["id"]
if route is None:
- route = current_task.get('route', 0)
+ route = current_task.get("route", 0)
try:
- workflow_state = context['__state'] or {}
+ workflow_state = context["__state"] or {}
except KeyError:
workflow_state = {}
- task_state_pointers = workflow_state.get('tasks') or {}
+ task_state_pointers = workflow_state.get("tasks") or {}
task_state_entry_uid = constants.TASK_STATE_ROUTE_FORMAT % (task_id, str(route))
task_state_entry_idx = task_state_pointers.get(task_state_entry_uid)
@@ -72,9 +72,11 @@ def task(context, task_id=None, route=None):
# use an earlier route before the split to find the specific task.
if task_state_entry_idx is None:
if route > 0:
- current_route_details = workflow_state['routes'][route]
+ current_route_details = workflow_state["routes"][route]
# Reverse the list because we want to start with the next longest route.
- for idx, prev_route_details in enumerate(reversed(workflow_state['routes'][:route])):
+ for idx, prev_route_details in enumerate(
+ reversed(workflow_state["routes"][:route])
+ ):
if len(set(prev_route_details) - set(current_route_details)) == 0:
# The index is from a reversed list so need to calculate
# the index of the item in the list before the reverse.
@@ -83,17 +85,15 @@ def task(context, task_id=None, route=None):
else:
# Otherwise, get the task flow entry and use the
# task id and route to query the database.
- task_state_seqs = workflow_state.get('sequence') or []
+ task_state_seqs = workflow_state.get("sequence") or []
task_state_entry = task_state_seqs[task_state_entry_idx]
- route = task_state_entry['route']
- st2_ctx = context['__vars']['st2']
- workflow_execution_id = st2_ctx['workflow_execution_id']
+ route = task_state_entry["route"]
+ st2_ctx = context["__vars"]["st2"]
+ workflow_execution_id = st2_ctx["workflow_execution_id"]
# Query the database by the workflow execution ID, task ID, and task route.
instances = wf_db_access.TaskExecution.query(
- workflow_execution=workflow_execution_id,
- task_id=task_id,
- task_route=route
+ workflow_execution=workflow_execution_id, task_id=task_id, task_route=route
)
if not instances:
diff --git a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py
index 35cae92cd7..ed23507a1b 100644
--- a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py
+++ b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py
@@ -29,26 +29,28 @@
def st2kv_(context, key, **kwargs):
if not isinstance(key, six.string_types):
- raise TypeError('Given key is not typeof string.')
+ raise TypeError("Given key is not typeof string.")
- decrypt = kwargs.get('decrypt', False)
+ decrypt = kwargs.get("decrypt", False)
if not isinstance(decrypt, bool):
- raise TypeError('Decrypt parameter is not typeof bool.')
+ raise TypeError("Decrypt parameter is not typeof bool.")
try:
- username = context['__vars']['st2']['user']
+ username = context["__vars"]["st2"]["user"]
except KeyError:
- raise KeyError('Could not get user from context.')
+ raise KeyError("Could not get user from context.")
try:
user_db = auth_db_access.User.get(username)
except Exception as e:
- raise Exception('Failed to retrieve User object for user "%s", "%s"' %
- (username, six.text_type(e)))
+ raise Exception(
+ 'Failed to retrieve User object for user "%s", "%s"'
+ % (username, six.text_type(e))
+ )
- has_default = 'default' in kwargs
- default_value = kwargs.get('default')
+ has_default = "default" in kwargs
+ default_value = kwargs.get("default")
try:
return kvp_util.get_key(key=key, user_db=user_db, decrypt=decrypt)
diff --git a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py
+++ b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py
index b59642609e..62f2492ae4 100644
--- a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py
+++ b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py
@@ -37,71 +37,72 @@
from st2common.util import api as api_util
from st2common.util import ujson
-__all__ = [
- 'OrquestaRunner',
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["OrquestaRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
class OrquestaRunner(runners.AsyncActionRunner):
-
@staticmethod
def get_workflow_definition(entry_point):
- with open(entry_point, 'r') as def_file:
+ with open(entry_point, "r") as def_file:
return def_file.read()
def _get_notify_config(self):
return (
- notify_api_models.NotificationsHelper.from_model(notify_model=self.liveaction.notify)
+ notify_api_models.NotificationsHelper.from_model(
+ notify_model=self.liveaction.notify
+ )
if self.liveaction.notify
else None
)
def _construct_context(self, wf_ex):
ctx = ujson.fast_deepcopy(self.context)
- ctx['workflow_execution'] = str(wf_ex.id)
+ ctx["workflow_execution"] = str(wf_ex.id)
return ctx
def _construct_st2_context(self):
st2_ctx = {
- 'st2': {
- 'action_execution_id': str(self.execution.id),
- 'api_url': api_util.get_full_public_api_url(),
- 'user': self.execution.context.get('user', cfg.CONF.system_user.user),
- 'pack': self.execution.context.get('pack', None),
- 'action': self.execution.action.get('ref', None),
- 'runner': self.execution.action.get('runner_type', None)
+ "st2": {
+ "action_execution_id": str(self.execution.id),
+ "api_url": api_util.get_full_public_api_url(),
+ "user": self.execution.context.get("user", cfg.CONF.system_user.user),
+ "pack": self.execution.context.get("pack", None),
+ "action": self.execution.action.get("ref", None),
+ "runner": self.execution.action.get("runner_type", None),
}
}
- if self.execution.context.get('api_user'):
- st2_ctx['st2']['api_user'] = self.execution.context.get('api_user')
+ if self.execution.context.get("api_user"):
+ st2_ctx["st2"]["api_user"] = self.execution.context.get("api_user")
- if self.execution.context.get('source_channel'):
- st2_ctx['st2']['source_channel'] = self.execution.context.get('source_channel')
+ if self.execution.context.get("source_channel"):
+ st2_ctx["st2"]["source_channel"] = self.execution.context.get(
+ "source_channel"
+ )
if self.execution.context:
- st2_ctx['parent'] = self.execution.context
+ st2_ctx["parent"] = self.execution.context
return st2_ctx
def _handle_workflow_return_value(self, wf_ex_db):
if wf_ex_db.status in wf_statuses.COMPLETED_STATUSES:
status = wf_ex_db.status
- result = {'output': wf_ex_db.output or None}
+ result = {"output": wf_ex_db.output or None}
if wf_ex_db.status in wf_statuses.ABENDED_STATUSES:
- result['errors'] = wf_ex_db.errors
+ result["errors"] = wf_ex_db.errors
for wf_ex_error in wf_ex_db.errors:
- msg = 'Workflow execution completed with errors.'
- wf_svc.update_progress(wf_ex_db, '%s %s' % (msg, str(wf_ex_error)), log=False)
- LOG.error('[%s] %s', str(self.execution.id), msg, extra=wf_ex_error)
+ msg = "Workflow execution completed with errors."
+ wf_svc.update_progress(
+ wf_ex_db, "%s %s" % (msg, str(wf_ex_error)), log=False
+ )
+ LOG.error("[%s] %s", str(self.execution.id), msg, extra=wf_ex_error)
return (status, result, self.context)
@@ -115,8 +116,8 @@ def _handle_workflow_return_value(self, wf_ex_db):
def run(self, action_parameters):
# If there is an action execution reference for rerun and there is task specified,
# then rerun the existing workflow execution.
- rerun_options = self.context.get('re-run', {})
- rerun_task_options = rerun_options.get('tasks', [])
+ rerun_options = self.context.get("re-run", {})
+ rerun_task_options = rerun_options.get("tasks", [])
if self.rerun_ex_ref and rerun_task_options:
return self.rerun_workflow(self.rerun_ex_ref, options=rerun_options)
@@ -131,14 +132,16 @@ def start_workflow(self, action_parameters):
# Request workflow execution.
st2_ctx = self._construct_st2_context()
notify_cfg = self._get_notify_config()
- wf_ex_db = wf_svc.request(wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg)
+ wf_ex_db = wf_svc.request(
+ wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg
+ )
except wf_exc.WorkflowInspectionError as e:
status = ac_const.LIVEACTION_STATUS_FAILED
- result = {'errors': e.args[1], 'output': None}
+ result = {"errors": e.args[1], "output": None}
return (status, result, self.context)
except Exception as e:
status = ac_const.LIVEACTION_STATUS_FAILED
- result = {'errors': [{'message': six.text_type(e)}], 'output': None}
+ result = {"errors": [{"message": six.text_type(e)}], "output": None}
return (status, result, self.context)
return self._handle_workflow_return_value(wf_ex_db)
@@ -146,13 +149,13 @@ def start_workflow(self, action_parameters):
def rerun_workflow(self, ac_ex_ref, options=None):
try:
# Request rerun of workflow execution.
- wf_ex_id = ac_ex_ref.context.get('workflow_execution')
+ wf_ex_id = ac_ex_ref.context.get("workflow_execution")
st2_ctx = self._construct_st2_context()
- st2_ctx['workflow_execution_id'] = wf_ex_id
+ st2_ctx["workflow_execution_id"] = wf_ex_id
wf_ex_db = wf_svc.request_rerun(self.execution, st2_ctx, options=options)
except Exception as e:
status = ac_const.LIVEACTION_STATUS_FAILED
- result = {'errors': [{'message': six.text_type(e)}], 'output': None}
+ result = {"errors": [{"message": six.text_type(e)}], "output": None}
return (status, result, self.context)
return self._handle_workflow_return_value(wf_ex_db)
@@ -160,8 +163,8 @@ def rerun_workflow(self, ac_ex_ref, options=None):
@staticmethod
def task_pauseable(ac_ex):
wf_ex_pauseable = (
- ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and
- ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING
+ ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES
+ and ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING
)
return wf_ex_pauseable
@@ -175,26 +178,24 @@ def pause(self):
child_ex = ex_db_access.ActionExecution.get(id=child_ex_id)
if self.task_pauseable(child_ex):
ac_svc.request_pause(
- lv_db_access.LiveAction.get(id=child_ex.liveaction['id']),
- self.context.get('user', None)
+ lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]),
+ self.context.get("user", None),
)
- if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active(self.liveaction.id):
+ if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active(
+ self.liveaction.id
+ ):
status = ac_const.LIVEACTION_STATUS_PAUSING
else:
status = ac_const.LIVEACTION_STATUS_PAUSED
- return (
- status,
- self.liveaction.result,
- self.liveaction.context
- )
+ return (status, self.liveaction.result, self.liveaction.context)
@staticmethod
def task_resumeable(ac_ex):
wf_ex_resumeable = (
- ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and
- ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED
+ ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES
+ and ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED
)
return wf_ex_resumeable
@@ -208,26 +209,26 @@ def resume(self):
child_ex = ex_db_access.ActionExecution.get(id=child_ex_id)
if self.task_resumeable(child_ex):
ac_svc.request_resume(
- lv_db_access.LiveAction.get(id=child_ex.liveaction['id']),
- self.context.get('user', None)
+ lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]),
+ self.context.get("user", None),
)
return (
wf_ex_db.status if wf_ex_db else ac_const.LIVEACTION_STATUS_RUNNING,
self.liveaction.result,
- self.liveaction.context
+ self.liveaction.context,
)
@staticmethod
def task_cancelable(ac_ex):
wf_ex_cancelable = (
- ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and
- ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES
+ ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES
+ and ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES
)
ac_ex_cancelable = (
- ac_ex.runner['name'] not in ac_const.WORKFLOW_RUNNER_TYPES and
- ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES
+ ac_ex.runner["name"] not in ac_const.WORKFLOW_RUNNER_TYPES
+ and ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES
)
return wf_ex_cancelable or ac_ex_cancelable
@@ -242,8 +243,10 @@ def cancel(self):
# If workflow execution is not found because the action execution is cancelled
# before the workflow execution is created or if the workflow execution is
# already completed, then ignore the exception and proceed with cancellation.
- except (wf_svc_exc.WorkflowExecutionNotFoundException,
- wf_svc_exc.WorkflowExecutionIsCompletedException):
+ except (
+ wf_svc_exc.WorkflowExecutionNotFoundException,
+ wf_svc_exc.WorkflowExecutionIsCompletedException,
+ ):
pass
# If there is an unknown exception, then log the error. Continue with the
# cancelation sequence below to cancel children and determine final status.
@@ -253,19 +256,22 @@ def cancel(self):
# execution will be in an unknown state.
except Exception:
_, ex, tb = sys.exc_info()
- msg = 'Error encountered when canceling workflow execution.'
- LOG.exception('[%s] %s', str(self.execution.id), msg)
- msg = 'Error encountered when canceling workflow execution. %s'
+ msg = "Error encountered when canceling workflow execution."
+ LOG.exception("[%s] %s", str(self.execution.id), msg)
+ msg = "Error encountered when canceling workflow execution. %s"
wf_svc.update_progress(wf_ex_db, msg % str(ex), log=False)
- result = {'error': msg % str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}
+ result = {
+ "error": msg % str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ }
# Request cancellation of tasks that are workflows and still running.
for child_ex_id in self.execution.children:
child_ex = ex_db_access.ActionExecution.get(id=child_ex_id)
if self.task_cancelable(child_ex):
ac_svc.request_cancellation(
- lv_db_access.LiveAction.get(id=child_ex.liveaction['id']),
- self.context.get('user', None)
+ lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]),
+ self.context.get("user", None),
)
status = (
@@ -277,7 +283,7 @@ def cancel(self):
return (
status,
result if result else self.liveaction.result,
- self.liveaction.context
+ self.liveaction.context,
)
@@ -286,4 +292,4 @@ def get_runner():
def get_metadata():
- return runners.get_metadata('orquesta_runner')[0]
+ return runners.get_metadata("orquesta_runner")[0]
diff --git a/contrib/runners/orquesta_runner/setup.py b/contrib/runners/orquesta_runner/setup.py
index 5dac5ed34e..859a8b6050 100644
--- a/contrib/runners/orquesta_runner/setup.py
+++ b/contrib/runners/orquesta_runner/setup.py
@@ -26,62 +26,64 @@
from orquesta_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-orquesta',
+ name="stackstorm-runner-orquesta",
version=__version__,
- description='Orquesta workflow runner for StackStorm event-driven automation platform',
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="Orquesta workflow runner for StackStorm event-driven automation platform",
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'orquesta_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"orquesta_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'orquesta = orquesta_runner.orquesta_runner',
+ "st2common.runners.runner": [
+ "orquesta = orquesta_runner.orquesta_runner",
],
- 'orquesta.expressions.functions': [
- 'st2kv = orquesta_functions.st2kv:st2kv_',
- 'task = orquesta_functions.runtime:task',
- 'basename = st2common.expressions.functions.path:basename',
- 'dirname = st2common.expressions.functions.path:dirname',
- 'from_json_string = st2common.expressions.functions.data:from_json_string',
- 'from_yaml_string = st2common.expressions.functions.data:from_yaml_string',
- 'json_dump = st2common.expressions.functions.data:to_json_string',
- 'json_parse = st2common.expressions.functions.data:from_json_string',
- 'json_escape = st2common.expressions.functions.data:json_escape',
- 'jsonpath_query = st2common.expressions.functions.data:jsonpath_query',
- 'regex_match = st2common.expressions.functions.regex:regex_match',
- 'regex_replace = st2common.expressions.functions.regex:regex_replace',
- 'regex_search = st2common.expressions.functions.regex:regex_search',
- 'regex_substring = st2common.expressions.functions.regex:regex_substring',
- ('to_human_time_from_seconds = '
- 'st2common.expressions.functions.time:to_human_time_from_seconds'),
- 'to_json_string = st2common.expressions.functions.data:to_json_string',
- 'to_yaml_string = st2common.expressions.functions.data:to_yaml_string',
- 'use_none = st2common.expressions.functions.data:use_none',
- 'version_compare = st2common.expressions.functions.version:version_compare',
- 'version_more_than = st2common.expressions.functions.version:version_more_than',
- 'version_less_than = st2common.expressions.functions.version:version_less_than',
- 'version_equal = st2common.expressions.functions.version:version_equal',
- 'version_match = st2common.expressions.functions.version:version_match',
- 'version_bump_major = st2common.expressions.functions.version:version_bump_major',
- 'version_bump_minor = st2common.expressions.functions.version:version_bump_minor',
- 'version_bump_patch = st2common.expressions.functions.version:version_bump_patch',
- 'version_strip_patch = st2common.expressions.functions.version:version_strip_patch',
- 'yaml_dump = st2common.expressions.functions.data:to_yaml_string',
- 'yaml_parse = st2common.expressions.functions.data:from_yaml_string'
+ "orquesta.expressions.functions": [
+ "st2kv = orquesta_functions.st2kv:st2kv_",
+ "task = orquesta_functions.runtime:task",
+ "basename = st2common.expressions.functions.path:basename",
+ "dirname = st2common.expressions.functions.path:dirname",
+ "from_json_string = st2common.expressions.functions.data:from_json_string",
+ "from_yaml_string = st2common.expressions.functions.data:from_yaml_string",
+ "json_dump = st2common.expressions.functions.data:to_json_string",
+ "json_parse = st2common.expressions.functions.data:from_json_string",
+ "json_escape = st2common.expressions.functions.data:json_escape",
+ "jsonpath_query = st2common.expressions.functions.data:jsonpath_query",
+ "regex_match = st2common.expressions.functions.regex:regex_match",
+ "regex_replace = st2common.expressions.functions.regex:regex_replace",
+ "regex_search = st2common.expressions.functions.regex:regex_search",
+ "regex_substring = st2common.expressions.functions.regex:regex_substring",
+ (
+ "to_human_time_from_seconds = "
+ "st2common.expressions.functions.time:to_human_time_from_seconds"
+ ),
+ "to_json_string = st2common.expressions.functions.data:to_json_string",
+ "to_yaml_string = st2common.expressions.functions.data:to_yaml_string",
+ "use_none = st2common.expressions.functions.data:use_none",
+ "version_compare = st2common.expressions.functions.version:version_compare",
+ "version_more_than = st2common.expressions.functions.version:version_more_than",
+ "version_less_than = st2common.expressions.functions.version:version_less_than",
+ "version_equal = st2common.expressions.functions.version:version_equal",
+ "version_match = st2common.expressions.functions.version:version_match",
+ "version_bump_major = st2common.expressions.functions.version:version_bump_major",
+ "version_bump_minor = st2common.expressions.functions.version:version_bump_minor",
+ "version_bump_patch = st2common.expressions.functions.version:version_bump_patch",
+ "version_strip_patch = st2common.expressions.functions.version:version_strip_patch",
+ "yaml_dump = st2common.expressions.functions.data:to_yaml_string",
+ "yaml_parse = st2common.expressions.functions.data:from_yaml_string",
],
- }
+ },
)
diff --git a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py
index 8734bce072..0e273f6e83 100644
--- a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py
+++ b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py
@@ -21,78 +21,67 @@
class DatastoreFunctionTest(base.TestWorkflowExecution):
@classmethod
- def set_kvp(cls, name, value, scope='system', secret=False):
+ def set_kvp(cls, name, value, scope="system", secret=False):
kvp = models.KeyValuePair(
- id=name,
- name=name,
- value=value,
- scope=scope,
- secret=secret
+ id=name, name=name, value=value, scope=scope, secret=secret
)
cls.st2client.keys.update(kvp)
@classmethod
- def del_kvp(cls, name, scope='system'):
- kvp = models.KeyValuePair(
- id=name,
- name=name,
- scope=scope
- )
+ def del_kvp(cls, name, scope="system"):
+ kvp = models.KeyValuePair(id=name, name=name, scope=scope)
cls.st2client.keys.delete(kvp)
def test_st2kv_system_scope(self):
- key = 'lakshmi'
- value = 'kanahansnasnasdlsajks'
+ key = "lakshmi"
+ value = "kanahansnasnasdlsajks"
self.set_kvp(key, value)
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {'key_name': 'system.%s' % key}
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": "system.%s" % key}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
self.del_kvp(key)
def test_st2kv_user_scope(self):
- key = 'winson'
- value = 'SoDiamondEng'
+ key = "winson"
+ value = "SoDiamondEng"
- self.set_kvp(key, value, 'user')
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {'key_name': key}
+ self.set_kvp(key, value, "user")
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": key}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
# self.del_kvp(key)
def test_st2kv_decrypt(self):
- key = 'kami'
- value = 'eggplant'
+ key = "kami"
+ value = "eggplant"
self.set_kvp(key, value, secret=True)
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True
- }
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
self.del_kvp(key)
diff --git a/contrib/runners/orquesta_runner/tests/unit/base.py b/contrib/runners/orquesta_runner/tests/unit/base.py
index dbd2895721..d3e518fab7 100644
--- a/contrib/runners/orquesta_runner/tests/unit/base.py
+++ b/contrib/runners/orquesta_runner/tests/unit/base.py
@@ -19,13 +19,13 @@
def get_wf_fixture_meta_data(fixture_pack_path, wf_meta_file_name):
- wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name
+ wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name
wf_meta_content = loader.load_meta_file(wf_meta_file_path)
- wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name']
+ wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"]
return {
- 'file_name': wf_meta_file_name,
- 'file_path': wf_meta_file_path,
- 'content': wf_meta_content,
- 'name': wf_name
+ "file_name": wf_meta_file_name,
+ "file_path": wf_meta_file_path,
+ "content": wf_meta_content,
+ "name": wf_name,
}
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_basic.py b/contrib/runners/orquesta_runner/tests/unit/test_basic.py
index 7fc2255ed2..5f5c60a012 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_basic.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_basic.py
@@ -26,6 +26,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -51,37 +52,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerTest, cls).setUpClass()
@@ -91,8 +100,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -103,14 +111,15 @@ def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name, runner_name).__class__
@mock.patch.object(
- runners_utils,
- 'invoke_post_run',
- mock.MagicMock(return_value=None))
+ runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)
+ )
def test_run_workflow(self):
- username = 'stanley'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ username = "stanley"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# The main action execution for this workflow is not under the context of another workflow.
@@ -120,9 +129,13 @@ def test_run_workflow(self):
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertTrue(lv_ac_db.action_is_workflow)
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
wf_ex_db = wf_ex_dbs[0]
# Check required attributes.
@@ -134,26 +147,24 @@ def test_run_workflow(self):
# Check context in the workflow execution.
expected_wf_ex_ctx = {
- 'st2': {
- 'workflow_execution_id': str(wf_ex_db.id),
- 'action_execution_id': str(ac_ex_db.id),
- 'api_url': 'http://127.0.0.1/v1',
- 'user': username,
- 'pack': 'orquesta_tests',
- 'action': 'orquesta_tests.sequential',
- 'runner': 'orquesta'
+ "st2": {
+ "workflow_execution_id": str(wf_ex_db.id),
+ "action_execution_id": str(ac_ex_db.id),
+ "api_url": "http://127.0.0.1/v1",
+ "user": username,
+ "pack": "orquesta_tests",
+ "action": "orquesta_tests.sequential",
+ "runner": "orquesta",
},
- 'parent': {
- 'pack': 'orquesta_tests'
- }
+ "parent": {"pack": "orquesta_tests"},
}
self.assertDictEqual(wf_ex_db.context, expected_wf_ex_ctx)
# Check context in the liveaction.
expected_lv_ac_ctx = {
- 'workflow_execution': str(wf_ex_db.id),
- 'pack': 'orquesta_tests'
+ "workflow_execution": str(wf_ex_db.id),
+ "pack": "orquesta_tests",
}
self.assertDictEqual(lv_ac_db.context, expected_lv_ac_ctx)
@@ -161,24 +172,26 @@ def test_run_workflow(self):
# Check graph.
self.assertIsNotNone(wf_ex_db.graph)
self.assertIsInstance(wf_ex_db.graph, dict)
- self.assertIn('nodes', wf_ex_db.graph)
- self.assertIn('adjacency', wf_ex_db.graph)
+ self.assertIn("nodes", wf_ex_db.graph)
+ self.assertIn("adjacency", wf_ex_db.graph)
# Check task states.
self.assertIsNotNone(wf_ex_db.state)
self.assertIsInstance(wf_ex_db.state, dict)
- self.assertIn('tasks', wf_ex_db.state)
- self.assertIn('sequence', wf_ex_db.state)
+ self.assertIn("tasks", wf_ex_db.state)
+ self.assertIn("sequence", wf_ex_db.state)
# Check input.
self.assertDictEqual(wf_ex_db.input, wf_input)
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.context.get('user'), username)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(tk1_lv_ac_db.context.get("user"), username)
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db))
@@ -192,11 +205,13 @@ def test_run_workflow(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task2 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
- self.assertEqual(tk2_lv_ac_db.context.get('user'), username)
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
+ self.assertEqual(tk2_lv_ac_db.context.get("user"), username)
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk2_ac_ex_db))
@@ -210,11 +225,13 @@ def test_run_workflow(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task3 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
- self.assertEqual(tk3_lv_ac_db.context.get('user'), username)
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
+ self.assertEqual(tk3_lv_ac_db.context.get("user"), username)
self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk3_ac_ex_db))
@@ -234,48 +251,60 @@ def test_run_workflow(self):
self.assertEqual(runners_utils.invoke_post_run.call_count, 1)
# Check workflow output.
- expected_output = {'msg': '%s, All your base are belong to us!' % wf_input['who']}
+ expected_output = {
+ "msg": "%s, All your base are belong to us!" % wf_input["who"]
+ }
self.assertDictEqual(wf_ex_db.output, expected_output)
# Check liveaction and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(lv_ac_db.result, expected_result)
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_run_workflow_with_unicode_input(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': '薩諾斯'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "薩諾斯"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
# Process task2.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id)
self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED)
# Process task3.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk3_ac_ex_db)
tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id)
@@ -290,33 +319,41 @@ def test_run_workflow_with_unicode_input(self):
self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Check workflow output.
- wf_input_val = wf_input['who'].decode('utf-8') if six.PY2 else wf_input['who']
- expected_output = {'msg': '%s, All your base are belong to us!' % wf_input_val}
+ wf_input_val = wf_input["who"].decode("utf-8") if six.PY2 else wf_input["who"]
+ expected_output = {"msg": "%s, All your base are belong to us!" % wf_input_val}
self.assertDictEqual(wf_ex_db.output, expected_output)
# Check liveaction and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(lv_ac_db.result, expected_result)
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_run_workflow_action_config_context(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'config-context.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "config-context.yaml")
wf_input = {}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db))
@@ -332,59 +369,77 @@ def test_run_workflow_action_config_context(self):
self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Verify config_context works
- self.assertEqual(wf_ex_db.output, {'msg': 'value of config key a'})
+ self.assertEqual(wf_ex_db.output, {"msg": "value of config key a"})
def test_run_workflow_with_action_less_tasks(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'action-less-tasks.yaml')
- wf_input = {'name': 'Thanos'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "action-less-tasks.yaml"
+ )
+ wf_input = {"name": "Thanos"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))
+ tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )
self.assertEqual(len(tk1_ac_ex_dbs), 0)
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
# Assert task2 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
# Assert task3 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
wf_svc.handle_action_execution_completion(tk3_ac_ex_db)
# Assert task4 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task4'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task4"}
tk4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk4_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk4_ex_db.id))
+ tk4_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk4_ex_db.id)
+ )
self.assertEqual(len(tk4_ac_ex_dbs), 0)
self.assertEqual(tk4_ex_db.status, wf_statuses.SUCCEEDED)
# Assert task5 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task5'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task5"}
tk5_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk5_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk5_ex_db.id))[0]
- tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction['id'])
+ tk5_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk5_ex_db.id)
+ )[0]
+ tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction["id"])
self.assertEqual(tk5_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
@@ -399,65 +454,95 @@ def test_run_workflow_with_action_less_tasks(self):
self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Check workflow output.
- expected_output = {'greeting': '%s, All your base are belong to us!' % wf_input['name']}
- expected_output['greeting'] = expected_output['greeting'].upper()
+ expected_output = {
+ "greeting": "%s, All your base are belong to us!" % wf_input["name"]
+ }
+ expected_output["greeting"] = expected_output["greeting"].upper()
self.assertDictEqual(wf_ex_db.output, expected_output)
# Check liveaction and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(lv_ac_db.result, expected_result)
self.assertDictEqual(ac_ex_db.result, expected_result)
@mock.patch.object(
- pc_svc, 'apply_post_run_policies',
- mock.MagicMock(return_value=None))
+ pc_svc, "apply_post_run_policies", mock.MagicMock(return_value=None)
+ )
def test_handle_action_execution_completion(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
# Identify the records for the tasks.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Manually notify action execution completion for the tasks.
# Assert policies are not applied in the notifier.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
notifier.get_notifier().process(t1_t1_ac_ex_db)
self.assertFalse(pc_svc.apply_post_run_policies.called)
- t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))
+ t1_tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )
self.assertEqual(len(t1_tk_ex_dbs), 1)
workflows.get_engine().process(t1_t1_ac_ex_db)
self.assertTrue(pc_svc.apply_post_run_policies.called)
pc_svc.apply_post_run_policies.reset_mock()
- t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[1]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
notifier.get_notifier().process(t1_t2_ac_ex_db)
self.assertFalse(pc_svc.apply_post_run_policies.called)
- t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))
+ t1_tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )
self.assertEqual(len(t1_tk_ex_dbs), 2)
workflows.get_engine().process(t1_t2_ac_ex_db)
self.assertTrue(pc_svc.apply_post_run_policies.called)
pc_svc.apply_post_run_policies.reset_mock()
- t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[2]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
notifier.get_notifier().process(t1_t3_ac_ex_db)
self.assertFalse(pc_svc.apply_post_run_policies.called)
- t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))
+ t1_tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )
self.assertEqual(len(t1_tk_ex_dbs), 3)
workflows.get_engine().process(t1_t3_ac_ex_db)
self.assertTrue(pc_svc.apply_post_run_policies.called)
@@ -466,19 +551,25 @@ def test_handle_action_execution_completion(self):
t1_ac_ex_db = ex_db_access.ActionExecution.get_by_id(t1_ac_ex_db.id)
notifier.get_notifier().process(t1_ac_ex_db)
self.assertFalse(pc_svc.apply_post_run_policies.called)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
workflows.get_engine().process(t1_ac_ex_db)
self.assertTrue(pc_svc.apply_post_run_policies.called)
pc_svc.apply_post_run_policies.reset_mock()
- t2_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ t2_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
t2_ex_db = wf_db_access.TaskExecution.query(**t2_ex_db_qry)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
self.assertEqual(t2_ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
notifier.get_notifier().process(t2_ac_ex_db)
self.assertFalse(pc_svc.apply_post_run_policies.called)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
workflows.get_engine().process(t2_ac_ex_db)
self.assertTrue(pc_svc.apply_post_run_policies.called)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py
index 145bd1f3b4..b49fd0f77b 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py
@@ -24,6 +24,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -45,37 +46,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerCancelTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerCancelTest, cls).setUpClass()
@@ -85,8 +94,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -96,15 +104,15 @@ def setUpClass(cls):
def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name, runner_name).__class__
- @mock.patch.object(
- ac_svc, 'is_children_active',
- mock.MagicMock(return_value=True))
+ @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True))
def test_cancel(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
requester = cfg.CONF.system_user.user
lv_ac_db, ac_ex_db = ac_svc.request_cancellation(lv_ac_db, requester)
@@ -112,23 +120,33 @@ def test_cancel(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING)
def test_cancel_workflow_cascade_down_to_subworkflow(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the subworkflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ac_ex_dbs), 1)
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id'])
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the main workflow.
@@ -145,23 +163,33 @@ def test_cancel_workflow_cascade_down_to_subworkflow(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
def test_cancel_subworkflow_cascade_up_to_workflow(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the subworkflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ac_ex_dbs), 1)
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id'])
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the subworkflow.
@@ -183,34 +211,50 @@ def test_cancel_subworkflow_cascade_up_to_workflow(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the subworkflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
- tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))
+ tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )
self.assertEqual(len(tk1_ac_ex_dbs), 1)
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_dbs[0].liveaction['id'])
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk1_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))
+ tk2_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )
self.assertEqual(len(tk2_ac_ex_dbs), 1)
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_dbs[0].liveaction['id'])
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk2_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the subworkflow which should cascade up to the root.
requester = cfg.CONF.system_user.user
- tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation(tk1_lv_ac_db, requester)
+ tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation(
+ tk1_lv_ac_db, requester
+ )
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING)
# Assert the main workflow is canceling.
@@ -239,15 +283,21 @@ def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self)
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
def test_cancel_before_wf_ex_db_created(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Delete the workfow execution to mock issue where the record has not been created yet.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False, dispatch_trigger=False)
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ wf_db_access.WorkflowExecution.delete(
+ wf_ex_db, publish=False, dispatch_trigger=False
+ )
# Cancel the action execution.
requester = cfg.CONF.system_user.user
@@ -256,15 +306,19 @@ def test_cancel_before_wf_ex_db_created(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
def test_cancel_after_wf_ex_db_completed(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Delete the workfow execution to mock issue where the workflow is already completed
# but the liveaction and action execution have not had time to be updated.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
wf_ex_db.status = wf_ex_statuses.SUCCEEDED
wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False)
@@ -275,14 +329,16 @@ def test_cancel_after_wf_ex_db_completed(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
@mock.patch.object(
- wf_svc, 'request_cancellation',
- mock.MagicMock(side_effect=Exception('foobar')))
+ wf_svc, "request_cancellation", mock.MagicMock(side_effect=Exception("foobar"))
+ )
def test_cancel_unexpected_exception(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Cancel the action execution.
requester = cfg.CONF.system_user.user
@@ -297,4 +353,6 @@ def test_cancel_unexpected_exception(self):
# to raise an exception and the records will be stuck in a canceling
# status and user is unable to easily clean up.
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED)
- self.assertIn('Error encountered when canceling', lv_ac_db.result.get('error', ''))
+ self.assertIn(
+ "Error encountered when canceling", lv_ac_db.result.get("error", "")
+ )
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_context.py b/contrib/runners/orquesta_runner/tests/unit/test_context.py
index 373f512e87..bce5a50873 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_context.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_context.py
@@ -24,6 +24,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -43,37 +44,45 @@
from st2tests.mocks import liveaction as mock_lv_ac_xport
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaContextTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaContextTest, cls).setUpClass()
@@ -83,24 +92,31 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def test_runtime_context(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'runtime-context.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "runtime-context.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
# Complete the worklfow.
wf_svc.handle_action_execution_completion(t1_ac_ex_db)
@@ -113,59 +129,75 @@ def test_runtime_context(self):
# Check result.
expected_st2_ctx = {
- 'action_execution_id': str(ac_ex_db.id),
- 'api_url': 'http://127.0.0.1/v1',
- 'user': 'stanley',
- 'pack': 'orquesta_tests',
- 'action': 'orquesta_tests.runtime-context',
- 'runner': 'orquesta'
+ "action_execution_id": str(ac_ex_db.id),
+ "api_url": "http://127.0.0.1/v1",
+ "user": "stanley",
+ "pack": "orquesta_tests",
+ "action": "orquesta_tests.runtime-context",
+ "runner": "orquesta",
}
expected_st2_ctx_with_wf_ex_id = copy.deepcopy(expected_st2_ctx)
- expected_st2_ctx_with_wf_ex_id['workflow_execution_id'] = str(wf_ex_db.id)
+ expected_st2_ctx_with_wf_ex_id["workflow_execution_id"] = str(wf_ex_db.id)
expected_output = {
- 'st2_ctx_at_input': expected_st2_ctx,
- 'st2_ctx_at_vars': expected_st2_ctx,
- 'st2_ctx_at_publish': expected_st2_ctx_with_wf_ex_id,
- 'st2_ctx_at_output': expected_st2_ctx_with_wf_ex_id
+ "st2_ctx_at_input": expected_st2_ctx,
+ "st2_ctx_at_vars": expected_st2_ctx,
+ "st2_ctx_at_publish": expected_st2_ctx_with_wf_ex_id,
+ "st2_ctx_at_output": expected_st2_ctx_with_wf_ex_id,
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(lv_ac_db.result, expected_result)
def test_action_context_sys_user(self):
- wf_name = 'subworkflow-default-value-from-action-context'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_name = "subworkflow-default-value-from-action-context"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Complete subworkflow under task1.
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"}
t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"}
t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"}
t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db)
t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id))
@@ -184,44 +216,60 @@ def test_action_context_sys_user(self):
# Check result.
expected_result = {
- 'output': {
- 'msg': 'stanley, All your base are belong to us!'
- }
+ "output": {"msg": "stanley, All your base are belong to us!"}
}
self.assertDictEqual(lv_ac_db.result, expected_result)
def test_action_context_api_user(self):
- wf_name = 'subworkflow-default-value-from-action-context'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], context={'api_user': 'Thanos'})
+ wf_name = "subworkflow-default-value-from-action-context"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], context={"api_user": "Thanos"}
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Complete subworkflow under task1.
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"}
t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"}
t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"}
t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db)
t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id))
@@ -239,45 +287,57 @@ def test_action_context_api_user(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Check result.
- expected_result = {
- 'output': {
- 'msg': 'Thanos, All your base are belong to us!'
- }
- }
+ expected_result = {"output": {"msg": "Thanos, All your base are belong to us!"}}
self.assertDictEqual(lv_ac_db.result, expected_result)
def test_action_context_no_channel(self):
- wf_name = 'subworkflow-source-channel-from-action-context'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_name = "subworkflow-source-channel-from-action-context"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Complete subworkflow under task1.
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"}
t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"}
t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"}
t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db)
t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id))
@@ -296,45 +356,60 @@ def test_action_context_no_channel(self):
# Check result.
expected_result = {
- 'output': {
- 'msg': 'no_channel, All your base are belong to us!'
- }
+ "output": {"msg": "no_channel, All your base are belong to us!"}
}
self.assertDictEqual(lv_ac_db.result, expected_result)
def test_action_context_source_channel(self):
- wf_name = 'subworkflow-source-channel-from-action-context'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'],
- context={'source_channel': 'general'})
+ wf_name = "subworkflow-source-channel-from-action-context"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], context={"source_channel": "general"}
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Complete subworkflow under task1.
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"}
t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"}
t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db)
- query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"}
t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db)
t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id))
@@ -353,9 +428,7 @@ def test_action_context_source_channel(self):
# Check result.
expected_result = {
- 'output': {
- 'msg': 'general, All your base are belong to us!'
- }
+ "output": {"msg": "general, All your base are belong to us!"}
}
self.assertDictEqual(lv_ac_db.result, expected_result)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py
index 00d26f0155..d1c0c249ab 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py
@@ -26,6 +26,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -47,37 +48,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerTest, cls).setUpClass()
@@ -87,8 +96,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -99,22 +107,30 @@ def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name, runner_name).__class__
def assert_data_flow(self, data):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'data-flow.yaml')
- wf_input = {'a1': data}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "data-flow.yaml")
+ wf_input = {"a1": data}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
@@ -127,10 +143,12 @@ def assert_data_flow(self, data):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task2 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
@@ -143,10 +161,12 @@ def assert_data_flow(self, data):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task3 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion.
@@ -164,20 +184,20 @@ def assert_data_flow(self, data):
# Check workflow output.
expected_output = {
- 'a5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8'),
- 'b5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8')
+ "a5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"),
+ "b5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"),
}
self.assertDictEqual(wf_ex_db.output, expected_output)
# Check liveaction and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(lv_ac_db.result, expected_result)
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_string(self):
- self.assert_data_flow('xyz')
+ self.assert_data_flow("xyz")
def test_unicode_string(self):
- self.assert_data_flow('床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉')
+ self.assert_data_flow("床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉")
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_delay.py b/contrib/runners/orquesta_runner/tests/unit/test_delay.py
index 66834f9952..d2535c8f03 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_delay.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_delay.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -43,37 +44,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerDelayTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerDelayTest, cls).setUpClass()
@@ -83,8 +92,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -94,17 +102,25 @@ def test_delay(self):
expected_delay_sec = 1
expected_delay_msec = expected_delay_sec * 1000
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'delay.yaml')
- wf_input = {'delay': expected_delay_sec}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "delay.yaml")
+ wf_input = {"delay": expected_delay_sec}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
- lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING)
+ lv_ac_db = self._wait_on_status(
+ lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Identify records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0]
# Assert delay value is rendered and assigned.
@@ -116,20 +132,28 @@ def test_delay_for_with_items(self):
expected_delay_sec = 1
expected_delay_msec = expected_delay_sec * 1000
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-delay.yaml')
- wf_input = {'delay': expected_delay_sec}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items-delay.yaml")
+ wf_input = {"delay": expected_delay_sec}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
- lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ lv_ac_db = self._wait_on_status(
+ lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process the with items task.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))
# Assert delay value is rendered and assigned.
@@ -166,20 +190,30 @@ def test_delay_for_with_items_concurrency(self):
expected_delay_sec = 1
expected_delay_msec = expected_delay_sec * 1000
- wf_input = {'concurrency': concurrency, 'delay': expected_delay_sec}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency-delay.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"concurrency": concurrency, "delay": expected_delay_sec}
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency-delay.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
- lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ lv_ac_db = self._wait_on_status(
+ lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process the first set of action executions from with items concurrency.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))
# Assert the number of concurrent items is correct.
@@ -211,7 +245,9 @@ def test_delay_for_with_items_concurrency(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Process the second set of action executions from with items concurrency.
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))
# Assert delay value is rendered and assigned only to the first set of action executions.
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py
index 6f140040ca..d06d335993 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py
@@ -24,6 +24,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -48,41 +49,50 @@
from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaErrorHandlingTest(st2tests.WorkflowTestCase):
ensure_indexes = True
ensure_indexes_models = [
WorkflowExecutionDB,
TaskExecutionDB,
- ActionExecutionSchedulingQueueItemDB
+ ActionExecutionSchedulingQueueItemDB,
]
@classmethod
@@ -94,8 +104,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -104,78 +113,86 @@ def setUpClass(cls):
def test_fail_inspection(self):
expected_errors = [
{
- 'type': 'content',
- 'message': 'The action "std.noop" is not registered in the database.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action',
- 'spec_path': 'tasks.task3.action'
+ "type": "content",
+ "message": 'The action "std.noop" is not registered in the database.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action",
+ "spec_path": "tasks.task3.action",
},
{
- 'type': 'context',
- 'language': 'yaql',
- 'expression': '<% ctx().foobar %>',
- 'message': 'Variable "foobar" is referenced before assignment.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input',
- 'spec_path': 'tasks.task1.input',
+ "type": "context",
+ "language": "yaql",
+ "expression": "<% ctx().foobar %>",
+ "message": 'Variable "foobar" is referenced before assignment.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input",
+ "spec_path": "tasks.task1.input",
},
{
- 'type': 'expression',
- 'language': 'yaql',
- 'expression': '<% <% succeeded() %>',
- 'message': (
- 'Parse error: unexpected \'<\' at '
- 'position 0 of expression \'<% succeeded()\''
+ "type": "expression",
+ "language": "yaql",
+ "expression": "<% <% succeeded() %>",
+ "message": (
+ "Parse error: unexpected '<' at "
+ "position 0 of expression '<% succeeded()'"
),
- 'schema_path': (
- r'properties.tasks.patternProperties.^\w+$.'
- 'properties.next.items.properties.when'
+ "schema_path": (
+ r"properties.tasks.patternProperties.^\w+$."
+ "properties.next.items.properties.when"
),
- 'spec_path': 'tasks.task2.next[0].when'
+ "spec_path": "tasks.task2.next[0].when",
},
{
- 'type': 'syntax',
- 'message': (
- '[{\'cmd\': \'echo <% ctx().macro %>\'}] is '
- 'not valid under any of the given schemas'
+ "type": "syntax",
+ "message": (
+ "[{'cmd': 'echo <% ctx().macro %>'}] is "
+ "not valid under any of the given schemas"
),
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf',
- 'spec_path': 'tasks.task2.input'
- }
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf",
+ "spec_path": "tasks.task2.input",
+ },
]
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-inspection.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertIn('errors', lv_ac_db.result)
- self.assertListEqual(lv_ac_db.result['errors'], expected_errors)
+ self.assertIn("errors", lv_ac_db.result)
+ self.assertListEqual(lv_ac_db.result["errors"], expected_errors)
def test_fail_input_rendering(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(4).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(4).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-input-rendering.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-input-rendering.yaml"
+ )
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution for task is not started and workflow failed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 0)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -188,28 +205,36 @@ def test_fail_input_rendering(self):
def test_fail_vars_rendering(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(4).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(4).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-vars-rendering.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-vars-rendering.yaml"
+ )
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution for task is not started and workflow failed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 0)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -222,30 +247,38 @@ def test_fail_vars_rendering(self):
def test_fail_start_task_action(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% ctx().func.value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
),
- 'task_id': 'task1',
- 'route': 0
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-start-task-action.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-start-task-action.yaml"
+ )
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution for task is not started and workflow failed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 0)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -258,31 +291,37 @@ def test_fail_start_task_action(self):
def test_fail_start_task_input_expr_eval(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% ctx().msg1.value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% ctx().msg1.value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
),
- 'task_id': 'task1',
- 'route': 0
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_file = 'fail-start-task-input-expr-eval.yaml'
+ wf_file = "fail-start-task-input-expr-eval.yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution for task is not started and workflow failed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 0)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -294,37 +333,40 @@ def test_fail_start_task_input_expr_eval(self):
def test_fail_start_task_input_value_type(self):
if six.PY3:
- msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".'
+ msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"."
else:
- msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".'
+ msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"."
- msg = 'ValueError: ' + msg
+ msg = "ValueError: " + msg
expected_errors = [
- {
- 'type': 'error',
- 'message': msg,
- 'task_id': 'task1',
- 'route': 0
- }
+ {"type": "error", "message": msg, "task_id": "task1", "route": 0}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_file = 'fail-start-task-input-value-type.yaml'
+ wf_file = "fail-start-task-input-value-type.yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- wf_input = {'var1': {'x': 'foobar'}}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"var1": {"x": "foobar"}}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert workflow and task executions failed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
self.assertEqual(tk_ex_db.status, wf_statuses.FAILED)
- self.assertDictEqual(tk_ex_db.result, {'errors': expected_errors})
+ self.assertDictEqual(tk_ex_db.result, {"errors": expected_errors})
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -337,29 +379,35 @@ def test_fail_start_task_input_value_type(self):
def test_fail_next_task_action(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% ctx().func.value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
),
- 'task_id': 'task2',
- 'route': 0
+ "task_id": "task2",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-action.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-task-action.yaml")
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion for task1 which has an error in publish.
@@ -370,7 +418,9 @@ def test_fail_next_task_action(self):
self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -383,29 +433,37 @@ def test_fail_next_task_action(self):
def test_fail_next_task_input_expr_eval(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% ctx().msg2.value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% ctx().msg2.value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
),
- 'task_id': 'task2',
- 'route': 0
+ "task_id": "task2",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-input-expr-eval.yaml')
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-task-input-expr-eval.yaml"
+ )
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion for task1 which has an error in publish.
@@ -416,7 +474,9 @@ def test_fail_next_task_input_expr_eval(self):
self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -428,34 +488,37 @@ def test_fail_next_task_input_expr_eval(self):
def test_fail_next_task_input_value_type(self):
if six.PY3:
- msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".'
+ msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"."
else:
- msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".'
+ msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"."
- msg = 'ValueError: ' + msg
+ msg = "ValueError: " + msg
expected_errors = [
- {
- 'type': 'error',
- 'message': msg,
- 'task_id': 'task2',
- 'route': 0
- }
+ {"type": "error", "message": msg, "task_id": "task2", "route": 0}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_file = 'fail-task-input-value-type.yaml'
+ wf_file = "fail-task-input-value-type.yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- wf_input = {'var1': {'x': 'foobar'}}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"var1": {"x": "foobar"}}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed and workflow execution is still running.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
@@ -465,11 +528,13 @@ def test_fail_next_task_input_value_type(self):
# Assert workflow execution and task2 execution failed.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id))
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
- tk2_ex_db = wf_db_access.TaskExecution.query(task_id='task2')[0]
+ tk2_ex_db = wf_db_access.TaskExecution.query(task_id="task2")[0]
self.assertEqual(tk2_ex_db.status, wf_statuses.FAILED)
- self.assertDictEqual(tk2_ex_db.result, {'errors': expected_errors})
+ self.assertDictEqual(tk2_ex_db.result, {"errors": expected_errors})
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -482,37 +547,47 @@ def test_fail_next_task_input_value_type(self):
def test_fail_task_execution(self):
expected_errors = [
{
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'task_id': 'task1',
- 'result': {
- 'stdout': '',
- 'stderr': 'boom!',
- 'return_code': 1,
- 'failed': True,
- 'succeeded': False
- }
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "task_id": "task1",
+ "result": {
+ "stdout": "",
+ "stderr": "boom!",
+ "return_code": 1,
+ "failed": True,
+ "succeeded": False,
+ },
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-execution.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-task-execution.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Process task1.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
# Assert workflow state and result.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id))
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -525,28 +600,36 @@ def test_fail_task_execution(self):
def test_fail_task_transition(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
+ "type": "error",
+ "message": (
"YaqlEvaluationException: Unable to resolve key 'foobar' in expression "
"'<% succeeded() and result().foobar %>' from context."
),
- 'task_transition_id': 'task2__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "task2__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-transition.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-task-transition.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion for task1 which has an error in publish.
@@ -557,7 +640,9 @@ def test_fail_task_transition(self):
self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -570,29 +655,37 @@ def test_fail_task_transition(self):
def test_fail_task_publish(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% foobar() %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% foobar() %>'. NoFunctionRegisteredException: "
'Unknown function "foobar"'
),
- 'task_transition_id': 'task2__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "task2__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-publish.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-task-publish.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion for task1 which has an error in publish.
@@ -603,7 +696,9 @@ def test_fail_task_publish(self):
self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -616,26 +711,34 @@ def test_fail_task_publish(self):
def test_fail_output_rendering(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(4).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(4).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- expected_result = {'output': None, 'errors': expected_errors}
+ expected_result = {"output": None, "errors": expected_errors}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-output-rendering.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-output-rendering.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert task1 is already completed.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually handle action execution completion for task1 which has an error in publish.
@@ -646,7 +749,9 @@ def test_fail_output_rendering(self):
self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
@@ -657,50 +762,51 @@ def test_fail_output_rendering(self):
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_output_on_error(self):
- expected_output = {
- 'progress': 25
- }
+ expected_output = {"progress": 25}
expected_errors = [
{
- 'type': 'error',
- 'task_id': 'task2',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "type": "error",
+ "task_id": "task2",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
}
]
- expected_result = {
- 'errors': expected_errors,
- 'output': expected_output
- }
+ expected_result = {"errors": expected_errors, "output": expected_output}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'output-on-error.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "output-on-error.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Assert task1 is already completed and workflow execution is still running.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task2 is already completed and workflow execution has failed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
@@ -718,26 +824,32 @@ def test_output_on_error(self):
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_fail_manually(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-manually.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-manually.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Assert task1 and workflow execution failed due to fail in the task transition.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.FAILED)
# Assert log task is scheduled even though the workflow execution failed manually.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'log'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "log"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
@@ -746,38 +858,44 @@ def test_fail_manually(self):
# Check errors and output.
expected_errors = [
{
- 'task_id': 'fail',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.'
+ "task_id": "fail",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
},
{
- 'task_id': 'task1',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
- }
+ "task_id": "task1",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
+ },
]
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
def test_fail_manually_with_recovery_failure(self):
- wf_file = 'fail-manually-with-recovery-failure.yaml'
+ wf_file = "fail-manually-with-recovery-failure.yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Assert task1 and workflow execution failed due to fail in the task transition.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
@@ -785,10 +903,12 @@ def test_fail_manually_with_recovery_failure(self):
# Assert recover task is scheduled even though the workflow execution failed manually.
# The recover task in the workflow is setup to fail.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'recover'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "recover"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
@@ -797,61 +917,70 @@ def test_fail_manually_with_recovery_failure(self):
# Check errors and output.
expected_errors = [
{
- 'task_id': 'fail',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.'
+ "task_id": "fail",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
},
{
- 'task_id': 'recover',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "task_id": "recover",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
},
{
- 'task_id': 'task1',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
- }
+ "task_id": "task1",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
+ },
]
- self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors)
+ self.assertListEqual(
+ self.sort_workflow_errors(wf_ex_db.errors), expected_errors
+ )
@mock.patch.object(
- runners_utils,
- 'invoke_post_run',
- mock.MagicMock(return_value=None))
+ runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)
+ )
def test_include_result_to_error_log(self):
- username = 'stanley'
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ username = "stanley"
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
wf_ex_db = wf_ex_dbs[0]
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.context.get('user'), username)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(tk1_lv_ac_db.context.get("user"), username)
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Manually override and fail the action execution and write some result.
@@ -862,11 +991,13 @@ def test_include_result_to_error_log(self):
tk1_lv_ac_db,
ac_const.LIVEACTION_STATUS_FAILED,
result=result,
- publish=False
+ publish=False,
)
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED)
self.assertDictEqual(tk1_lv_ac_db.result, result)
@@ -882,14 +1013,10 @@ def test_include_result_to_error_log(self):
# Assert result is included in the error log.
expected_errors = [
{
- 'message': 'Execution failed. See result for details.',
- 'type': 'error',
- 'task_id': 'task1',
- 'result': {
- '127.0.0.1': {
- 'hostname': 'foobar'
- }
- }
+ "message": "Execution failed. See result for details.",
+ "type": "error",
+ "task_id": "task1",
+ "result": {"127.0.0.1": {"hostname": "foobar"}},
}
]
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py
index d8c416f13a..faa92bd03a 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -44,37 +45,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaFunctionTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaFunctionTest, cls).setUpClass()
@@ -84,30 +93,35 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def _execute_workflow(self, wf_name, expected_output):
- wf_file = wf_name + '.yaml'
+ wf_file = wf_name + ".yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert task1 is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db))
@@ -123,149 +137,139 @@ def _execute_workflow(self, wf_name, expected_output):
self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Check workflow output, liveaction result, and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertDictEqual(wf_ex_db.output, expected_output)
self.assertDictEqual(lv_ac_db.result, expected_result)
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_data_functions_in_yaql(self):
- wf_name = 'yaql-data-functions'
+ wf_name = "yaql-data-functions"
expected_output = {
- 'data_json_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_2': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_3': '{"foo": {"bar": "foobar"}}',
- 'data_json_obj_1': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_2': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_3': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_4': {'foo': {'bar': 'foobar'}},
- 'data_yaml_str_1': 'foo:\n bar: foobar\n',
- 'data_yaml_str_2': 'foo:\n bar: foobar\n',
- 'data_query_1': ['foobar'],
- 'data_none_str': data_funcs.NONE_MAGIC_VALUE,
- 'data_str': 'foobar'
+ "data_json_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_2": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_3": '{"foo": {"bar": "foobar"}}',
+ "data_json_obj_1": {"foo": {"bar": "foobar"}},
+ "data_json_obj_2": {"foo": {"bar": "foobar"}},
+ "data_json_obj_3": {"foo": {"bar": "foobar"}},
+ "data_json_obj_4": {"foo": {"bar": "foobar"}},
+ "data_yaml_str_1": "foo:\n bar: foobar\n",
+ "data_yaml_str_2": "foo:\n bar: foobar\n",
+ "data_query_1": ["foobar"],
+ "data_none_str": data_funcs.NONE_MAGIC_VALUE,
+ "data_str": "foobar",
}
self._execute_workflow(wf_name, expected_output)
def test_data_functions_in_jinja(self):
- wf_name = 'jinja-data-functions'
+ wf_name = "jinja-data-functions"
expected_output = {
- 'data_json_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_2': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_3': '{"foo": {"bar": "foobar"}}',
- 'data_json_obj_1': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_2': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_3': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_4': {'foo': {'bar': 'foobar'}},
- 'data_yaml_str_1': 'foo:\n bar: foobar\n',
- 'data_yaml_str_2': 'foo:\n bar: foobar\n',
- 'data_query_1': ['foobar'],
- 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_none_str': data_funcs.NONE_MAGIC_VALUE,
- 'data_str': 'foobar',
- 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n'
+ "data_json_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_2": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_3": '{"foo": {"bar": "foobar"}}',
+ "data_json_obj_1": {"foo": {"bar": "foobar"}},
+ "data_json_obj_2": {"foo": {"bar": "foobar"}},
+ "data_json_obj_3": {"foo": {"bar": "foobar"}},
+ "data_json_obj_4": {"foo": {"bar": "foobar"}},
+ "data_yaml_str_1": "foo:\n bar: foobar\n",
+ "data_yaml_str_2": "foo:\n bar: foobar\n",
+ "data_query_1": ["foobar"],
+ "data_pipe_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_none_str": data_funcs.NONE_MAGIC_VALUE,
+ "data_str": "foobar",
+ "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n",
}
self._execute_workflow(wf_name, expected_output)
def test_path_functions_in_yaql(self):
- wf_name = 'yaql-path-functions'
+ wf_name = "yaql-path-functions"
- expected_output = {
- 'basename': 'file.txt',
- 'dirname': '/path/to/some'
- }
+ expected_output = {"basename": "file.txt", "dirname": "/path/to/some"}
self._execute_workflow(wf_name, expected_output)
def test_path_functions_in_jinja(self):
- wf_name = 'jinja-path-functions'
+ wf_name = "jinja-path-functions"
- expected_output = {
- 'basename': 'file.txt',
- 'dirname': '/path/to/some'
- }
+ expected_output = {"basename": "file.txt", "dirname": "/path/to/some"}
self._execute_workflow(wf_name, expected_output)
def test_regex_functions_in_yaql(self):
- wf_name = 'yaql-regex-functions'
+ wf_name = "yaql-regex-functions"
expected_output = {
- 'match': True,
- 'replace': 'wxyz',
- 'search': True,
- 'substring': '668 Infinite Dr'
+ "match": True,
+ "replace": "wxyz",
+ "search": True,
+ "substring": "668 Infinite Dr",
}
self._execute_workflow(wf_name, expected_output)
def test_regex_functions_in_jinja(self):
- wf_name = 'jinja-regex-functions'
+ wf_name = "jinja-regex-functions"
expected_output = {
- 'match': True,
- 'replace': 'wxyz',
- 'search': True,
- 'substring': '668 Infinite Dr'
+ "match": True,
+ "replace": "wxyz",
+ "search": True,
+ "substring": "668 Infinite Dr",
}
self._execute_workflow(wf_name, expected_output)
def test_time_functions_in_yaql(self):
- wf_name = 'yaql-time-functions'
+ wf_name = "yaql-time-functions"
- expected_output = {
- 'time': '3h25m45s'
- }
+ expected_output = {"time": "3h25m45s"}
self._execute_workflow(wf_name, expected_output)
def test_time_functions_in_jinja(self):
- wf_name = 'jinja-time-functions'
+ wf_name = "jinja-time-functions"
- expected_output = {
- 'time': '3h25m45s'
- }
+ expected_output = {"time": "3h25m45s"}
self._execute_workflow(wf_name, expected_output)
def test_version_functions_in_yaql(self):
- wf_name = 'yaql-version-functions'
+ wf_name = "yaql-version-functions"
expected_output = {
- 'compare_equal': 0,
- 'compare_more_than': -1,
- 'compare_less_than': 1,
- 'equal': True,
- 'more_than': False,
- 'less_than': False,
- 'match': True,
- 'bump_major': '1.0.0',
- 'bump_minor': '0.11.0',
- 'bump_patch': '0.10.1',
- 'strip_patch': '0.10'
+ "compare_equal": 0,
+ "compare_more_than": -1,
+ "compare_less_than": 1,
+ "equal": True,
+ "more_than": False,
+ "less_than": False,
+ "match": True,
+ "bump_major": "1.0.0",
+ "bump_minor": "0.11.0",
+ "bump_patch": "0.10.1",
+ "strip_patch": "0.10",
}
self._execute_workflow(wf_name, expected_output)
def test_version_functions_in_jinja(self):
- wf_name = 'jinja-version-functions'
+ wf_name = "jinja-version-functions"
expected_output = {
- 'compare_equal': 0,
- 'compare_more_than': -1,
- 'compare_less_than': 1,
- 'equal': True,
- 'more_than': False,
- 'less_than': False,
- 'match': True,
- 'bump_major': '1.0.0',
- 'bump_minor': '0.11.0',
- 'bump_patch': '0.10.1',
- 'strip_patch': '0.10'
+ "compare_equal": 0,
+ "compare_more_than": -1,
+ "compare_less_than": 1,
+ "equal": True,
+ "more_than": False,
+ "less_than": False,
+ "match": True,
+ "bump_major": "1.0.0",
+ "bump_minor": "0.11.0",
+ "bump_patch": "0.10.1",
+ "strip_patch": "0.10",
}
self._execute_workflow(wf_name, expected_output)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py
index 846afa19f0..3004857bee 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from orquesta_functions import st2kv
@@ -37,14 +38,13 @@
from st2common.util import keyvalue as kvp_util
-MOCK_CTX = {'__vars': {'st2': {'user': 'stanley'}}}
-MOCK_CTX_NO_USER = {'__vars': {'st2': {}}}
+MOCK_CTX = {"__vars": {"st2": {"user": "stanley"}}}
+MOCK_CTX_NO_USER = {"__vars": {"st2": {}}}
class DatastoreFunctionTest(unittest2.TestCase):
-
def test_missing_user_context(self):
- self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, 'foo')
+ self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, "foo")
def test_invalid_input(self):
self.assertRaises(TypeError, st2kv.st2kv_, None, 123)
@@ -55,35 +55,29 @@ def test_invalid_input(self):
class UserScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(UserScopeDatastoreFunctionTest, cls).setUpClass()
- user = auth_db.UserDB(name='stanley')
+ user = auth_db.UserDB(name="stanley")
user.save()
scope = kvp_const.FULL_USER_SCOPE
cls.kvps = {}
# Plain keys
- keys = {
- 'stanley:foo': 'bar',
- 'stanley:foo_empty': '',
- 'stanley:foo_null': None
- }
+ keys = {"stanley:foo": "bar", "stanley:foo_empty": "", "stanley:foo_null": None}
for k, v in six.iteritems(keys):
instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope)
cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance)
# Secret key
- keys = {
- 'stanley:fu': 'bar',
- 'stanley:fu_empty': ''
- }
+ keys = {"stanley:fu": "bar", "stanley:fu_empty": ""}
for k, v in six.iteritems(keys):
value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v)
- instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True)
+ instance = kvp_db.KeyValuePairDB(
+ name=k, value=value, scope=scope, secret=True
+ )
cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance)
@classmethod
@@ -94,9 +88,9 @@ def tearDownClass(cls):
super(UserScopeDatastoreFunctionTest, cls).tearDownClass()
def test_key_exists(self):
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo'), 'bar')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo_empty'), '')
- self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foo_null'))
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo"), "bar")
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo_empty"), "")
+ self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foo_null"))
def test_key_does_not_exist(self):
self.assertRaisesRegexp(
@@ -104,65 +98,61 @@ def test_key_does_not_exist(self):
'The key ".*" does not exist in the StackStorm datastore.',
st2kv.st2kv_,
MOCK_CTX,
- 'foobar'
+ "foobar",
)
def test_key_does_not_exist_but_return_default(self):
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default='foosball'), 'foosball')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default=''), '')
- self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foobar', default=None))
+ self.assertEqual(
+ st2kv.st2kv_(MOCK_CTX, "foobar", default="foosball"), "foosball"
+ )
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foobar", default=""), "")
+ self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foobar", default=None))
def test_key_decrypt(self):
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu'), 'bar')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=False), 'bar')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=True), 'bar')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty'), '')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=False), '')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=True), '')
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu"), "bar")
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=False), "bar")
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=True), "bar")
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty"), "")
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=False), "")
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=True), "")
@mock.patch.object(
- kvp_util, 'get_key',
- mock.MagicMock(side_effect=Exception('Mock failure.')))
+ kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure."))
+ )
def test_get_key_exception(self):
self.assertRaisesRegexp(
exc.ExpressionEvaluationException,
- 'Mock failure.',
+ "Mock failure.",
st2kv.st2kv_,
MOCK_CTX,
- 'foo'
+ "foo",
)
class SystemScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(SystemScopeDatastoreFunctionTest, cls).setUpClass()
- user = auth_db.UserDB(name='stanley')
+ user = auth_db.UserDB(name="stanley")
user.save()
scope = kvp_const.FULL_SYSTEM_SCOPE
cls.kvps = {}
# Plain key
- keys = {
- 'foo': 'bar',
- 'foo_empty': '',
- 'foo_null': None
- }
+ keys = {"foo": "bar", "foo_empty": "", "foo_null": None}
for k, v in six.iteritems(keys):
instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope)
cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance)
# Secret key
- keys = {
- 'fu': 'bar',
- 'fu_empty': ''
- }
+ keys = {"fu": "bar", "fu_empty": ""}
for k, v in six.iteritems(keys):
value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v)
- instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True)
+ instance = kvp_db.KeyValuePairDB(
+ name=k, value=value, scope=scope, secret=True
+ )
cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance)
@classmethod
@@ -173,9 +163,9 @@ def tearDownClass(cls):
super(SystemScopeDatastoreFunctionTest, cls).tearDownClass()
def test_key_exists(self):
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo'), 'bar')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo_empty'), '')
- self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foo_null'))
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo"), "bar")
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo_empty"), "")
+ self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foo_null"))
def test_key_does_not_exist(self):
self.assertRaisesRegexp(
@@ -183,30 +173,34 @@ def test_key_does_not_exist(self):
'The key ".*" does not exist in the StackStorm datastore.',
st2kv.st2kv_,
MOCK_CTX,
- 'foo'
+ "foo",
)
def test_key_does_not_exist_but_return_default(self):
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default='foosball'), 'foosball')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=''), '')
- self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=None))
+ self.assertEqual(
+ st2kv.st2kv_(MOCK_CTX, "system.foobar", default="foosball"), "foosball"
+ )
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=""), "")
+ self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=None))
def test_key_decrypt(self):
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu'), 'bar')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=False), 'bar')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=True), 'bar')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty'), '')
- self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=False), '')
- self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=True), '')
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu"), "bar")
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=False), "bar")
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=True), "bar")
+ self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty"), "")
+ self.assertNotEqual(
+ st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=False), ""
+ )
+ self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=True), "")
@mock.patch.object(
- kvp_util, 'get_key',
- mock.MagicMock(side_effect=Exception('Mock failure.')))
+ kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure."))
+ )
def test_get_key_exception(self):
self.assertRaisesRegexp(
exc.ExpressionEvaluationException,
- 'Mock failure.',
+ "Mock failure.",
st2kv.st2kv_,
MOCK_CTX,
- 'system.foo'
+ "system.foo",
)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py
index 146e7ee39e..46ffb861e3 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -43,37 +44,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaFunctionTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaFunctionTest, cls).setUpClass()
@@ -83,42 +92,57 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
- def _execute_workflow(self, wf_name, expected_task_sequence, expected_output,
- expected_status=wf_statuses.SUCCEEDED, expected_errors=None):
- wf_file = wf_name + '.yaml'
+ def _execute_workflow(
+ self,
+ wf_name,
+ expected_task_sequence,
+ expected_output,
+ expected_status=wf_statuses.SUCCEEDED,
+ expected_errors=None,
+ ):
+ wf_file = wf_name + ".yaml"
wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
for task_id, route in expected_task_sequence:
tk_ex_dbs = wf_db_access.TaskExecution.query(
- workflow_execution=str(wf_ex_db.id),
- task_id=task_id,
- task_route=route
+ workflow_execution=str(wf_ex_db.id), task_id=task_id, task_route=route
)
if len(tk_ex_dbs) <= 0:
break
- tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[len(tk_ex_dbs) - 1]
- tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[
+ len(tk_ex_dbs) - 1
+ ]
+ tk_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )[0]
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk_ac_ex_db.liveaction["id"]
+ )
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db))
+ self.assertTrue(
+ wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db)
+ )
wf_svc.handle_action_execution_completion(tk_ac_ex_db)
@@ -131,10 +155,10 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output,
self.assertEqual(ac_ex_db.status, expected_status)
# Check workflow output, liveaction result, and action execution result.
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
if expected_errors is not None:
- expected_result['errors'] = expected_errors
+ expected_result["errors"] = expected_errors
if expected_output is not None:
self.assertDictEqual(wf_ex_db.output, expected_output)
@@ -143,83 +167,81 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output,
self.assertDictEqual(ac_ex_db.result, expected_result)
def test_task_functions_in_yaql(self):
- wf_name = 'yaql-task-functions'
+ wf_name = "yaql-task-functions"
expected_task_sequence = [
- ('task1', 0),
- ('task3', 0),
- ('task6', 0),
- ('task7', 0),
- ('task2', 0),
- ('task4', 0),
- ('task8', 1),
- ('task8', 2),
- ('task4', 0),
- ('task9', 1),
- ('task9', 2),
- ('task5', 0)
+ ("task1", 0),
+ ("task3", 0),
+ ("task6", 0),
+ ("task7", 0),
+ ("task2", 0),
+ ("task4", 0),
+ ("task8", 1),
+ ("task8", 2),
+ ("task4", 0),
+ ("task9", 1),
+ ("task9", 2),
+ ("task5", 0),
]
expected_output = {
- 'last_task4_result': 'False',
- 'task9__1__parent': 'task8__1',
- 'task9__2__parent': 'task8__2',
- 'that_task_by_name': 'task1',
- 'this_task_by_name': 'task1',
- 'this_task_no_arg': 'task1'
+ "last_task4_result": "False",
+ "task9__1__parent": "task8__1",
+ "task9__2__parent": "task8__2",
+ "that_task_by_name": "task1",
+ "this_task_by_name": "task1",
+ "this_task_no_arg": "task1",
}
self._execute_workflow(wf_name, expected_task_sequence, expected_output)
def test_task_functions_in_jinja(self):
- wf_name = 'jinja-task-functions'
+ wf_name = "jinja-task-functions"
expected_task_sequence = [
- ('task1', 0),
- ('task3', 0),
- ('task6', 0),
- ('task7', 0),
- ('task2', 0),
- ('task4', 0),
- ('task8', 1),
- ('task8', 2),
- ('task4', 0),
- ('task9', 1),
- ('task9', 2),
- ('task5', 0)
+ ("task1", 0),
+ ("task3", 0),
+ ("task6", 0),
+ ("task7", 0),
+ ("task2", 0),
+ ("task4", 0),
+ ("task8", 1),
+ ("task8", 2),
+ ("task4", 0),
+ ("task9", 1),
+ ("task9", 2),
+ ("task5", 0),
]
expected_output = {
- 'last_task4_result': 'False',
- 'task9__1__parent': 'task8__1',
- 'task9__2__parent': 'task8__2',
- 'that_task_by_name': 'task1',
- 'this_task_by_name': 'task1',
- 'this_task_no_arg': 'task1'
+ "last_task4_result": "False",
+ "task9__1__parent": "task8__1",
+ "task9__2__parent": "task8__2",
+ "that_task_by_name": "task1",
+ "this_task_by_name": "task1",
+ "this_task_no_arg": "task1",
}
self._execute_workflow(wf_name, expected_task_sequence, expected_output)
def test_task_nonexistent_in_yaql(self):
- wf_name = 'yaql-task-nonexistent'
+ wf_name = "yaql-task-nonexistent"
- expected_task_sequence = [
- ('task1', 0)
- ]
+ expected_task_sequence = [("task1", 0)]
expected_output = None
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% task("task0") %>\'. ExpressionEvaluationException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% task(\"task0\") %>'. ExpressionEvaluationException: "
'Unable to find task execution for "task0".'
),
- 'task_transition_id': 'continue__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "continue__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
@@ -228,29 +250,27 @@ def test_task_nonexistent_in_yaql(self):
expected_task_sequence,
expected_output,
expected_status=ac_const.LIVEACTION_STATUS_FAILED,
- expected_errors=expected_errors
+ expected_errors=expected_errors,
)
def test_task_nonexistent_in_jinja(self):
- wf_name = 'jinja-task-nonexistent'
+ wf_name = "jinja-task-nonexistent"
- expected_task_sequence = [
- ('task1', 0)
- ]
+ expected_task_sequence = [("task1", 0)]
expected_output = None
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'JinjaEvaluationException: Unable to evaluate expression '
- '\'{{ task("task0") }}\'. ExpressionEvaluationException: '
+ "type": "error",
+ "message": (
+ "JinjaEvaluationException: Unable to evaluate expression "
+ "'{{ task(\"task0\") }}'. ExpressionEvaluationException: "
'Unable to find task execution for "task0".'
),
- 'task_transition_id': 'continue__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "continue__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
@@ -259,5 +279,5 @@ def test_task_nonexistent_in_jinja(self):
expected_task_sequence,
expected_output,
expected_status=ac_const.LIVEACTION_STATUS_FAILED,
- expected_errors=expected_errors
+ expected_errors=expected_errors,
)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py
index 3e84d7bce8..8dfdf24a84 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -45,37 +46,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerTest, cls).setUpClass()
@@ -85,30 +94,35 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def test_inquiry(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-approval.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "ask-approval.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert start task is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t1_ac_ex_db)
t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
@@ -118,10 +132,15 @@ def test_inquiry(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert get approval task is already pending.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'}
+ query_filters = {
+ "workflow_execution": str(wf_ex_db.id),
+ "task_id": "get_approval",
+ }
t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
@@ -133,12 +152,16 @@ def test_inquiry(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
- self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id))
- self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id))
self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@@ -148,11 +171,15 @@ def test_inquiry(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert the final task is completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"}
t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
- self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id)
self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED)
@@ -162,22 +189,30 @@ def test_inquiry(self):
self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED)
def test_consecutive_inquiries(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-consecutive-approvals.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "ask-consecutive-approvals.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert start task is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t1_ac_ex_db)
t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
@@ -187,10 +222,15 @@ def test_consecutive_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert get approval task is already pending.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'}
+ query_filters = {
+ "workflow_execution": str(wf_ex_db.id),
+ "task_id": "get_approval",
+ }
t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
@@ -202,12 +242,16 @@ def test_consecutive_inquiries(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
- self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id))
- self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id))
self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@@ -217,10 +261,15 @@ def test_consecutive_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert the final task is completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_confirmation'}
+ query_filters = {
+ "workflow_execution": str(wf_ex_db.id),
+ "task_id": "get_confirmation",
+ }
t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id)
@@ -232,12 +281,16 @@ def test_consecutive_inquiries(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id))
- self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id))
- self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id))
self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@@ -247,11 +300,15 @@ def test_consecutive_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert the final task is completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"}
t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0]
- t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id'])
- self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t4_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t4_ex_db.id)
+ )[0]
+ t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t4_ac_ex_db)
t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id)
self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED)
@@ -261,22 +318,30 @@ def test_consecutive_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED)
def test_parallel_inquiries(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-parallel-approvals.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "ask-parallel-approvals.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert start task is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t1_ac_ex_db)
t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
@@ -286,10 +351,12 @@ def test_parallel_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert get approval task is already pending.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jack'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jack"}
t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
@@ -300,10 +367,12 @@ def test_parallel_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.PAUSING)
# Assert get approval task is already pending.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jill'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jill"}
t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id)
@@ -315,12 +384,16 @@ def test_parallel_inquiries(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
- self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id))
- self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id))
self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@@ -332,12 +405,16 @@ def test_parallel_inquiries(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id))
- self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id))
- self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id))
self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@@ -347,11 +424,15 @@ def test_parallel_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert the final task is completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"}
t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0]
- t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id'])
- self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t4_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t4_ex_db.id)
+ )[0]
+ t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t4_ac_ex_db)
t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id)
self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED)
@@ -361,22 +442,30 @@ def test_parallel_inquiries(self):
self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED)
def test_nested_inquiry(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-nested-approval.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "ask-nested-approval.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert start task is already completed.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t1_ac_ex_db)
t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
@@ -386,23 +475,36 @@ def test_nested_inquiry(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert the subworkflow is already started.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'}
+ query_filters = {
+ "workflow_execution": str(wf_ex_db.id),
+ "task_id": "get_approval",
+ }
t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
self.assertEqual(t2_ex_db.status, wf_statuses.RUNNING)
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Process task1 of subworkflow.
- query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'start'}
+ query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "start"}
t2_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
- t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t1_ac_ex_db.liveaction['id'])
- self.assertEqual(t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
+ t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ t2_t1_ac_ex_db.liveaction["id"]
+ )
+ self.assertEqual(
+ t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_t1_ac_ex_db)
t2_t1_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t1_ex_db.id)
self.assertEqual(t2_t1_ex_db.status, wf_statuses.SUCCEEDED)
@@ -410,11 +512,20 @@ def test_nested_inquiry(self):
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Process inquiry task of subworkflow and assert the subworkflow is paused.
- query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'get_approval'}
+ query_filters = {
+ "workflow_execution": str(t2_wf_ex_db.id),
+ "task_id": "get_approval",
+ }
t2_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0]
- t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t2_ac_ex_db.liveaction['id'])
- self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING)
+ t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t2_ex_db.id)
+ )[0]
+ t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ t2_t2_ac_ex_db.liveaction["id"]
+ )
+ self.assertEqual(
+ t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING
+ )
workflows.get_engine().process(t2_t2_ac_ex_db)
t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t2_ex_db.id)
self.assertEqual(t2_t2_ex_db.status, wf_statuses.PENDING)
@@ -422,8 +533,10 @@ def test_nested_inquiry(self):
self.assertEqual(t2_wf_ex_db.status, wf_statuses.PAUSED)
# Process the corresponding task in parent workflow and assert the task is paused.
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSED)
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
@@ -435,34 +548,50 @@ def test_nested_inquiry(self):
# Respond to the inquiry and check status.
inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_t2_ac_ex_db)
- inquiry_response = {'approved': True}
+ inquiry_response = {"approved": True}
inquiry_service.respond(inquiry_api, inquiry_response)
t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_t2_lv_ac_db.id))
- self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
t2_t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_t2_ac_ex_db.id))
- self.assertEqual(t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_t2_ac_ex_db)
t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_t2_ex_db.id))
- self.assertEqual(t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Assert the main workflow is running again.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Complete the rest of the subworkflow
- query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'finish'}
+ query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "finish"}
t2_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0]
- t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t3_ac_ex_db.liveaction['id'])
- self.assertEqual(t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t3_ex_db.id)
+ )[0]
+ t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ t2_t3_ac_ex_db.liveaction["id"]
+ )
+ self.assertEqual(
+ t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_t3_ac_ex_db)
t2_t3_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t3_ex_db.id)
self.assertEqual(t2_t3_ex_db.status, wf_statuses.SUCCEEDED)
t2_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t2_wf_ex_db.id))
self.assertEqual(t2_wf_ex_db.status, wf_statuses.SUCCEEDED)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED)
@@ -470,11 +599,15 @@ def test_nested_inquiry(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Complete the rest of the main workflow
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"}
t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
- self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id)
self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_notify.py b/contrib/runners/orquesta_runner/tests/unit/test_notify.py
index dc8131f100..6ca125d855 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_notify.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_notify.py
@@ -25,6 +25,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -47,57 +48,60 @@
from st2tests.mocks import liveaction as mock_lv_ac_xport
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
MOCK_NOTIFY = {
- 'on-complete': {
- 'data': {
- 'source_channel': 'baloney',
- 'user': 'lakstorm'
- },
- 'routes': [
- 'hubot'
- ]
+ "on-complete": {
+ "data": {"source_channel": "baloney", "user": "lakstorm"},
+ "routes": ["hubot"],
}
}
@mock.patch.object(
- notifier.Notifier,
- '_post_notify_triggers',
- mock.MagicMock(return_value=None))
+ notifier.Notifier, "_post_notify_triggers", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
- notifier.Notifier,
- '_post_generic_trigger',
- mock.MagicMock(return_value=None))
+ notifier.Notifier, "_post_generic_trigger", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update))
+ "publish_update",
+ mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaNotifyTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaNotifyTest, cls).setUpClass()
@@ -107,177 +111,181 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def test_no_notify(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Check that notify is setup correctly in the db record.
self.assertDictEqual(wf_ex_db.notify, {})
def test_no_notify_task_list(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Check that notify is setup correctly in the db record.
- expected_notify = {
- 'config': MOCK_NOTIFY,
- 'tasks': []
- }
+ expected_notify = {"config": MOCK_NOTIFY, "tasks": []}
self.assertDictEqual(wf_ex_db.notify, expected_notify)
def test_custom_notify_task_list(self):
- wf_input = {'notify': ['task1']}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"notify": ["task1"]}
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Check that notify is setup correctly in the db record.
- expected_notify = {
- 'config': MOCK_NOTIFY,
- 'tasks': wf_input['notify']
- }
+ expected_notify = {"config": MOCK_NOTIFY, "tasks": wf_input["notify"]}
self.assertDictEqual(wf_ex_db.notify, expected_notify)
def test_default_notify_task_list(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'notify.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "notify.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Check that notify is setup correctly in the db record.
- expected_notify = {
- 'config': MOCK_NOTIFY,
- 'tasks': ['task1', 'task2', 'task3']
- }
+ expected_notify = {"config": MOCK_NOTIFY, "tasks": ["task1", "task2", "task3"]}
self.assertDictEqual(wf_ex_db.notify, expected_notify)
def test_notify_task_list_bad_item_value(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
expected_schema_failure_test_cases = [
- 'task1', # Notify must be type of list.
- [123], # Item has to be type of string.
- [''], # String value cannot be empty.
- [' '], # String value cannot be just spaces.
- [' '], # String value cannot be just tabs.
- ['init task'], # String value cannot have space.
- ['init-task'], # String value cannot have dash.
- ['task1', 'task1'] # String values have to be unique.
+ "task1", # Notify must be type of list.
+ [123], # Item has to be type of string.
+ [""], # String value cannot be empty.
+ [" "], # String value cannot be just spaces.
+ [" "], # String value cannot be just tabs.
+ ["init task"], # String value cannot have space.
+ ["init-task"], # String value cannot have dash.
+ ["task1", "task1"], # String values have to be unique.
]
for notify_tasks in expected_schema_failure_test_cases:
- lv_ac_db.parameters = {'notify': notify_tasks}
+ lv_ac_db.parameters = {"notify": notify_tasks}
try:
self.assertRaises(
- jsonschema.ValidationError,
- action_service.request,
- lv_ac_db
+ jsonschema.ValidationError, action_service.request, lv_ac_db
)
except Exception as e:
- raise AssertionError('%s: %s' % (six.text_type(e), notify_tasks))
+ raise AssertionError("%s: %s" % (six.text_type(e), notify_tasks))
def test_notify_task_list_nonexistent_task(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
- lv_ac_db.parameters = {'notify': ['init_task']}
+ lv_ac_db.parameters = {"notify": ["init_task"]}
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
expected_result = {
- 'output': None,
- 'errors': [
+ "output": None,
+ "errors": [
{
- 'message': (
- 'The following tasks in the notify parameter do not '
- 'exist in the workflow definition: init_task.'
+ "message": (
+ "The following tasks in the notify parameter do not "
+ "exist in the workflow definition: init_task."
)
}
- ]
+ ],
}
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
self.assertDictEqual(lv_ac_db.result, expected_result)
def test_notify_task_list_item_value(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
- expected_schema_success_test_cases = [
- [],
- ['task1'],
- ['task1', 'task2']
- ]
+ expected_schema_success_test_cases = [[], ["task1"], ["task1", "task2"]]
for notify_tasks in expected_schema_success_test_cases:
- lv_ac_db.parameters = {'notify': notify_tasks}
+ lv_ac_db.parameters = {"notify": notify_tasks}
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
+ self.assertEqual(
+ lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING
+ )
def test_cascade_notify_to_tasks(self):
- wf_input = {'notify': ['task2']}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"notify": ["task2"]}
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY)
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert task1 notify is not set.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertIsNone(tk1_lv_ac_db.notify)
- self.assertEqual(tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
self.assertFalse(notifier.Notifier._post_notify_triggers.called)
notifier.Notifier._post_notify_triggers.reset_mock()
@@ -289,13 +297,19 @@ def test_cascade_notify_to_tasks(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task2 notify is set.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
- notify = notify_api_models.NotificationsHelper.from_model(notify_model=tk2_lv_ac_db.notify)
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
+ notify = notify_api_models.NotificationsHelper.from_model(
+ notify_model=tk2_lv_ac_db.notify
+ )
self.assertEqual(notify, MOCK_NOTIFY)
- self.assertEqual(tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
self.assertTrue(notifier.Notifier._post_notify_triggers.called)
notifier.Notifier._post_notify_triggers.reset_mock()
@@ -307,12 +321,16 @@ def test_cascade_notify_to_tasks(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Assert task3 notify is not set.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
self.assertIsNone(tk3_lv_ac_db.notify)
- self.assertEqual(tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ self.assertEqual(
+ tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
self.assertFalse(notifier.Notifier._post_notify_triggers.called)
notifier.Notifier._post_notify_triggers.reset_mock()
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py
index 5bae5bab27..f23084b527 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py
@@ -22,6 +22,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -45,12 +46,14 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
FAIL_SCHEMA = {
@@ -61,25 +64,32 @@
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerTest(RunnerTestCase, st2tests.ExecutionDbTestCase):
@classmethod
def setUpClass(cls):
@@ -90,8 +100,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -102,28 +111,40 @@ def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name, runner_name).__class__
def test_adherence_to_output_schema(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential_with_schema.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "sequential_with_schema.yaml"
+ )
+ wf_input = {"who": "Thanos"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
wf_ex_db = wf_ex_dbs[0]
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk3_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
@@ -134,30 +155,39 @@ def test_adherence_to_output_schema(self):
def test_fail_incorrect_output_schema(self):
wf_meta = base.get_wf_fixture_meta_data(
- TEST_PACK_PATH,
- 'sequential_with_broken_schema.yaml'
+ TEST_PACK_PATH, "sequential_with_broken_schema.yaml"
+ )
+ wf_input = {"who": "Thanos"}
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
)
- wf_input = {'who': 'Thanos'}
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
wf_ex_db = wf_ex_dbs[0]
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk2_ac_ex_db)
tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
wf_svc.handle_action_execution_completion(tk3_ac_ex_db)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
@@ -167,9 +197,9 @@ def test_fail_incorrect_output_schema(self):
self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED)
expected_result = {
- 'error': "Additional properties are not allowed",
- 'message': 'Error validating output. See error output for more details.'
+ "error": "Additional properties are not allowed",
+ "message": "Error validating output. See error output for more details.",
}
- self.assertIn(expected_result['error'], ac_ex_db.result['error'])
- self.assertEqual(expected_result['message'], ac_ex_db.result['message'])
+ self.assertIn(expected_result["error"], ac_ex_db.result["error"])
+ self.assertEqual(expected_result["message"], ac_ex_db.result["message"])
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py
index c2021b379e..6ade390029 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py
@@ -24,6 +24,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -46,37 +47,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerPauseResumeTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerPauseResumeTest, cls).setUpClass()
@@ -86,8 +95,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -97,56 +105,68 @@ def setUpClass(cls):
def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name, runner_name).__class__
- @mock.patch.object(
- ac_svc, 'is_children_active',
- mock.MagicMock(return_value=False))
+ @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False))
def test_pause(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
- @mock.patch.object(
- ac_svc, 'is_children_active',
- mock.MagicMock(return_value=True))
+ @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True))
def test_pause_with_active_children(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
def test_pause_subworkflow_not_cascade_up_to_workflow(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the subworkflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ac_ex_dbs), 1)
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id'])
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow.
- tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause(tk_lv_ac_db, cfg.CONF.system_user.user)
+ tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause(
+ tk_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -154,38 +174,52 @@ def test_pause_subworkflow_not_cascade_up_to_workflow(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
def test_pause_workflow_cascade_down_to_subworkflow(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
wf_ex_db = wf_ex_dbs[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
tk_ex_db = tk_ex_dbs[0]
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )
self.assertEqual(len(tk_ac_ex_dbs), 1)
tk_ac_ex_db = tk_ac_ex_dbs[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Identify the records for the subworkflow.
- sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id))
+ sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(tk_ac_ex_db.id)
+ )
self.assertEqual(len(sub_wf_ex_dbs), 1)
sub_wf_ex_db = sub_wf_ex_dbs[0]
- sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id))
+ sub_tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(sub_wf_ex_db.id)
+ )
self.assertEqual(len(sub_tk_ex_dbs), 1)
sub_tk_ex_db = sub_tk_ex_dbs[0]
- sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id))
+ sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(sub_tk_ex_db.id)
+ )
self.assertEqual(len(sub_tk_ac_ex_dbs), 1)
# Pause the main workflow and assert it is pausing because subworkflow is still running.
@@ -213,32 +247,48 @@ def test_pause_workflow_cascade_down_to_subworkflow(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
def test_pause_subworkflow_while_another_subworkflow_running(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
# Identify the records for the subworkflows.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Pause the subworkflow.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -246,12 +296,16 @@ def test_pause_subworkflow_while_another_subworkflow_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the task in the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -267,18 +321,30 @@ def test_pause_subworkflow_while_another_subworkflow_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the other subworkflow.
- t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
+ t2_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[0]
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t1_ac_ex_db)
- t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1]
- t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0]
+ t2_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[1]
+ t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t2_ac_ex_db)
- t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2]
- t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0]
+ t2_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[2]
+ t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t3_ac_ex_db)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -293,32 +359,48 @@ def test_pause_subworkflow_while_another_subworkflow_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
def test_pause_subworkflow_while_another_subworkflow_completed(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
# Identify the records for the subworkflows.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Pause the subworkflow.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -326,18 +408,30 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the other subworkflow.
- t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
+ t2_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[0]
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t1_ac_ex_db)
- t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1]
- t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0]
+ t2_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[1]
+ t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t2_ac_ex_db)
- t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2]
- t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0]
+ t2_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[2]
+ t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t3_ac_ex_db)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -352,12 +446,16 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the target subworkflow is still pausing.
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Manually notify action execution completion for the task in the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -372,15 +470,15 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self):
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
- @mock.patch.object(
- ac_svc, 'is_children_active',
- mock.MagicMock(return_value=False))
+ @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False))
def test_resume(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Pause the workflow.
lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user)
@@ -388,63 +486,93 @@ def test_resume(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Identify the records for the running task(s) and manually complete it.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id'])
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(
+ tk_ac_ex_dbs[0].liveaction["id"]
+ )
self.assertEqual(tk_ac_ex_dbs[0].status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk_ac_ex_dbs[0])
# Ensure the workflow is paused.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result)
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result
+ )
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(wf_ex_dbs[0].status, wf_statuses.PAUSED)
# Resume the workflow.
lv_ac_db, ac_ex_db = ac_svc.request_resume(lv_ac_db, cfg.CONF.system_user.user)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(wf_ex_dbs[0].status, wf_statuses.RUNNING)
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_dbs[0].id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
def test_resume_cascade_to_subworkflow(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
wf_ex_db = wf_ex_dbs[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
tk_ex_db = tk_ex_dbs[0]
- tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))
+ tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_db.id)
+ )
self.assertEqual(len(tk_ac_ex_dbs), 1)
tk_ac_ex_db = tk_ac_ex_dbs[0]
- tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id'])
+ tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"])
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Identify the records for the subworkflow.
- sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id))
+ sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(tk_ac_ex_db.id)
+ )
self.assertEqual(len(sub_wf_ex_dbs), 1)
sub_wf_ex_db = sub_wf_ex_dbs[0]
- sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id))
+ sub_tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(sub_wf_ex_db.id)
+ )
self.assertEqual(len(sub_tk_ex_dbs), 1)
sub_tk_ex_db = sub_tk_ex_dbs[0]
- sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id))
+ sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(sub_tk_ex_db.id)
+ )
self.assertEqual(len(sub_tk_ac_ex_dbs), 1)
# Pause the main workflow and assert it is pausing because subworkflow is still running.
@@ -481,32 +609,48 @@ def test_resume_cascade_to_subworkflow(self):
self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
def test_resume_from_each_subworkflow_when_parent_is_paused(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
# Identify the records for the subworkflows.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Pause one of the subworkflows.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -514,12 +658,16 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the task in the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -535,11 +683,13 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the other subworkflow.
- t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause(t2_lv_ac_db, cfg.CONF.system_user.user)
+ t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause(
+ t2_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -547,8 +697,12 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the task in the subworkflow.
- t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
+ t2_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[0]
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -564,7 +718,9 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
# Resume the subworkflow and assert it is running.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
@@ -573,11 +729,19 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the subworkflow.
- t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[1]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t2_ac_ex_db)
- t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[2]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t3_ac_ex_db)
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -592,32 +756,48 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
def test_resume_from_subworkflow_when_parent_is_paused(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
# Identify the records for the subworkflows.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Pause the subworkflow.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -625,12 +805,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the task in the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -646,18 +830,30 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the other subworkflow.
- t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
+ t2_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[0]
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t1_ac_ex_db)
- t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1]
- t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0]
+ t2_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[1]
+ t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t2_ac_ex_db)
- t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2]
- t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0]
+ t2_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[2]
+ t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t3_ac_ex_db)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -672,7 +868,9 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED)
# Resume the subworkflow and assert it is running.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
@@ -681,11 +879,19 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the subworkflow.
- t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[1]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t2_ac_ex_db)
- t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[2]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t3_ac_ex_db)
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -696,12 +902,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
workflows.get_engine().process(t1_ac_ex_db)
# Assert task3 has started and completed.
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 3)
- t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(t3_ac_ex_db)
@@ -710,32 +920,48 @@ def test_resume_from_subworkflow_when_parent_is_paused(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_resume_from_subworkflow_when_parent_is_running(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 2)
# Identify the records for the subworkflows.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id'])
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"])
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0]
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
- t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[1].id)
+ )[0]
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
+ t2_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t2_ac_ex_db.id)
+ )[0]
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING)
# Pause the subworkflow.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING)
# Assert the main workflow is still running.
@@ -743,12 +969,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the task in the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t1_ac_ex_db)
# Assert the subworkflow is paused and manually notify the paused of the
@@ -764,11 +994,13 @@ def test_resume_from_subworkflow_when_parent_is_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Resume the subworkflow and assert it is running.
- t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user)
+ t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(
+ t1_lv_ac_db, cfg.CONF.system_user.user
+ )
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
@@ -777,15 +1009,23 @@ def test_resume_from_subworkflow_when_parent_is_running(self):
self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Assert the other subworkflow is still running.
- t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id'])
+ t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"])
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
# Manually notify action execution completion for the tasks in the subworkflow.
- t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1]
- t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0]
+ t1_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[1]
+ t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t2_ac_ex_db)
- t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2]
- t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0]
+ t1_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[2]
+ t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_t3_ac_ex_db)
t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id))
self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -796,14 +1036,26 @@ def test_resume_from_subworkflow_when_parent_is_running(self):
workflows.get_engine().process(t1_ac_ex_db)
# Manually notify action execution completion for the tasks in the other subworkflow.
- t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0]
- t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0]
+ t2_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[0]
+ t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t1_ac_ex_db)
- t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1]
- t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0]
+ t2_t2_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[1]
+ t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t2_ac_ex_db)
- t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2]
- t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0]
+ t2_t3_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t2_wf_ex_db.id)
+ )[2]
+ t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_t3_ac_ex_db)
t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id))
self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
@@ -814,12 +1066,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self):
workflows.get_engine().process(t2_ac_ex_db)
# Assert task3 has started and completed.
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 3)
- t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
- t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id'])
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
+ t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"])
self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(t3_ac_ex_db)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_policies.py b/contrib/runners/orquesta_runner/tests/unit/test_policies.py
index 2595609f63..81ab639262 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_policies.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_policies.py
@@ -23,6 +23,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -46,37 +47,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaRunnerTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaRunnerTest, cls).setUpClass()
@@ -86,7 +95,7 @@ def setUpClass(cls):
policiesregistrar.register_policy_types(st2common)
# Register test pack(s).
- registrar_options = {'use_pack_cache': False, 'fail_on_failure': True}
+ registrar_options = {"use_pack_cache": False, "fail_on_failure": True}
actions_registrar = actionsregistrar.ActionsRegistrar(**registrar_options)
policies_registrar = policiesregistrar.PolicyRegistrar(**registrar_options)
@@ -106,27 +115,37 @@ def tearDown(self):
ac_ex_db.delete()
def test_retry_policy_applied_on_workflow_failure(self):
- wf_name = 'sequential'
- wf_ac_ref = TEST_PACK + '.' + wf_name
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_name = "sequential"
+ wf_ac_ref = TEST_PACK + "." + wf_name
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Ensure there is only one execution recorded.
self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 1)
# Identify the records for the workflow and task.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )[0]
t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
# Manually set the status to fail.
ac_svc.update_status(t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED)
t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED)
notifier.get_notifier().process(t1_ac_ex_db)
workflows.get_engine().process(t1_ac_ex_db)
@@ -140,32 +159,48 @@ def test_retry_policy_applied_on_workflow_failure(self):
self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 2)
def test_no_retry_policy_applied_on_task_failure(self):
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
- self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result)
+ self.assertEqual(
+ lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result
+ )
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
# Identify the records for the tasks.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING)
self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING)
# Ensure there is only one execution for the task.
- tk_ac_ref = TEST_PACK + '.' + 'sequential'
+ tk_ac_ref = TEST_PACK + "." + "sequential"
self.assertEqual(len(lv_db_access.LiveAction.query(action=tk_ac_ref)), 1)
# Fail the subtask of the subworkflow.
- t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0]
- t1_t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ex_db = wf_db_access.TaskExecution.query(
+ workflow_execution=str(t1_wf_ex_db.id)
+ )[0]
+ t1_t1_lv_ac_db = lv_db_access.LiveAction.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
ac_svc.update_status(t1_t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED)
- t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0]
+ t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_t1_ex_db.id)
+ )[0]
self.assertEqual(t1_t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED)
notifier.get_notifier().process(t1_t1_ac_ex_db)
workflows.get_engine().process(t1_t1_ac_ex_db)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py
index 59f2f94d08..191f3a0681 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py
@@ -20,6 +20,7 @@
import st2tests
import st2tests.config as tests_config
+
tests_config.parse_args()
from local_runner import local_shell_command_runner
@@ -41,41 +42,57 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {})
-RUNNER_RESULT_RUNNING = (action_constants.LIVEACTION_STATUS_RUNNING, {'stdout': '...'}, {})
-RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {})
+RUNNER_RESULT_RUNNING = (
+ action_constants.LIVEACTION_STATUS_RUNNING,
+ {"stdout": "..."},
+ {},
+)
+RUNNER_RESULT_SUCCEEDED = (
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
+ {"stdout": "foobar"},
+ {},
+)
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestRunnerTest(st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestRunnerTest, cls).setUpClass()
@@ -85,28 +102,35 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]),
+ )
def test_rerun_workflow(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
@@ -121,18 +145,15 @@ def test_rerun_workflow(self):
self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_FAILED)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
# Assert the workflow reran ok and is running.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db2.id)
+ )[0]
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -140,33 +161,45 @@ def test_rerun_workflow(self):
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process task1 and make sure it succeeds.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters)
self.assertEqual(len(tk1_ex_dbs), 2)
tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp)
tk1_ex_db = tk1_ex_dbs[-1]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]),
+ )
def test_rerun_with_missing_workflow_execution_id(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
@@ -184,49 +217,52 @@ def test_rerun_with_missing_workflow_execution_id(self):
wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False)
# Manually delete the workflow_execution_id from context of the action execution.
- lv_ac_db1.context.pop('workflow_execution')
+ lv_ac_db1.context.pop("workflow_execution")
lv_ac_db1 = lv_db_access.LiveAction.add_or_update(lv_ac_db1, publish=False)
ac_ex_db1 = execution_service.update_execution(lv_ac_db1, publish=False)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
expected_error = (
- 'Unable to rerun workflow execution because '
- 'workflow_execution_id is not provided.'
+ "Unable to rerun workflow execution because "
+ "workflow_execution_id is not provided."
)
# Assert the workflow rerrun fails.
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"])
ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id))
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"])
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]),
+ )
def test_rerun_with_invalid_workflow_execution(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
@@ -244,45 +280,50 @@ def test_rerun_with_invalid_workflow_execution(self):
wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
expected_error = (
'Unable to rerun workflow execution "%s" because '
- 'it does not exist.' % str(wf_ex_db.id)
+ "it does not exist." % str(wf_ex_db.id)
)
# Assert the workflow rerrun fails.
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"])
ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id))
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"])
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING]),
+ )
def test_rerun_workflow_still_running(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING
+ )
# Assert workflow is still running.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
@@ -293,47 +334,52 @@ def test_rerun_workflow_still_running(self):
self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
expected_error = (
'Unable to rerun workflow execution "%s" because '
- 'it is not in a completed state.' % str(wf_ex_db.id)
+ "it is not in a completed state." % str(wf_ex_db.id)
)
# Assert the workflow rerrun fails.
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"])
ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id))
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"])
@mock.patch.object(
- workflow_service, 'request_rerun',
- mock.MagicMock(side_effect=Exception('Unexpected.')))
+ workflow_service,
+ "request_rerun",
+ mock.MagicMock(side_effect=Exception("Unexpected.")),
+ )
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]),
+ )
def test_rerun_with_unexpected_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
@@ -351,62 +397,75 @@ def test_rerun_with_unexpected_error(self):
wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
- expected_error = 'Unexpected.'
+ expected_error = "Unexpected."
# Assert the workflow rerrun fails.
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"])
ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id))
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message'])
+ self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"])
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED),
+ )
def test_rerun_workflow_already_succeeded(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- wf_input = {'who': 'Thanos'}
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ wf_input = {"who": "Thanos"}
+ lv_ac_db1 = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db1.id)
+ )[0]
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
# Process task2.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
- self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk2_ac_ex_db)
tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id)
self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED)
# Process task3.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
- self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk3_ac_ex_db)
tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id)
self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED)
@@ -420,18 +479,15 @@ def test_rerun_workflow_already_succeeded(self):
self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
# Rerun the execution.
- context = {
- 're-run': {
- 'ref': str(ac_ex_db1.id),
- 'tasks': ['task1']
- }
- }
-
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context)
+ context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}}
+
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context)
lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2)
# Assert the workflow reran ok and is running.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db2.id)
+ )[0]
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id))
self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -439,40 +495,52 @@ def test_rerun_workflow_already_succeeded(self):
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Assert there are two task1 and the last entry succeeded.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters)
self.assertEqual(len(tk1_ex_dbs), 2)
tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp)
tk1_ex_db = tk1_ex_dbs[-1]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
- self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk1_ac_ex_db)
tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id)
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
# Assert there are two task2 and the last entry succeeded.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
tk2_ex_dbs = wf_db_access.TaskExecution.query(**query_filters)
self.assertEqual(len(tk2_ex_dbs), 2)
tk2_ex_dbs = sorted(tk2_ex_dbs, key=lambda x: x.start_timestamp)
tk2_ex_db = tk2_ex_dbs[-1]
- tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0]
- tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id'])
- self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk2_ex_db.id)
+ )[0]
+ tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk2_ac_ex_db)
tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id)
self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED)
# Assert there are two task3 and the last entry succeeded.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
tk3_ex_dbs = wf_db_access.TaskExecution.query(**query_filters)
self.assertEqual(len(tk3_ex_dbs), 2)
tk3_ex_dbs = sorted(tk3_ex_dbs, key=lambda x: x.start_timestamp)
tk3_ex_db = tk3_ex_dbs[-1]
- tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0]
- tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id'])
- self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ tk3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk3_ex_db.id)
+ )[0]
+ tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"])
+ self.assertEqual(
+ tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
workflow_service.handle_action_execution_completion(tk3_ac_ex_db)
tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id)
self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED)
diff --git a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py
index cc0846d733..6676874586 100644
--- a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py
+++ b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py
@@ -25,6 +25,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from tests.unit import base
@@ -48,37 +49,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaWithItemsTest(st2tests.ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(OrquestaWithItemsTest, cls).setUpClass()
@@ -88,8 +97,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -101,35 +109,34 @@ def get_runner_class(cls, runner_name):
def set_execution_status(self, lv_ac_db_id, status):
lv_ac_db = action_utils.update_liveaction_status(
- status=status,
- liveaction_id=lv_ac_db_id,
- publish=False
+ status=status, liveaction_id=lv_ac_db_id, publish=False
)
- ac_ex_db = execution_service.update_execution(
- lv_ac_db,
- publish=False
- )
+ ac_ex_db = execution_service.update_execution(lv_ac_db, publish=False)
return lv_ac_db, ac_ex_db
def test_with_items(self):
num_items = 3
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process the with items task.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(len(t1_ac_ex_dbs), num_items)
@@ -155,20 +162,26 @@ def test_with_items(self):
def test_with_items_failure(self):
num_items = 10
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-failure.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-failure.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process the with items task.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(len(t1_ac_ex_dbs), num_items)
@@ -195,52 +208,68 @@ def test_with_items_failure(self):
def test_with_items_empty_list(self):
items = []
num_items = len(items)
- wf_input = {'members': items}
+ wf_input = {"members": items}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Wait for the liveaction to complete.
- lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ lv_ac_db = self._wait_on_status(
+ lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
# Retrieve records from database.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
# Ensure there is no action executions for the task and the task is already completed.
self.assertEqual(len(t1_ac_ex_dbs), num_items)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
- self.assertDictEqual(t1_ex_db.result, {'items': []})
+ self.assertDictEqual(t1_ex_db.result, {"items": []})
# Assert the main workflow is completed.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED)
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertDictEqual(lv_ac_db.result, {'output': {'items': []}})
+ self.assertDictEqual(lv_ac_db.result, {"output": {"items": []}})
def test_with_items_concurrency(self):
num_items = 3
concurrency = 2
- wf_input = {'concurrency': concurrency}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"concurrency": concurrency}
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process the first set of action executions from with items concurrency.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(len(t1_ac_ex_dbs), concurrency)
@@ -261,7 +290,9 @@ def test_with_items_concurrency(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Process the second set of action executions from with items concurrency.
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(len(t1_ac_ex_dbs), num_items)
@@ -287,30 +318,37 @@ def test_with_items_concurrency(self):
def test_with_items_cancellation(self):
num_items = 3
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert the workflow execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(len(t1_ac_ex_dbs), num_items)
# Reset the action executions to running status.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_RUNNING
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING
@@ -328,11 +366,12 @@ def test_with_items_cancellation(self):
# Manually succeed the action executions and process completion.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED
@@ -353,31 +392,40 @@ def test_with_items_cancellation(self):
def test_with_items_concurrency_cancellation(self):
concurrency = 2
- wf_input = {'concurrency': concurrency}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"concurrency": concurrency}
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert the workflow execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(len(t1_ac_ex_dbs), concurrency)
# Reset the action executions to running status.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_RUNNING
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING
@@ -395,11 +443,12 @@ def test_with_items_concurrency_cancellation(self):
# Manually succeed the action executions and process completion.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED
@@ -420,30 +469,37 @@ def test_with_items_concurrency_cancellation(self):
def test_with_items_pause_and_resume(self):
num_items = 3
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert the workflow execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(len(t1_ac_ex_dbs), num_items)
# Reset the action executions to running status.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_RUNNING
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING
@@ -461,11 +517,12 @@ def test_with_items_pause_and_resume(self):
# Manually succeed the action executions and process completion.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED
@@ -498,31 +555,40 @@ def test_with_items_concurrency_pause_and_resume(self):
num_items = 3
concurrency = 2
- wf_input = {'concurrency': concurrency}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"concurrency": concurrency}
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-concurrency.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert the workflow execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING)
self.assertEqual(len(t1_ac_ex_dbs), concurrency)
# Reset the action executions to running status.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_RUNNING
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING
@@ -540,11 +606,12 @@ def test_with_items_concurrency_pause_and_resume(self):
# Manually succeed the action executions and process completion.
for ac_ex in t1_ac_ex_dbs:
self.set_execution_status(
- ac_ex.liveaction['id'],
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED
)
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
status = [
ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED
@@ -572,7 +639,9 @@ def test_with_items_concurrency_pause_and_resume(self):
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Check new set of action execution is scheduled.
- t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))
+ t1_ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )
self.assertEqual(len(t1_ac_ex_dbs), num_items)
# Manually process the last action execution.
@@ -585,20 +654,34 @@ def test_with_items_concurrency_pause_and_resume(self):
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
def test_subworkflow_with_items_empty_list(self):
- wf_input = {'members': []}
- wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-empty-parent.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input)
+ wf_input = {"members": []}
+ wf_meta = base.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "with-items-empty-parent.yaml"
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters=wf_input
+ )
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Identify the records for the main workflow.
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
- tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
+ tk_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(tk_ex_dbs), 1)
# Identify the records for the tasks.
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0]
- t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0]
- self.assertEqual(t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk_ex_dbs[0].id)
+ )[0]
+ t1_wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(t1_ac_ex_db.id)
+ )[0]
+ self.assertEqual(
+ t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
self.assertEqual(t1_wf_ex_db.status, wf_statuses.SUCCEEDED)
# Manually processing completion of the subworkflow in task1.
diff --git a/contrib/runners/python_runner/dist_utils.py b/contrib/runners/python_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/python_runner/dist_utils.py
+++ b/contrib/runners/python_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/python_runner/python_runner/__init__.py b/contrib/runners/python_runner/python_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/python_runner/python_runner/__init__.py
+++ b/contrib/runners/python_runner/python_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/python_runner/python_runner/python_action_wrapper.py b/contrib/runners/python_runner/python_runner/python_action_wrapper.py
index 119f6bdf84..b9ae0757b3 100644
--- a/contrib/runners/python_runner/python_runner/python_action_wrapper.py
+++ b/contrib/runners/python_runner/python_runner/python_action_wrapper.py
@@ -18,7 +18,8 @@
# Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7
import warnings
from cryptography.utils import CryptographyDeprecationWarning
-warnings.filterwarnings('ignore', category=CryptographyDeprecationWarning)
+
+warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import os
import sys
@@ -33,8 +34,8 @@
# lives gets added to sys.path and we don't want that.
# Note: We need to use just the suffix, because full path is different depending if the process
# is ran in virtualenv or not
-RUNNERS_PATH_SUFFIX = 'st2common/runners'
-if __name__ == '__main__':
+RUNNERS_PATH_SUFFIX = "st2common/runners"
+if __name__ == "__main__":
script_path = sys.path[0]
if RUNNERS_PATH_SUFFIX in script_path:
sys.path.pop(0)
@@ -61,10 +62,7 @@
from st2common.constants.runners import PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE
from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL
-__all__ = [
- 'PythonActionWrapper',
- 'ActionService'
-]
+__all__ = ["PythonActionWrapper", "ActionService"]
LOG = logging.getLogger(__name__)
@@ -104,15 +102,18 @@ def datastore_service(self):
# duration of the action lifetime
action_name = self._action_wrapper._class_name
log_level = self._action_wrapper._log_level
- logger = get_logger_for_python_runner_action(action_name=action_name,
- log_level=log_level)
+ logger = get_logger_for_python_runner_action(
+ action_name=action_name, log_level=log_level
+ )
pack_name = self._action_wrapper._pack
class_name = self._action_wrapper._class_name
- auth_token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None)
- self._datastore_service = ActionDatastoreService(logger=logger,
- pack_name=pack_name,
- class_name=class_name,
- auth_token=auth_token)
+ auth_token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None)
+ self._datastore_service = ActionDatastoreService(
+ logger=logger,
+ pack_name=pack_name,
+ class_name=class_name,
+ auth_token=auth_token,
+ )
return self._datastore_service
##################################
@@ -130,20 +131,32 @@ def list_values(self, local=True, prefix=None):
return self.datastore_service.list_values(local=local, prefix=prefix)
def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False):
- return self.datastore_service.get_value(name=name, local=local, scope=scope,
- decrypt=decrypt)
+ return self.datastore_service.get_value(
+ name=name, local=local, scope=scope, decrypt=decrypt
+ )
- def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False):
- return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local,
- scope=scope, encrypt=encrypt)
+ def set_value(
+ self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False
+ ):
+ return self.datastore_service.set_value(
+ name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt
+ )
def delete_value(self, name, local=True, scope=SYSTEM_SCOPE):
return self.datastore_service.delete_value(name=name, local=local, scope=scope)
class PythonActionWrapper(object):
- def __init__(self, pack, file_path, config=None, parameters=None, user=None, parent_args=None,
- log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL):
+ def __init__(
+ self,
+ pack,
+ file_path,
+ config=None,
+ parameters=None,
+ user=None,
+ parent_args=None,
+ log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL,
+ ):
"""
:param pack: Name of the pack this action belongs to.
:type pack: ``str``
@@ -173,19 +186,22 @@ def __init__(self, pack, file_path, config=None, parameters=None, user=None, par
self._log_level = log_level
self._class_name = None
- self._logger = logging.getLogger('PythonActionWrapper')
+ self._logger = logging.getLogger("PythonActionWrapper")
try:
st2common_config.parse_args(args=self._parent_args)
except Exception as e:
- LOG.debug('Failed to parse config using parent args (parent_args=%s): %s' %
- (str(self._parent_args), six.text_type(e)))
+ LOG.debug(
+ "Failed to parse config using parent args (parent_args=%s): %s"
+ % (str(self._parent_args), six.text_type(e))
+ )
# Note: We can only set a default user value if one is not provided after parsing the
# config
if not self._user:
# Note: We use late import to avoid performance overhead
from oslo_config import cfg
+
self._user = cfg.CONF.system_user.user
def run(self):
@@ -201,26 +217,25 @@ def run(self):
action_status = None
action_result = output
- action_output = {
- 'result': action_result,
- 'status': None
- }
+ action_output = {"result": action_result, "status": None}
if action_status is not None and not isinstance(action_status, bool):
- sys.stderr.write('Status returned from the action run() method must either be '
- 'True or False, got: %s\n' % (action_status))
+ sys.stderr.write(
+ "Status returned from the action run() method must either be "
+ "True or False, got: %s\n" % (action_status)
+ )
sys.stderr.write(INVALID_STATUS_ERROR_MESSAGE)
sys.exit(PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE)
if action_status is not None and isinstance(action_status, bool):
- action_output['status'] = action_status
+ action_output["status"] = action_status
# Special case if result object is not JSON serializable - aka user wanted to return a
# non-simple type (e.g. class instance or other non-JSON serializable type)
try:
- json.dumps(action_output['result'])
+ json.dumps(action_output["result"])
except TypeError:
- action_output['result'] = str(action_output['result'])
+ action_output["result"] = str(action_output["result"])
try:
print_output = json.dumps(action_output)
@@ -229,7 +244,7 @@ def run(self):
# Print output to stdout so the parent can capture it
sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER)
- sys.stdout.write(print_output + '\n')
+ sys.stdout.write(print_output + "\n")
sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER)
sys.stdout.flush()
@@ -238,17 +253,22 @@ def _get_action_instance(self):
actions_cls = action_loader.register_plugin(Action, self._file_path)
except Exception as e:
tb_msg = traceback.format_exc()
- msg = ('Failed to load action class from file "%s" (action file most likely doesn\'t '
- 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e)))
- msg += '\n\n' + tb_msg
+ msg = (
+ 'Failed to load action class from file "%s" (action file most likely doesn\'t '
+ "exist or contains invalid syntax): %s"
+ % (self._file_path, six.text_type(e))
+ )
+ msg += "\n\n" + tb_msg
exc_cls = type(e)
raise exc_cls(msg)
action_cls = actions_cls[0] if actions_cls and len(actions_cls) > 0 else None
if not action_cls:
- raise Exception('File "%s" has no action class or the file doesn\'t exist.' %
- (self._file_path))
+ raise Exception(
+ 'File "%s" has no action class or the file doesn\'t exist.'
+ % (self._file_path)
+ )
# Retrieve name of the action class
# Note - we need to either use cls.__name_ or inspect.getmro(cls)[0].__name__ to
@@ -256,31 +276,45 @@ def _get_action_instance(self):
self._class_name = action_cls.__name__
action_service = ActionService(action_wrapper=self)
- action_instance = get_action_class_instance(action_cls=action_cls,
- config=self._config,
- action_service=action_service)
+ action_instance = get_action_class_instance(
+ action_cls=action_cls, config=self._config, action_service=action_service
+ )
return action_instance
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Python action runner process wrapper')
- parser.add_argument('--pack', required=True,
- help='Name of the pack this action belongs to')
- parser.add_argument('--file-path', required=True,
- help='Path to the action module')
- parser.add_argument('--config', required=False,
- help='Pack config serialized as JSON')
- parser.add_argument('--parameters', required=False,
- help='Serialized action parameters')
- parser.add_argument('--stdin-parameters', required=False, action='store_true',
- help='Serialized action parameters via stdin')
- parser.add_argument('--user', required=False,
- help='User who triggered the action execution')
- parser.add_argument('--parent-args', required=False,
- help='Command line arguments passed to the parent process serialized as '
- ' JSON')
- parser.add_argument('--log-level', required=False, default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL,
- help='Log level for actions')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Python action runner process wrapper")
+ parser.add_argument(
+ "--pack", required=True, help="Name of the pack this action belongs to"
+ )
+ parser.add_argument("--file-path", required=True, help="Path to the action module")
+ parser.add_argument(
+ "--config", required=False, help="Pack config serialized as JSON"
+ )
+ parser.add_argument(
+ "--parameters", required=False, help="Serialized action parameters"
+ )
+ parser.add_argument(
+ "--stdin-parameters",
+ required=False,
+ action="store_true",
+ help="Serialized action parameters via stdin",
+ )
+ parser.add_argument(
+ "--user", required=False, help="User who triggered the action execution"
+ )
+ parser.add_argument(
+ "--parent-args",
+ required=False,
+ help="Command line arguments passed to the parent process serialized as "
+ " JSON",
+ )
+ parser.add_argument(
+ "--log-level",
+ required=False,
+ default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL,
+ help="Log level for actions",
+ )
args = parser.parse_args()
config = json.loads(args.config) if args.config else {}
@@ -289,46 +323,54 @@ def _get_action_instance(self):
log_level = args.log_level
if not isinstance(config, dict):
- raise ValueError('Pack config needs to be a dictionary')
+ raise ValueError("Pack config needs to be a dictionary")
parameters = {}
if args.parameters:
- LOG.debug('Getting parameters from argument')
+ LOG.debug("Getting parameters from argument")
args_parameters = args.parameters
args_parameters = json.loads(args_parameters) if args_parameters else {}
parameters.update(args_parameters)
if args.stdin_parameters:
- LOG.debug('Getting parameters from stdin')
+ LOG.debug("Getting parameters from stdin")
i, _, _ = select.select([sys.stdin], [], [], READ_STDIN_INPUT_TIMEOUT)
if not i:
- raise ValueError(('No input received and timed out while waiting for '
- 'parameters from stdin'))
+ raise ValueError(
+ (
+ "No input received and timed out while waiting for "
+ "parameters from stdin"
+ )
+ )
stdin_data = sys.stdin.readline().strip()
try:
stdin_parameters = json.loads(stdin_data)
- stdin_parameters = stdin_parameters.get('parameters', {})
+ stdin_parameters = stdin_parameters.get("parameters", {})
except Exception as e:
- msg = ('Failed to parse parameters from stdin. Expected a JSON object with '
- '"parameters" attribute: %s' % (six.text_type(e)))
+ msg = (
+ "Failed to parse parameters from stdin. Expected a JSON object with "
+ '"parameters" attribute: %s' % (six.text_type(e))
+ )
raise ValueError(msg)
parameters.update(stdin_parameters)
- LOG.debug('Received parameters: %s', parameters)
+ LOG.debug("Received parameters: %s", parameters)
assert isinstance(parent_args, list)
- obj = PythonActionWrapper(pack=args.pack,
- file_path=args.file_path,
- config=config,
- parameters=parameters,
- user=user,
- parent_args=parent_args,
- log_level=log_level)
+ obj = PythonActionWrapper(
+ pack=args.pack,
+ file_path=args.file_path,
+ config=config,
+ parameters=parameters,
+ user=user,
+ parent_args=parent_args,
+ log_level=log_level,
+ )
obj.run()
diff --git a/contrib/runners/python_runner/python_runner/python_runner.py b/contrib/runners/python_runner/python_runner/python_runner.py
index fd412c890e..b11668e000 100644
--- a/contrib/runners/python_runner/python_runner/python_runner.py
+++ b/contrib/runners/python_runner/python_runner/python_runner.py
@@ -58,34 +58,39 @@
from python_runner import python_action_wrapper
__all__ = [
- 'PythonRunner',
-
- 'get_runner',
- 'get_metadata',
+ "PythonRunner",
+ "get_runner",
+ "get_metadata",
]
LOG = logging.getLogger(__name__)
# constants to lookup in runner_parameters.
-RUNNER_ENV = 'env'
-RUNNER_TIMEOUT = 'timeout'
-RUNNER_LOG_LEVEL = 'log_level'
+RUNNER_ENV = "env"
+RUNNER_TIMEOUT = "timeout"
+RUNNER_LOG_LEVEL = "log_level"
# Environment variables which can't be specified by the user
BLACKLISTED_ENV_VARS = [
# We don't allow user to override PYTHONPATH since this would break things
- 'pythonpath'
+ "pythonpath"
]
BASE_DIR = os.path.dirname(os.path.abspath(python_action_wrapper.__file__))
-WRAPPER_SCRIPT_NAME = 'python_action_wrapper.py'
+WRAPPER_SCRIPT_NAME = "python_action_wrapper.py"
WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME)
class PythonRunner(GitWorktreeActionRunner):
-
- def __init__(self, runner_id, config=None, timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT,
- log_level=None, sandbox=True, use_parent_args=True):
+ def __init__(
+ self,
+ runner_id,
+ config=None,
+ timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT,
+ log_level=None,
+ sandbox=True,
+ use_parent_args=True,
+ ):
"""
:param timeout: Action execution timeout in seconds.
@@ -123,36 +128,42 @@ def pre_run(self):
self._log_level = cfg.CONF.actionrunner.python_runner_log_level
def run(self, action_parameters):
- LOG.debug('Running pythonrunner.')
- LOG.debug('Getting pack name.')
+ LOG.debug("Running pythonrunner.")
+ LOG.debug("Getting pack name.")
pack = self.get_pack_ref()
- LOG.debug('Getting user.')
+ LOG.debug("Getting user.")
user = self.get_user()
- LOG.debug('Serializing parameters.')
- serialized_parameters = json.dumps(action_parameters if action_parameters else {})
- LOG.debug('Getting virtualenv_path.')
+ LOG.debug("Serializing parameters.")
+ serialized_parameters = json.dumps(
+ action_parameters if action_parameters else {}
+ )
+ LOG.debug("Getting virtualenv_path.")
virtualenv_path = get_sandbox_virtualenv_path(pack=pack)
- LOG.debug('Getting python path.')
+ LOG.debug("Getting python path.")
if self._sandbox:
python_path = get_sandbox_python_binary_path(pack=pack)
else:
python_path = sys.executable
- LOG.debug('Checking virtualenv path.')
+ LOG.debug("Checking virtualenv path.")
if virtualenv_path and not os.path.isdir(virtualenv_path):
- format_values = {'pack': pack, 'virtualenv_path': virtualenv_path}
+ format_values = {"pack": pack, "virtualenv_path": virtualenv_path}
msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values
- LOG.error('virtualenv_path set but not a directory: %s', msg)
+ LOG.error("virtualenv_path set but not a directory: %s", msg)
raise Exception(msg)
- LOG.debug('Checking entry_point.')
+ LOG.debug("Checking entry_point.")
if not self.entry_point:
- LOG.error('Action "%s" is missing entry_point attribute' % (self.action.name))
- raise Exception('Action "%s" is missing entry_point attribute' % (self.action.name))
+ LOG.error(
+ 'Action "%s" is missing entry_point attribute' % (self.action.name)
+ )
+ raise Exception(
+ 'Action "%s" is missing entry_point attribute' % (self.action.name)
+ )
# Note: We pass config as command line args so the actual wrapper process is standalone
# and doesn't need access to db
- LOG.debug('Setting args.')
+ LOG.debug("Setting args.")
if self._use_parent_args:
parent_args = json.dumps(sys.argv[1:])
@@ -161,12 +172,12 @@ def run(self, action_parameters):
args = [
python_path,
- '-u', # unbuffered mode so streaming mode works as expected
+ "-u", # unbuffered mode so streaming mode works as expected
WRAPPER_SCRIPT_PATH,
- '--pack=%s' % (pack),
- '--file-path=%s' % (self.entry_point),
- '--user=%s' % (user),
- '--parent-args=%s' % (parent_args),
+ "--pack=%s" % (pack),
+ "--file-path=%s" % (self.entry_point),
+ "--user=%s" % (user),
+ "--parent-args=%s" % (parent_args),
]
subprocess = concurrency.get_subprocess_module()
@@ -178,35 +189,36 @@ def run(self, action_parameters):
stdin_params = None
if len(serialized_parameters) >= MAX_PARAM_LENGTH:
stdin = subprocess.PIPE
- LOG.debug('Parameters are too big...changing to stdin')
+ LOG.debug("Parameters are too big...changing to stdin")
stdin_params = '{"parameters": %s}\n' % (serialized_parameters)
- args.append('--stdin-parameters')
+ args.append("--stdin-parameters")
else:
- LOG.debug('Parameters are just right...adding them to arguments')
- args.append('--parameters=%s' % (serialized_parameters))
+ LOG.debug("Parameters are just right...adding them to arguments")
+ args.append("--parameters=%s" % (serialized_parameters))
if self._config:
- args.append('--config=%s' % (json.dumps(self._config)))
+ args.append("--config=%s" % (json.dumps(self._config)))
if self._log_level != PYTHON_RUNNER_DEFAULT_LOG_LEVEL:
# We only pass --log-level parameter if non default log level value is specified
- args.append('--log-level=%s' % (self._log_level))
+ args.append("--log-level=%s" % (self._log_level))
# We need to ensure all the st2 dependencies are also available to the subprocess
- LOG.debug('Setting env.')
+ LOG.debug("Setting env.")
env = os.environ.copy()
- env['PATH'] = get_sandbox_path(virtualenv_path=virtualenv_path)
+ env["PATH"] = get_sandbox_path(virtualenv_path=virtualenv_path)
sandbox_python_path = get_sandbox_python_path_for_python_action(
- pack=pack,
- inherit_from_parent=True,
- inherit_parent_virtualenv=True)
+ pack=pack, inherit_from_parent=True, inherit_parent_virtualenv=True
+ )
if self._enable_common_pack_libs:
try:
pack_common_libs_path = self._get_pack_common_libs_path(pack_ref=pack)
except Exception as e:
- LOG.debug('Failed to retrieve pack common lib path: %s' % (six.text_type(e)))
+ LOG.debug(
+ "Failed to retrieve pack common lib path: %s" % (six.text_type(e))
+ )
# There is no MongoDB connection available in Lambda and pack common lib
# functionality is not also mandatory for Lambda so we simply ignore those errors.
# Note: We should eventually refactor this code to make runner standalone and not
@@ -217,13 +229,13 @@ def run(self, action_parameters):
pack_common_libs_path = None
# Remove leading : (if any)
- if sandbox_python_path.startswith(':'):
+ if sandbox_python_path.startswith(":"):
sandbox_python_path = sandbox_python_path[1:]
if self._enable_common_pack_libs and pack_common_libs_path:
- sandbox_python_path = pack_common_libs_path + ':' + sandbox_python_path
+ sandbox_python_path = pack_common_libs_path + ":" + sandbox_python_path
- env['PYTHONPATH'] = sandbox_python_path
+ env["PYTHONPATH"] = sandbox_python_path
# Include user provided environment variables (if any)
user_env_vars = self._get_env_vars()
@@ -238,40 +250,53 @@ def run(self, action_parameters):
stdout = StringIO()
stderr = StringIO()
- store_execution_stdout_line = functools.partial(store_execution_output_data,
- output_type='stdout')
- store_execution_stderr_line = functools.partial(store_execution_output_data,
- output_type='stderr')
-
- read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution,
- action_db=self.action, store_data_func=store_execution_stdout_line)
- read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution,
- action_db=self.action, store_data_func=store_execution_stderr_line)
+ store_execution_stdout_line = functools.partial(
+ store_execution_output_data, output_type="stdout"
+ )
+ store_execution_stderr_line = functools.partial(
+ store_execution_output_data, output_type="stderr"
+ )
+
+ read_and_store_stdout = make_read_and_store_stream_func(
+ execution_db=self.execution,
+ action_db=self.action,
+ store_data_func=store_execution_stdout_line,
+ )
+ read_and_store_stderr = make_read_and_store_stream_func(
+ execution_db=self.execution,
+ action_db=self.action,
+ store_data_func=store_execution_stderr_line,
+ )
command_string = list2cmdline(args)
if stdin_params:
- command_string = 'echo %s | %s' % (quote_unix(stdin_params), command_string)
+ command_string = "echo %s | %s" % (quote_unix(stdin_params), command_string)
bufsize = cfg.CONF.actionrunner.stream_output_buffer_size
- LOG.debug('Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s' % (bufsize, env['PATH'],
- env['PYTHONPATH'],
- command_string))
- exit_code, stdout, stderr, timed_out = run_command(cmd=args,
- stdin=stdin,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- shell=False,
- env=env,
- timeout=self._timeout,
- read_stdout_func=read_and_store_stdout,
- read_stderr_func=read_and_store_stderr,
- read_stdout_buffer=stdout,
- read_stderr_buffer=stderr,
- stdin_value=stdin_params,
- bufsize=bufsize)
- LOG.debug('Returning values: %s, %s, %s, %s', exit_code, stdout, stderr, timed_out)
- LOG.debug('Returning.')
+ LOG.debug(
+ "Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s"
+ % (bufsize, env["PATH"], env["PYTHONPATH"], command_string)
+ )
+ exit_code, stdout, stderr, timed_out = run_command(
+ cmd=args,
+ stdin=stdin,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ env=env,
+ timeout=self._timeout,
+ read_stdout_func=read_and_store_stdout,
+ read_stderr_func=read_and_store_stderr,
+ read_stdout_buffer=stdout,
+ read_stderr_buffer=stderr,
+ stdin_value=stdin_params,
+ bufsize=bufsize,
+ )
+ LOG.debug(
+ "Returning values: %s, %s, %s, %s", exit_code, stdout, stderr, timed_out
+ )
+ LOG.debug("Returning.")
return self._get_output_values(exit_code, stdout, stderr, timed_out)
def _get_pack_common_libs_path(self, pack_ref):
@@ -280,7 +305,9 @@ def _get_pack_common_libs_path(self, pack_ref):
(if used).
"""
worktree_path = self.git_worktree_path
- pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref)
+ pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(
+ pack_ref=pack_ref
+ )
if not worktree_path:
return pack_common_libs_path
@@ -288,18 +315,20 @@ def _get_pack_common_libs_path(self, pack_ref):
# Modify the path so it uses git worktree directory
pack_base_path = get_pack_base_path(pack_name=pack_ref)
- new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, '')
+ new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, "")
# Remove leading slash (if any)
- if new_pack_common_libs_path.startswith('/'):
+ if new_pack_common_libs_path.startswith("/"):
new_pack_common_libs_path = new_pack_common_libs_path[1:]
- new_pack_common_libs_path = os.path.join(worktree_path, new_pack_common_libs_path)
+ new_pack_common_libs_path = os.path.join(
+ worktree_path, new_pack_common_libs_path
+ )
# Check to prevent directory traversal
common_prefix = os.path.commonprefix([worktree_path, new_pack_common_libs_path])
if common_prefix != worktree_path:
- raise ValueError('pack libs path is not located inside the pack directory')
+ raise ValueError("pack libs path is not located inside the pack directory")
return new_pack_common_libs_path
@@ -312,7 +341,7 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out):
:rtype: ``tuple``
"""
if timed_out:
- error = 'Action failed to complete in %s seconds' % (self._timeout)
+ error = "Action failed to complete in %s seconds" % (self._timeout)
else:
error = None
@@ -335,16 +364,18 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out):
action_result = json.loads(action_result)
except Exception as e:
# Failed to de-serialize the result, probably it contains non-simple type or similar
- LOG.warning('Failed to de-serialize result "%s": %s' % (str(action_result),
- six.text_type(e)))
+ LOG.warning(
+ 'Failed to de-serialize result "%s": %s'
+ % (str(action_result), six.text_type(e))
+ )
if action_result:
if isinstance(action_result, dict):
- result = action_result.get('result', None)
- status = action_result.get('status', None)
+ result = action_result.get("result", None)
+ status = action_result.get("status", None)
else:
# Failed to de-serialize action result aka result is a string
- match = re.search("'result': (.*?)$", action_result or '')
+ match = re.search("'result': (.*?)$", action_result or "")
if match:
action_result = match.groups()[0]
@@ -352,21 +383,22 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out):
result = action_result
status = None
else:
- result = 'None'
+ result = "None"
status = None
output = {
- 'stdout': stdout,
- 'stderr': stderr,
- 'exit_code': exit_code,
- 'result': result
+ "stdout": stdout,
+ "stderr": stderr,
+ "exit_code": exit_code,
+ "result": result,
}
if error:
- output['error'] = error
+ output["error"] = error
- status = self._get_final_status(action_status=status, timed_out=timed_out,
- exit_code=exit_code)
+ status = self._get_final_status(
+ action_status=status, timed_out=timed_out, exit_code=exit_code
+ )
return (status, output, None)
def _get_final_status(self, action_status, timed_out, exit_code):
@@ -415,8 +447,10 @@ def _get_env_vars(self):
to_delete.append(key)
for key in to_delete:
- LOG.debug('User specified environment variable "%s" which is being ignored...' %
- (key))
+ LOG.debug(
+ 'User specified environment variable "%s" which is being ignored...'
+ % (key)
+ )
del env_vars[key]
return env_vars
@@ -441,4 +475,4 @@ def get_runner(config=None):
def get_metadata():
- return get_runner_metadata('python_runner')[0]
+ return get_runner_metadata("python_runner")[0]
diff --git a/contrib/runners/python_runner/setup.py b/contrib/runners/python_runner/setup.py
index c1a5d6c20a..04e55a31c0 100644
--- a/contrib/runners/python_runner/setup.py
+++ b/contrib/runners/python_runner/setup.py
@@ -26,30 +26,30 @@
from python_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-python',
+ name="stackstorm-runner-python",
version=__version__,
- description='Python action runner for StackStorm event-driven automation platform',
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="Python action runner for StackStorm event-driven automation platform",
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'python_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"python_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'python-script = python_runner.python_runner',
+ "st2common.runners.runner": [
+ "python-script = python_runner.python_runner",
],
- }
+ },
)
diff --git a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py
index 27e42ecc5a..e1d39361a2 100644
--- a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py
+++ b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py
@@ -42,49 +42,53 @@
from st2common.util.shell import run_command
from six.moves import range
-__all__ = [
- 'PythonRunnerActionWrapperProcessTestCase'
-]
+__all__ = ["PythonRunnerActionWrapperProcessTestCase"]
# Maximum limit for the process wrapper script execution time (in seconds)
WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT = 0.31
-ASSERTION_ERROR_MESSAGE = ("""
+ASSERTION_ERROR_MESSAGE = """
Python wrapper process script took more than %s seconds to execute (%s). This most likely means
that a direct or in-direct import of a module which takes a long time to load has been added (e.g.
jsonschema, pecan, kombu, etc).
Please review recently changed and added code for potential slow import issues and refactor /
re-organize code if possible.
-""".strip())
+""".strip()
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR,
- '../../../python_runner/python_runner/python_action_wrapper.py')
+WRAPPER_SCRIPT_PATH = os.path.join(
+ BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py"
+)
WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH)
-TIME_BINARY_PATH = find_executable('time')
+TIME_BINARY_PATH = find_executable("time")
TIME_BINARY_AVAILABLE = TIME_BINARY_PATH is not None
-@unittest2.skipIf(not TIME_BINARY_PATH, 'time binary not available')
+@unittest2.skipIf(not TIME_BINARY_PATH, "time binary not available")
class PythonRunnerActionWrapperProcessTestCase(unittest2.TestCase):
def test_process_wrapper_exits_in_reasonable_timeframe(self):
# 1. Verify wrapper script path is correct and file exists
self.assertTrue(os.path.isfile(WRAPPER_SCRIPT_PATH))
# 2. First run it without time to verify path is valid
- command_string = 'python %s --file-path=foo.py' % (WRAPPER_SCRIPT_PATH)
+ command_string = "python %s --file-path=foo.py" % (WRAPPER_SCRIPT_PATH)
_, _, stderr = run_command(command_string, shell=True)
- self.assertIn('usage: python_action_wrapper.py', stderr)
+ self.assertIn("usage: python_action_wrapper.py", stderr)
- expected_msg_1 = 'python_action_wrapper.py: error: argument --pack is required'
- expected_msg_2 = ('python_action_wrapper.py: error: the following arguments are '
- 'required: --pack')
+ expected_msg_1 = "python_action_wrapper.py: error: argument --pack is required"
+ expected_msg_2 = (
+ "python_action_wrapper.py: error: the following arguments are "
+ "required: --pack"
+ )
self.assertTrue(expected_msg_1 in stderr or expected_msg_2 in stderr)
# 3. Now time it
- command_string = '%s -f "%%e" python %s' % (TIME_BINARY_PATH, WRAPPER_SCRIPT_PATH)
+ command_string = '%s -f "%%e" python %s' % (
+ TIME_BINARY_PATH,
+ WRAPPER_SCRIPT_PATH,
+ )
# Do multiple runs and average it
run_times = []
@@ -92,14 +96,18 @@ def test_process_wrapper_exits_in_reasonable_timeframe(self):
count = 8
for i in range(0, count):
_, _, stderr = run_command(command_string, shell=True)
- stderr = stderr.strip().split('\n')[-1]
+ stderr = stderr.strip().split("\n")[-1]
run_time_seconds = float(stderr)
run_times.append(run_time_seconds)
- avg_run_time_seconds = (sum(run_times) / count)
- assertion_msg = ASSERTION_ERROR_MESSAGE % (WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT,
- avg_run_time_seconds)
- self.assertTrue(avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg)
+ avg_run_time_seconds = sum(run_times) / count
+ assertion_msg = ASSERTION_ERROR_MESSAGE % (
+ WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT,
+ avg_run_time_seconds,
+ )
+ self.assertTrue(
+ avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg
+ )
def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self):
# Test case which verifies that actions with large configs and a lot of parameters work
@@ -107,48 +115,55 @@ def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self):
# upper limit on the size.
config = {}
for index in range(0, 50):
- config['key_%s' % (index)] = 'value value foo %s' % (index)
+ config["key_%s" % (index)] = "value value foo %s" % (index)
config = json.dumps(config)
parameters = {}
for index in range(0, 30):
- parameters['param_foo_%s' % (index)] = 'some param value %s' % (index)
+ parameters["param_foo_%s" % (index)] = "some param value %s" % (index)
parameters = json.dumps(parameters)
- file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py')
+ file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py")
- command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' '
- '--parameters=\'%s\'' %
- (WRAPPER_SCRIPT_PATH, file_path, config, parameters))
+ command_string = (
+ "python %s --pack=dummy --file-path=%s --config='%s' "
+ "--parameters='%s'" % (WRAPPER_SCRIPT_PATH, file_path, config, parameters)
+ )
exit_code, stdout, stderr = run_command(command_string, shell=True)
self.assertEqual(exit_code, 0)
self.assertIn('"status"', stdout)
def test_stdin_params_timeout_no_stdin_data_provided(self):
config = {}
- file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py')
+ file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py")
# try running in a sub-shell to ensure that the stdin is empty
- command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' '
- '--stdin-parameters' %
- (WRAPPER_SCRIPT_PATH, file_path, config))
+ command_string = (
+ "python %s --pack=dummy --file-path=%s --config='%s' "
+ "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config)
+ )
exit_code, stdout, stderr = run_command(command_string, shell=True)
- expected_msg = ('ValueError: No input received and timed out while waiting for parameters '
- 'from stdin')
+ expected_msg = (
+ "ValueError: No input received and timed out while waiting for parameters "
+ "from stdin"
+ )
self.assertEqual(exit_code, 1)
self.assertIn(expected_msg, stderr)
def test_stdin_params_invalid_format_friendly_error(self):
config = {}
- file_path = os.path.join(BASE_DIR, '../../../contrib/examples/actions/noop.py')
+ file_path = os.path.join(BASE_DIR, "../../../contrib/examples/actions/noop.py")
# Not a valid JSON string
- command_string = ('echo "invalid" | python %s --pack=dummy --file-path=%s --config=\'%s\' '
- '--stdin-parameters' %
- (WRAPPER_SCRIPT_PATH, file_path, config))
+ command_string = (
+ "echo \"invalid\" | python %s --pack=dummy --file-path=%s --config='%s' "
+ "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config)
+ )
exit_code, stdout, stderr = run_command(command_string, shell=True)
- expected_msg = ('ValueError: Failed to parse parameters from stdin. Expected a JSON '
- 'object with "parameters" attribute')
+ expected_msg = (
+ "ValueError: Failed to parse parameters from stdin. Expected a JSON "
+ 'object with "parameters" attribute'
+ )
self.assertEqual(exit_code, 1)
self.assertIn(expected_msg, stderr)
diff --git a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py
index 328a4a0fc0..a6d300be23 100644
--- a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py
+++ b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py
@@ -30,13 +30,12 @@
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'PythonRunnerBehaviorTestCase'
-]
+__all__ = ["PythonRunnerBehaviorTestCase"]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR,
- '../../../python_runner/python_runner/python_action_wrapper.py')
+WRAPPER_SCRIPT_PATH = os.path.join(
+ BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py"
+)
WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH)
@@ -46,24 +45,24 @@ def setUp(self):
config.parse_args()
dir_path = tempfile.mkdtemp()
- cfg.CONF.set_override(name='base_path', override=dir_path, group='system')
+ cfg.CONF.set_override(name="base_path", override=dir_path, group="system")
self.base_path = dir_path
- self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/')
+ self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/")
# Make sure dir is deleted on tearDown
self.to_delete_directories.append(self.base_path)
def test_priority_of_loading_library_after_setup_pack_virtualenv(self):
- '''
+ """
This test checks priority of loading library, whether the library which is specified in
the 'requirements.txt' of pack is loaded when a same name module is also specified in the
'requirements.txt' of st2, at a subprocess in ActionRunner.
To test above, this uses 'get_library_path.py' action in 'test_library_dependencies' pack.
This action returns file-path of imported module which is specified by 'module' parameter.
- '''
- pack_name = 'test_library_dependencies'
+ """
+ pack_name = "test_library_dependencies"
# Before calling action, this sets up virtualenv for test pack. This pack has
# requirements.txt wihch only writes 'six' module.
@@ -72,20 +71,25 @@ def test_priority_of_loading_library_after_setup_pack_virtualenv(self):
# This test suite expects that loaded six module is located under the virtualenv library,
# because 'six' is written in the requirements.txt of 'test_library_dependencies' pack.
- (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'})
- self.assertEqual(output['result'].find(self.virtualenvs_path), 0)
+ (_, output, _) = self._run_action(
+ pack_name, "get_library_path.py", {"module": "six"}
+ )
+ self.assertEqual(output["result"].find(self.virtualenvs_path), 0)
# Conversely, this expects that 'mock' module file-path is not under sandbox library,
# but the parent process's library path, because that is not under the pack's virtualenv.
- (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'mock'})
- self.assertEqual(output['result'].find(self.virtualenvs_path), -1)
+ (_, output, _) = self._run_action(
+ pack_name, "get_library_path.py", {"module": "mock"}
+ )
+ self.assertEqual(output["result"].find(self.virtualenvs_path), -1)
# While a module which is in the pack's virtualenv library is specified at 'module'
# parameter of the action, this test suite expects that file-path under the parent's
# library is returned when 'sandbox' parameter of PythonRunner is False.
- (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'},
- {'_sandbox': False})
- self.assertEqual(output['result'].find(self.virtualenvs_path), -1)
+ (_, output, _) = self._run_action(
+ pack_name, "get_library_path.py", {"module": "six"}, {"_sandbox": False}
+ )
+ self.assertEqual(output["result"].find(self.virtualenvs_path), -1)
def _run_action(self, pack, action, params, runner_params={}):
action_db = mock.Mock()
@@ -99,7 +103,8 @@ def _run_action(self, pack, action, params, runner_params={}):
for key, value in runner_params.items():
setattr(runner, key, value)
- runner.entry_point = os.path.join(get_fixtures_base_path(),
- 'packs/%s/actions/%s' % (pack, action))
+ runner.entry_point = os.path.join(
+ get_fixtures_base_path(), "packs/%s/actions/%s" % (pack, action)
+ )
runner.pre_run()
return runner.run(params)
diff --git a/contrib/runners/python_runner/tests/unit/test_output_schema.py b/contrib/runners/python_runner/tests/unit/test_output_schema.py
index 218ba669a6..218a8f0732 100644
--- a/contrib/runners/python_runner/tests/unit/test_output_schema.py
+++ b/contrib/runners/python_runner/tests/unit/test_output_schema.py
@@ -33,15 +33,16 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/pascal_row.py')
+PASCAL_ROW_ACTION_PATH = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py"
+)
MOCK_SYS = mock.Mock()
MOCK_SYS.argv = []
MOCK_SYS.executable = sys.executable
MOCK_EXECUTION = mock.Mock()
-MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b'
+MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b"
FAIL_SCHEMA = {
"notvalid": {
@@ -50,7 +51,7 @@
}
-@mock.patch('python_runner.python_runner.sys', MOCK_SYS)
+@mock.patch("python_runner.python_runner.sys", MOCK_SYS)
class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase):
register_packs = True
register_pack_configs = True
@@ -61,29 +62,23 @@ def setUpClass(cls):
assert_submodules_are_checked_out()
def test_adherence_to_output_schema(self):
- config = self.loader(os.path.join(BASE_DIR, '../../runner.yaml'))
+ config = self.loader(os.path.join(BASE_DIR, "../../runner.yaml"))
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 5})
- output_schema._validate_runner(
- config[0]['output_schema'],
- output
- )
+ (status, output, _) = runner.run({"row_index": 5})
+ output_schema._validate_runner(config[0]["output_schema"], output)
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(output)
- self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1])
+ self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1])
def test_fail_incorrect_output_schema(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 5})
+ (status, output, _) = runner.run({"row_index": 5})
with self.assertRaises(jsonschema.ValidationError):
- output_schema._validate_runner(
- FAIL_SCHEMA,
- output
- )
+ output_schema._validate_runner(FAIL_SCHEMA, output)
def _get_mock_runner_obj(self, pack=None, sandbox=None):
runner = python_runner.get_runner()
@@ -106,10 +101,8 @@ def _get_mock_action_obj(self):
Pack gets set to the system pack so the action doesn't require a separate virtualenv.
"""
action = mock.Mock()
- action.ref = 'dummy.action'
+ action.ref = "dummy.action"
action.pack = SYSTEM_PACK_NAME
- action.entry_point = 'foo.py'
- action.runner_type = {
- 'name': 'python-script'
- }
+ action.entry_point = "foo.py"
+ action.runner_type = {"name": "python-script"}
return action
diff --git a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py
index 8d55f8262d..940d087af6 100644
--- a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py
+++ b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py
@@ -29,7 +29,10 @@
from st2common.runners.utils import get_action_class_instance
from st2common.services import config as config_service
from st2common.constants.action import ACTION_OUTPUT_RESULT_DELIMITER
-from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED
+from st2common.constants.action import (
+ LIVEACTION_STATUS_SUCCEEDED,
+ LIVEACTION_STATUS_FAILED,
+)
from st2common.constants.action import LIVEACTION_STATUS_TIMED_OUT
from st2common.constants.action import MAX_PARAM_LENGTH
from st2common.constants.pack import SYSTEM_PACK_NAME
@@ -43,29 +46,49 @@
import st2tests.base as tests_base
-PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/pascal_row.py')
-ECHOER_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/echoer.py')
-TEST_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/test.py')
-PATHS_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/python_paths.py')
-ACTION_1_PATH = os.path.join(tests_base.get_fixtures_path(),
- 'packs/dummy_pack_9/actions/list_repos_doesnt_exist.py')
-ACTION_2_PATH = os.path.join(tests_base.get_fixtures_path(),
- 'packs/dummy_pack_9/actions/invalid_syntax.py')
-NON_SIMPLE_TYPE_ACTION = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/non_simple_type.py')
-PRINT_VERSION_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs',
- 'test_content_version/actions/print_version.py')
-PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs',
- 'test_content_version/actions/print_version_local_import.py')
-
-PRINT_CONFIG_ITEM_ACTION = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/print_config_item_doesnt_exist.py')
-PRINT_TO_STDOUT_STDERR_ACTION = os.path.join(tests_base.get_resources_path(), 'packs',
- 'pythonactions/actions/print_to_stdout_and_stderr.py')
+PASCAL_ROW_ACTION_PATH = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py"
+)
+ECHOER_ACTION_PATH = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/echoer.py"
+)
+TEST_ACTION_PATH = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/test.py"
+)
+PATHS_ACTION_PATH = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/python_paths.py"
+)
+ACTION_1_PATH = os.path.join(
+ tests_base.get_fixtures_path(),
+ "packs/dummy_pack_9/actions/list_repos_doesnt_exist.py",
+)
+ACTION_2_PATH = os.path.join(
+ tests_base.get_fixtures_path(), "packs/dummy_pack_9/actions/invalid_syntax.py"
+)
+NON_SIMPLE_TYPE_ACTION = os.path.join(
+ tests_base.get_resources_path(), "packs", "pythonactions/actions/non_simple_type.py"
+)
+PRINT_VERSION_ACTION = os.path.join(
+ tests_base.get_fixtures_path(),
+ "packs",
+ "test_content_version/actions/print_version.py",
+)
+PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join(
+ tests_base.get_fixtures_path(),
+ "packs",
+ "test_content_version/actions/print_version_local_import.py",
+)
+
+PRINT_CONFIG_ITEM_ACTION = os.path.join(
+ tests_base.get_resources_path(),
+ "packs",
+ "pythonactions/actions/print_config_item_doesnt_exist.py",
+)
+PRINT_TO_STDOUT_STDERR_ACTION = os.path.join(
+ tests_base.get_resources_path(),
+ "packs",
+ "pythonactions/actions/print_to_stdout_and_stderr.py",
+)
# Note: runner inherits parent args which doesn't work with tests since test pass additional
@@ -75,10 +98,10 @@
mock_sys.executable = sys.executable
MOCK_EXECUTION = mock.Mock()
-MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b'
+MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b"
-@mock.patch('python_runner.python_runner.sys', mock_sys)
+@mock.patch("python_runner.python_runner.sys", mock_sys)
class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase):
register_packs = True
register_pack_configs = True
@@ -90,8 +113,10 @@ def setUpClass(cls):
def test_runner_creation(self):
runner = python_runner.get_runner()
- self.assertIsNotNone(runner, 'Creation failed. No instance.')
- self.assertEqual(type(runner), python_runner.PythonRunner, 'Creation failed. No instance.')
+ self.assertIsNotNone(runner, "Creation failed. No instance.")
+ self.assertEqual(
+ type(runner), python_runner.PythonRunner, "Creation failed. No instance."
+ )
def test_action_returns_non_serializable_result(self):
# Actions returns non-simple type which can't be serialized, verify result is simple str()
@@ -105,33 +130,37 @@ def test_action_returns_non_serializable_result(self):
self.assertIsNotNone(output)
if six.PY2:
- expected_result_re = (r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': "
- r"}\]")
+ expected_result_re = (
+ r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': "
+ r"}\]"
+ )
else:
- expected_result_re = (r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': "
- r"}\]")
+ expected_result_re = (
+ r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': "
+ r"}\]"
+ )
- match = re.match(expected_result_re, output['result'])
+ match = re.match(expected_result_re, output["result"])
self.assertTrue(match)
def test_simple_action_with_result_no_status(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 5})
+ (status, output, _) = runner.run({"row_index": 5})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(output)
- self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1])
+ self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1])
def test_simple_action_with_result_as_None_no_status(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 'b'})
+ (status, output, _) = runner.run({"row_index": "b"})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(output)
- self.assertEqual(output['exit_code'], 0)
- self.assertEqual(output['result'], None)
+ self.assertEqual(output["exit_code"], 0)
+ self.assertEqual(output["result"], None)
def test_simple_action_timeout(self):
timeout = 0
@@ -139,30 +168,30 @@ def test_simple_action_timeout(self):
runner.runner_parameters = {python_runner.RUNNER_TIMEOUT: timeout}
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 4})
+ (status, output, _) = runner.run({"row_index": 4})
self.assertEqual(status, LIVEACTION_STATUS_TIMED_OUT)
self.assertIsNotNone(output)
- self.assertEqual(output['result'], 'None')
- self.assertEqual(output['error'], 'Action failed to complete in 0 seconds')
- self.assertEqual(output['exit_code'], -9)
+ self.assertEqual(output["result"], "None")
+ self.assertEqual(output["error"], "Action failed to complete in 0 seconds")
+ self.assertEqual(output["exit_code"], -9)
def test_simple_action_with_status_succeeded(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 4})
+ (status, output, _) = runner.run({"row_index": 4})
self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED)
self.assertIsNotNone(output)
- self.assertEqual(output['result'], [1, 4, 6, 4, 1])
+ self.assertEqual(output["result"], [1, 4, 6, 4, 1])
def test_simple_action_with_status_failed(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 'a'})
+ (status, output, _) = runner.run({"row_index": "a"})
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
self.assertIsNotNone(output)
- self.assertEqual(output['result'], "This is suppose to fail don't worry!!")
+ self.assertEqual(output["result"], "This is suppose to fail don't worry!!")
def test_simple_action_with_status_complex_type_returned_for_result(self):
# Result containing a complex type shouldn't break the returning a tuple with status
@@ -170,78 +199,79 @@ def test_simple_action_with_status_complex_type_returned_for_result(self):
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
runner.pre_run()
- (status, output, _) = runner.run({'row_index': 'complex_type'})
+ (status, output, _) = runner.run({"row_index": "complex_type"})
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
self.assertIsNotNone(output)
- self.assertIn('.*" %
- runner.git_worktree_path)
- self.assertRegexpMatches(output['stdout'].strip(), expected_stdout)
+ expected_stdout = (
+ ".*"
+ % runner.git_worktree_path
+ )
+ self.assertRegexpMatches(output["stdout"].strip(), expected_stdout)
- @mock.patch('st2common.runners.base.run_command')
+ @mock.patch("st2common.runners.base.run_command")
def test_content_version_old_git_version(self, mock_run_command):
- mock_stdout = ''
- mock_stderr = '''
+ mock_stdout = ""
+ mock_stderr = """
git: 'worktree' is not a git command. See 'git --help'.
-'''
+"""
mock_stderr = six.text_type(mock_stderr)
mock_run_command.return_value = 1, mock_stdout, mock_stderr, False
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
- runner.runner_parameters = {'content_version': 'v0.10.0'}
+ runner.runner_parameters = {"content_version": "v0.10.0"}
- expected_msg = (r'Failed to create git worktree for pack "core": Installed git version '
- 'doesn\'t support git worktree command. To be able to utilize this '
- 'functionality you need to use git >= 2.5.0.')
+ expected_msg = (
+ r'Failed to create git worktree for pack "core": Installed git version '
+ "doesn't support git worktree command. To be able to utilize this "
+ "functionality you need to use git >= 2.5.0."
+ )
self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run)
- @mock.patch('st2common.runners.base.run_command')
+ @mock.patch("st2common.runners.base.run_command")
def test_content_version_pack_repo_not_git_repository(self, mock_run_command):
- mock_stdout = ''
- mock_stderr = '''
+ mock_stdout = ""
+ mock_stderr = """
fatal: Not a git repository (or any parent up to mount point /home)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
-'''
+"""
mock_stderr = six.text_type(mock_stderr)
mock_run_command.return_value = 1, mock_stdout, mock_stderr, False
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
- runner.runner_parameters = {'content_version': 'v0.10.0'}
-
- expected_msg = (r'Failed to create git worktree for pack "core": Pack directory '
- '".*" is not a '
- 'git repository. To utilize this functionality, pack directory needs to '
- 'be a git repository.')
+ runner.runner_parameters = {"content_version": "v0.10.0"}
+
+ expected_msg = (
+ r'Failed to create git worktree for pack "core": Pack directory '
+ '".*" is not a '
+ "git repository. To utilize this functionality, pack directory needs to "
+ "be a git repository."
+ )
self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run)
- @mock.patch('st2common.runners.base.run_command')
+ @mock.patch("st2common.runners.base.run_command")
def test_content_version_invalid_git_revision(self, mock_run_command):
- mock_stdout = ''
- mock_stderr = '''
+ mock_stdout = ""
+ mock_stderr = """
fatal: invalid reference: vinvalid
-'''
+"""
mock_stderr = six.text_type(mock_stderr)
mock_run_command.return_value = 1, mock_stdout, mock_stderr, False
runner = self._get_mock_runner_obj()
runner.entry_point = PASCAL_ROW_ACTION_PATH
- runner.runner_parameters = {'content_version': 'vinvalid'}
+ runner.runner_parameters = {"content_version": "vinvalid"}
- expected_msg = (r'Failed to create git worktree for pack "core": Invalid content_version '
- '"vinvalid" provided. Make sure that git repository is up '
- 'to date and contains that revision.')
+ expected_msg = (
+ r'Failed to create git worktree for pack "core": Invalid content_version '
+ '"vinvalid" provided. Make sure that git repository is up '
+ "to date and contains that revision."
+ )
self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run)
def test_missing_config_item_user_friendly_error(self):
@@ -953,10 +1051,12 @@ def test_missing_config_item_user_friendly_error(self):
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
self.assertIsNotNone(output)
- self.assertIn('{}', output['stdout'])
- self.assertIn('default_value', output['stdout'])
- self.assertIn('Config for pack "core" is missing key "key"', output['stderr'])
- self.assertIn('make sure you run "st2ctl reload --register-configs"', output['stderr'])
+ self.assertIn("{}", output["stdout"])
+ self.assertIn("default_value", output["stdout"])
+ self.assertIn('Config for pack "core" is missing key "key"', output["stderr"])
+ self.assertIn(
+ 'make sure you run "st2ctl reload --register-configs"', output["stderr"]
+ )
def _get_mock_runner_obj(self, pack=None, sandbox=None):
runner = python_runner.get_runner()
@@ -972,22 +1072,25 @@ def _get_mock_runner_obj(self, pack=None, sandbox=None):
return runner
- @mock.patch('st2actions.container.base.ActionExecution.get', mock.Mock())
+ @mock.patch("st2actions.container.base.ActionExecution.get", mock.Mock())
def _get_mock_runner_obj_from_container(self, pack, user, sandbox=None):
container = RunnerContainer()
runnertype_db = mock.Mock()
- runnertype_db.name = 'python-script'
- runnertype_db.runner_package = 'python_runner'
- runnertype_db.runner_module = 'python_runner'
+ runnertype_db.name = "python-script"
+ runnertype_db.runner_package = "python_runner"
+ runnertype_db.runner_module = "python_runner"
action_db = mock.Mock()
action_db.pack = pack
- action_db.entry_point = 'foo.py'
+ action_db.entry_point = "foo.py"
liveaction_db = mock.Mock()
- liveaction_db.id = '123'
- liveaction_db.context = {'user': user}
- runner = container._get_runner(runner_type_db=runnertype_db, action_db=action_db,
- liveaction_db=liveaction_db)
+ liveaction_db.id = "123"
+ liveaction_db.context = {"user": user}
+ runner = container._get_runner(
+ runner_type_db=runnertype_db,
+ action_db=action_db,
+ liveaction_db=liveaction_db,
+ )
runner.execution = MOCK_EXECUTION
runner.action = action_db
runner.runner_parameters = {}
@@ -1004,10 +1107,8 @@ def _get_mock_action_obj(self):
Pack gets set to the system pack so the action doesn't require a separate virtualenv.
"""
action = mock.Mock()
- action.ref = 'dummy.action'
+ action.ref = "dummy.action"
action.pack = SYSTEM_PACK_NAME
- action.entry_point = 'foo.py'
- action.runner_type = {
- 'name': 'python-script'
- }
+ action.entry_point = "foo.py"
+ action.runner_type = {"name": "python-script"}
return action
diff --git a/contrib/runners/remote_runner/dist_utils.py b/contrib/runners/remote_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/remote_runner/dist_utils.py
+++ b/contrib/runners/remote_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/remote_runner/remote_runner/__init__.py b/contrib/runners/remote_runner/remote_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/remote_runner/remote_runner/__init__.py
+++ b/contrib/runners/remote_runner/remote_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py
index 09382d9125..60880a0431 100644
--- a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py
+++ b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py
@@ -24,12 +24,7 @@
from st2common.runners.base import get_metadata as get_runner_metadata
from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction
-__all__ = [
- 'ParamikoRemoteCommandRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["ParamikoRemoteCommandRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
@@ -38,42 +33,52 @@ class ParamikoRemoteCommandRunner(BaseParallelSSHRunner):
def run(self, action_parameters):
remote_action = self._get_remote_action(action_parameters)
- LOG.debug('Executing remote command action.', extra={'_action_params': remote_action})
+ LOG.debug(
+ "Executing remote command action.", extra={"_action_params": remote_action}
+ )
result = self._run(remote_action)
- LOG.debug('Executed remote_action.', extra={'_result': result})
- status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure)
+ LOG.debug("Executed remote_action.", extra={"_result": result})
+ status = self._get_result_status(
+ result, cfg.CONF.ssh_runner.allow_partial_failure
+ )
return (status, result, None)
def _run(self, remote_action):
command = remote_action.get_full_command_string()
- return self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout())
+ return self._parallel_ssh_client.run(
+ command, timeout=remote_action.get_timeout()
+ )
def _get_remote_action(self, action_paramaters):
# remote script actions with entry_point don't make sense, user probably wanted to use
# "remote-shell-script" action
if self.entry_point:
- msg = ('Action "%s" specified "entry_point" attribute. Perhaps wanted to use '
- '"remote-shell-script" runner?' % (self.action_name))
+ msg = (
+ 'Action "%s" specified "entry_point" attribute. Perhaps wanted to use '
+ '"remote-shell-script" runner?' % (self.action_name)
+ )
raise Exception(msg)
command = self.runner_parameters.get(RUNNER_COMMAND, None)
env_vars = self._get_env_vars()
- return ParamikoRemoteCommandAction(self.action_name,
- str(self.liveaction_id),
- command,
- env_vars=env_vars,
- on_behalf_user=self._on_behalf_user,
- user=self._username,
- password=self._password,
- private_key=self._private_key,
- passphrase=self._passphrase,
- hosts=self._hosts,
- parallel=self._parallel,
- sudo=self._sudo,
- sudo_password=self._sudo_password,
- timeout=self._timeout,
- cwd=self._cwd)
+ return ParamikoRemoteCommandAction(
+ self.action_name,
+ str(self.liveaction_id),
+ command,
+ env_vars=env_vars,
+ on_behalf_user=self._on_behalf_user,
+ user=self._username,
+ password=self._password,
+ private_key=self._private_key,
+ passphrase=self._passphrase,
+ hosts=self._hosts,
+ parallel=self._parallel,
+ sudo=self._sudo,
+ sudo_password=self._sudo_password,
+ timeout=self._timeout,
+ cwd=self._cwd,
+ )
def get_runner():
@@ -81,7 +86,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('remote_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("remote_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py
index 292f391850..e71e8f6314 100644
--- a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py
+++ b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py
@@ -27,12 +27,7 @@
from st2common.runners.base import get_metadata as get_runner_metadata
from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction
-__all__ = [
- 'ParamikoRemoteScriptRunner',
-
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["ParamikoRemoteScriptRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
@@ -41,10 +36,12 @@ class ParamikoRemoteScriptRunner(BaseParallelSSHRunner):
def run(self, action_parameters):
remote_action = self._get_remote_action(action_parameters)
- LOG.debug('Executing remote action.', extra={'_action_params': remote_action})
+ LOG.debug("Executing remote action.", extra={"_action_params": remote_action})
result = self._run(remote_action)
- LOG.debug('Executed remote action.', extra={'_result': result})
- status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure)
+ LOG.debug("Executed remote action.", extra={"_result": result})
+ status = self._get_result_status(
+ result, cfg.CONF.ssh_runner.allow_partial_failure
+ )
return (status, result, None)
@@ -54,109 +51,133 @@ def _run(self, remote_action):
except:
# If for whatever reason there is a top level exception,
# we just bail here.
- error = 'Failed copying content to remote boxes.'
+ error = "Failed copying content to remote boxes."
LOG.exception(error)
_, ex, tb = sys.exc_info()
- copy_results = self._generate_error_results(' '.join([error, str(ex)]), tb)
+ copy_results = self._generate_error_results(" ".join([error, str(ex)]), tb)
return copy_results
try:
exec_results = self._run_script_on_remote_host(remote_action)
try:
remote_dir = remote_action.get_remote_base_dir()
- LOG.debug('Deleting remote execution dir.', extra={'_remote_dir': remote_dir})
- delete_results = self._parallel_ssh_client.delete_dir(path=remote_dir,
- force=True)
- LOG.debug('Deleted remote execution dir.', extra={'_result': delete_results})
+ LOG.debug(
+ "Deleting remote execution dir.", extra={"_remote_dir": remote_dir}
+ )
+ delete_results = self._parallel_ssh_client.delete_dir(
+ path=remote_dir, force=True
+ )
+ LOG.debug(
+ "Deleted remote execution dir.", extra={"_result": delete_results}
+ )
except:
- LOG.exception('Failed deleting remote dir.', extra={'_remote_dir': remote_dir})
+ LOG.exception(
+ "Failed deleting remote dir.", extra={"_remote_dir": remote_dir}
+ )
finally:
return exec_results
except:
- error = 'Failed executing script on remote boxes.'
- LOG.exception(error, extra={'_action_params': remote_action})
+ error = "Failed executing script on remote boxes."
+ LOG.exception(error, extra={"_action_params": remote_action})
_, ex, tb = sys.exc_info()
- exec_results = self._generate_error_results(' '.join([error, str(ex)]), tb)
+ exec_results = self._generate_error_results(" ".join([error, str(ex)]), tb)
return exec_results
def _copy_artifacts(self, remote_action):
# First create remote execution directory.
remote_dir = remote_action.get_remote_base_dir()
- LOG.debug('Creating remote execution dir.', extra={'_path': remote_dir})
- mkdir_result = self._parallel_ssh_client.mkdir(path=remote_action.get_remote_base_dir())
+ LOG.debug("Creating remote execution dir.", extra={"_path": remote_dir})
+ mkdir_result = self._parallel_ssh_client.mkdir(
+ path=remote_action.get_remote_base_dir()
+ )
# Copy the script to remote dir in remote host.
local_script_abs_path = remote_action.get_local_script_abs_path()
remote_script_abs_path = remote_action.get_remote_script_abs_path()
file_mode = 0o744
- extra = {'_local_script': local_script_abs_path, '_remote_script': remote_script_abs_path,
- 'mode': file_mode}
- LOG.debug('Copying local script to remote box.', extra=extra)
- put_result_1 = self._parallel_ssh_client.put(local_path=local_script_abs_path,
- remote_path=remote_script_abs_path,
- mirror_local_mode=False, mode=file_mode)
+ extra = {
+ "_local_script": local_script_abs_path,
+ "_remote_script": remote_script_abs_path,
+ "mode": file_mode,
+ }
+ LOG.debug("Copying local script to remote box.", extra=extra)
+ put_result_1 = self._parallel_ssh_client.put(
+ local_path=local_script_abs_path,
+ remote_path=remote_script_abs_path,
+ mirror_local_mode=False,
+ mode=file_mode,
+ )
# If `lib` exist for the script, copy that to remote host.
local_libs_path = remote_action.get_local_libs_path_abs()
if os.path.exists(local_libs_path):
- extra = {'_local_libs': local_libs_path, '_remote_path': remote_dir}
- LOG.debug('Copying libs to remote host.', extra=extra)
- put_result_2 = self._parallel_ssh_client.put(local_path=local_libs_path,
- remote_path=remote_dir,
- mirror_local_mode=True)
+ extra = {"_local_libs": local_libs_path, "_remote_path": remote_dir}
+ LOG.debug("Copying libs to remote host.", extra=extra)
+ put_result_2 = self._parallel_ssh_client.put(
+ local_path=local_libs_path,
+ remote_path=remote_dir,
+ mirror_local_mode=True,
+ )
result = mkdir_result or put_result_1 or put_result_2
return result
def _run_script_on_remote_host(self, remote_action):
command = remote_action.get_full_command_string()
- LOG.info('Command to run: %s', command)
- results = self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout())
- LOG.debug('Results from script: %s', results)
+ LOG.info("Command to run: %s", command)
+ results = self._parallel_ssh_client.run(
+ command, timeout=remote_action.get_timeout()
+ )
+ LOG.debug("Results from script: %s", results)
return results
def _get_remote_action(self, action_parameters):
# remote script actions without entry_point don't make sense, user probably wanted to use
# "remote-shell-cmd" action
if not self.entry_point:
- msg = ('Action "%s" is missing "entry_point" attribute. Perhaps wanted to use '
- '"remote-shell-script" runner?' % (self.action_name))
+ msg = (
+ 'Action "%s" is missing "entry_point" attribute. Perhaps wanted to use '
+ '"remote-shell-script" runner?' % (self.action_name)
+ )
raise Exception(msg)
script_local_path_abs = self.entry_point
pos_args, named_args = self._get_script_args(action_parameters)
named_args = self._transform_named_args(named_args)
env_vars = self._get_env_vars()
- remote_dir = self.runner_parameters.get(RUNNER_REMOTE_DIR,
- cfg.CONF.ssh_runner.remote_dir)
+ remote_dir = self.runner_parameters.get(
+ RUNNER_REMOTE_DIR, cfg.CONF.ssh_runner.remote_dir
+ )
remote_dir = os.path.join(remote_dir, self.liveaction_id)
- return ParamikoRemoteScriptAction(self.action_name,
- str(self.liveaction_id),
- script_local_path_abs,
- self.libs_dir_path,
- named_args=named_args,
- positional_args=pos_args,
- env_vars=env_vars,
- on_behalf_user=self._on_behalf_user,
- user=self._username,
- password=self._password,
- private_key=self._private_key,
- remote_dir=remote_dir,
- hosts=self._hosts,
- parallel=self._parallel,
- sudo=self._sudo,
- sudo_password=self._sudo_password,
- timeout=self._timeout,
- cwd=self._cwd)
+ return ParamikoRemoteScriptAction(
+ self.action_name,
+ str(self.liveaction_id),
+ script_local_path_abs,
+ self.libs_dir_path,
+ named_args=named_args,
+ positional_args=pos_args,
+ env_vars=env_vars,
+ on_behalf_user=self._on_behalf_user,
+ user=self._username,
+ password=self._password,
+ private_key=self._private_key,
+ remote_dir=remote_dir,
+ hosts=self._hosts,
+ parallel=self._parallel,
+ sudo=self._sudo,
+ sudo_password=self._sudo_password,
+ timeout=self._timeout,
+ cwd=self._cwd,
+ )
@staticmethod
def _generate_error_results(error, tb):
error_dict = {
- 'error': error,
- 'traceback': ''.join(traceback.format_tb(tb, 20)) if tb else '',
- 'failed': True,
- 'succeeded': False,
- 'return_code': 255
+ "error": error,
+ "traceback": "".join(traceback.format_tb(tb, 20)) if tb else "",
+ "failed": True,
+ "succeeded": False,
+ "return_code": 255,
}
return error_dict
@@ -166,7 +187,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('remote_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("remote_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/remote_runner/setup.py b/contrib/runners/remote_runner/setup.py
index cdd61b68b1..3e83437aff 100644
--- a/contrib/runners/remote_runner/setup.py
+++ b/contrib/runners/remote_runner/setup.py
@@ -26,32 +26,34 @@
from remote_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-remote',
+ name="stackstorm-runner-remote",
version=__version__,
- description=('Remote SSH shell command and script action runner for StackStorm event-driven '
- 'automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "Remote SSH shell command and script action runner for StackStorm event-driven "
+ "automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'remote_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"remote_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'remote-shell-cmd = remote_runner.remote_command_runner',
- 'remote-shell-script = remote_runner.remote_script_runner',
+ "st2common.runners.runner": [
+ "remote-shell-cmd = remote_runner.remote_command_runner",
+ "remote-shell-script = remote_runner.remote_script_runner",
],
- }
+ },
)
diff --git a/contrib/runners/winrm_runner/dist_utils.py b/contrib/runners/winrm_runner/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/contrib/runners/winrm_runner/dist_utils.py
+++ b/contrib/runners/winrm_runner/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/contrib/runners/winrm_runner/setup.py b/contrib/runners/winrm_runner/setup.py
index f3f014277b..53d7b952e1 100644
--- a/contrib/runners/winrm_runner/setup.py
+++ b/contrib/runners/winrm_runner/setup.py
@@ -26,33 +26,35 @@
from winrm_runner import __version__
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
setup(
- name='stackstorm-runner-winrm',
+ name="stackstorm-runner-winrm",
version=__version__,
- description=('WinRM shell command and PowerShell script action runner for'
- ' the StackStorm event-driven automation platform'),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description=(
+ "WinRM shell command and PowerShell script action runner for"
+ " the StackStorm event-driven automation platform"
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
- test_suite='tests',
+ test_suite="tests",
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- package_data={'winrm_runner': ['runner.yaml']},
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ package_data={"winrm_runner": ["runner.yaml"]},
scripts=[],
entry_points={
- 'st2common.runners.runner': [
- 'winrm-cmd = winrm_runner.winrm_command_runner',
- 'winrm-ps-cmd = winrm_runner.winrm_ps_command_runner',
- 'winrm-ps-script = winrm_runner.winrm_ps_script_runner',
+ "st2common.runners.runner": [
+ "winrm-cmd = winrm_runner.winrm_command_runner",
+ "winrm-ps-cmd = winrm_runner.winrm_ps_command_runner",
+ "winrm-ps-script = winrm_runner.winrm_ps_script_runner",
],
- }
+ },
)
diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py
index 0803b3e25a..1ff9f2ce1d 100644
--- a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py
+++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py
@@ -32,157 +32,170 @@
class WinRmBaseTestCase(RunnerTestCase):
-
def setUp(self):
super(WinRmBaseTestCase, self).setUpClass()
self._runner = winrm_ps_command_runner.get_runner()
def _init_runner(self):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'xyz987'}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "xyz987",
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
def test_win_rm_runner_timout_error(self):
- error = WinRmRunnerTimoutError('test_response')
+ error = WinRmRunnerTimoutError("test_response")
self.assertIsInstance(error, Exception)
- self.assertEqual(error.response, 'test_response')
+ self.assertEqual(error.response, "test_response")
with self.assertRaises(WinRmRunnerTimoutError):
- raise WinRmRunnerTimoutError('test raising')
+ raise WinRmRunnerTimoutError("test raising")
def test_init(self):
- runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef')
+ runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef")
self.assertIsInstance(runner, WinRmBaseRunner)
self.assertIsInstance(runner, ActionRunner)
self.assertEqual(runner.runner_id, "abcdef")
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123',
- 'timeout': 99,
- 'port': 1234,
- 'scheme': 'http',
- 'transport': 'ntlm',
- 'verify_ssl_cert': False,
- 'cwd': 'C:\\Test',
- 'env': {'TEST_VAR': 'TEST_VALUE'},
- 'kwarg_op': '/'}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ "timeout": 99,
+ "port": 1234,
+ "scheme": "http",
+ "transport": "ntlm",
+ "verify_ssl_cert": False,
+ "cwd": "C:\\Test",
+ "env": {"TEST_VAR": "TEST_VALUE"},
+ "kwarg_op": "/",
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
self.assertEqual(self._runner._session, None)
- self.assertEqual(self._runner._host, 'host@domain.tld')
- self.assertEqual(self._runner._username, 'user@domain.tld')
- self.assertEqual(self._runner._password, 'abc123')
+ self.assertEqual(self._runner._host, "host@domain.tld")
+ self.assertEqual(self._runner._username, "user@domain.tld")
+ self.assertEqual(self._runner._password, "abc123")
self.assertEqual(self._runner._timeout, 99)
self.assertEqual(self._runner._read_timeout, 100)
self.assertEqual(self._runner._port, 1234)
- self.assertEqual(self._runner._scheme, 'http')
- self.assertEqual(self._runner._transport, 'ntlm')
- self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:1234/wsman')
+ self.assertEqual(self._runner._scheme, "http")
+ self.assertEqual(self._runner._transport, "ntlm")
+ self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:1234/wsman")
self.assertEqual(self._runner._verify_ssl, False)
- self.assertEqual(self._runner._server_cert_validation, 'ignore')
- self.assertEqual(self._runner._cwd, 'C:\\Test')
- self.assertEqual(self._runner._env, {'TEST_VAR': 'TEST_VALUE'})
- self.assertEqual(self._runner._kwarg_op, '/')
+ self.assertEqual(self._runner._server_cert_validation, "ignore")
+ self.assertEqual(self._runner._cwd, "C:\\Test")
+ self.assertEqual(self._runner._env, {"TEST_VAR": "TEST_VALUE"})
+ self.assertEqual(self._runner._kwarg_op, "/")
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run_defaults(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123'}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
- self.assertEqual(self._runner._host, 'host@domain.tld')
- self.assertEqual(self._runner._username, 'user@domain.tld')
- self.assertEqual(self._runner._password, 'abc123')
+ self.assertEqual(self._runner._host, "host@domain.tld")
+ self.assertEqual(self._runner._username, "user@domain.tld")
+ self.assertEqual(self._runner._password, "abc123")
self.assertEqual(self._runner._timeout, 60)
self.assertEqual(self._runner._read_timeout, 61)
self.assertEqual(self._runner._port, 5986)
- self.assertEqual(self._runner._scheme, 'https')
- self.assertEqual(self._runner._transport, 'ntlm')
- self.assertEqual(self._runner._winrm_url, 'https://host@domain.tld:5986/wsman')
+ self.assertEqual(self._runner._scheme, "https")
+ self.assertEqual(self._runner._transport, "ntlm")
+ self.assertEqual(self._runner._winrm_url, "https://host@domain.tld:5986/wsman")
self.assertEqual(self._runner._verify_ssl, True)
- self.assertEqual(self._runner._server_cert_validation, 'validate')
+ self.assertEqual(self._runner._server_cert_validation, "validate")
self.assertEqual(self._runner._cwd, None)
self.assertEqual(self._runner._env, {})
- self.assertEqual(self._runner._kwarg_op, '-')
+ self.assertEqual(self._runner._kwarg_op, "-")
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run_5985_force_http(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123',
- 'port': 5985,
- 'scheme': 'https'}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ "port": 5985,
+ "scheme": "https",
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
- self.assertEqual(self._runner._host, 'host@domain.tld')
- self.assertEqual(self._runner._username, 'user@domain.tld')
- self.assertEqual(self._runner._password, 'abc123')
+ self.assertEqual(self._runner._host, "host@domain.tld")
+ self.assertEqual(self._runner._username, "user@domain.tld")
+ self.assertEqual(self._runner._password, "abc123")
self.assertEqual(self._runner._timeout, 60)
self.assertEqual(self._runner._read_timeout, 61)
# ensure port is still 5985
self.assertEqual(self._runner._port, 5985)
# ensure scheme is set back to http
- self.assertEqual(self._runner._scheme, 'http')
- self.assertEqual(self._runner._transport, 'ntlm')
- self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:5985/wsman')
+ self.assertEqual(self._runner._scheme, "http")
+ self.assertEqual(self._runner._transport, "ntlm")
+ self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:5985/wsman")
self.assertEqual(self._runner._verify_ssl, True)
- self.assertEqual(self._runner._server_cert_validation, 'validate')
+ self.assertEqual(self._runner._server_cert_validation, "validate")
self.assertEqual(self._runner._cwd, None)
self.assertEqual(self._runner._env, {})
- self.assertEqual(self._runner._kwarg_op, '-')
+ self.assertEqual(self._runner._kwarg_op, "-")
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run_none_env(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123',
- 'env': None}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ "env": None,
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
# ensure that env is set to {} even though we passed in None
self.assertEqual(self._runner._env, {})
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run_ssl_verify_true(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123',
- 'verify_ssl_cert': True}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ "verify_ssl_cert": True,
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
self.assertEqual(self._runner._verify_ssl, True)
- self.assertEqual(self._runner._server_cert_validation, 'validate')
+ self.assertEqual(self._runner._server_cert_validation, "validate")
- @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run')
+ @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run")
def test_pre_run_ssl_verify_false(self, mock_pre_run):
- runner_parameters = {'host': 'host@domain.tld',
- 'username': 'user@domain.tld',
- 'password': 'abc123',
- 'verify_ssl_cert': False}
+ runner_parameters = {
+ "host": "host@domain.tld",
+ "username": "user@domain.tld",
+ "password": "abc123",
+ "verify_ssl_cert": False,
+ }
self._runner.runner_parameters = runner_parameters
self._runner.pre_run()
mock_pre_run.assert_called_with()
self.assertEqual(self._runner._verify_ssl, False)
- self.assertEqual(self._runner._server_cert_validation, 'ignore')
+ self.assertEqual(self._runner._server_cert_validation, "ignore")
- @mock.patch('winrm_runner.winrm_base.Session')
+ @mock.patch("winrm_runner.winrm_base.Session")
def test_get_session(self, mock_session):
self._runner._session = None
- self._runner._winrm_url = 'https://host@domain.tld:5986/wsman'
- self._runner._username = 'user@domain.tld'
- self._runner._password = 'abc123'
- self._runner._transport = 'ntlm'
- self._runner._server_cert_validation = 'validate'
+ self._runner._winrm_url = "https://host@domain.tld:5986/wsman"
+ self._runner._username = "user@domain.tld"
+ self._runner._password = "abc123"
+ self._runner._transport = "ntlm"
+ self._runner._server_cert_validation = "validate"
self._runner._timeout = 60
self._runner._read_timeout = 61
mock_session.return_value = "session"
@@ -190,12 +203,14 @@ def test_get_session(self, mock_session):
result = self._runner._get_session()
self.assertEqual(result, "session")
self.assertEqual(result, self._runner._session)
- mock_session.assert_called_with('https://host@domain.tld:5986/wsman',
- auth=('user@domain.tld', 'abc123'),
- transport='ntlm',
- server_cert_validation='validate',
- operation_timeout_sec=60,
- read_timeout_sec=61)
+ mock_session.assert_called_with(
+ "https://host@domain.tld:5986/wsman",
+ auth=("user@domain.tld", "abc123"),
+ transport="ntlm",
+ server_cert_validation="validate",
+ operation_timeout_sec=60,
+ read_timeout_sec=61,
+ )
# ensure calling _get_session again doesn't create a new one, it reuses the existing
old_session = self._runner._session
@@ -206,18 +221,18 @@ def test_winrm_get_command_output(self):
self._runner._timeout = 0
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 123, False),
- (b'output2', b'error2', 456, False),
- (b'output3', b'error3', 789, True)
+ (b"output1", b"error1", 123, False),
+ (b"output2", b"error2", 456, False),
+ (b"output3", b"error3", 789, True),
]
result = self._runner._winrm_get_command_output(mock_protocol, 567, 890)
- self.assertEqual(result, (b'output1output2output3', b'error1error2error3', 789))
+ self.assertEqual(result, (b"output1output2output3", b"error1error2error3", 789))
mock_protocol._raw_get_command_output.assert_has_calls = [
mock.call(567, 890),
mock.call(567, 890),
- mock.call(567, 890)
+ mock.call(567, 890),
]
def test_winrm_get_command_output_timeout(self):
@@ -227,7 +242,7 @@ def test_winrm_get_command_output_timeout(self):
def sleep_for_timeout(*args, **kwargs):
time.sleep(0.2)
- return (b'output1', b'error1', 123, False)
+ return (b"output1", b"error1", 123, False)
mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout
@@ -235,9 +250,11 @@ def sleep_for_timeout(*args, **kwargs):
self._runner._winrm_get_command_output(mock_protocol, 567, 890)
timeout_exception = cm.exception
- self.assertEqual(timeout_exception.response.std_out, b'output1')
- self.assertEqual(timeout_exception.response.std_err, b'error1')
- self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE)
+ self.assertEqual(timeout_exception.response.std_out, b"output1")
+ self.assertEqual(timeout_exception.response.std_err, b"error1")
+ self.assertEqual(
+ timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE
+ )
mock_protocol._raw_get_command_output.assert_called_with(567, 890)
def test_winrm_get_command_output_operation_timeout(self):
@@ -255,292 +272,354 @@ def sleep_for_timeout_then_raise(*args, **kwargs):
self._runner._winrm_get_command_output(mock_protocol, 567, 890)
timeout_exception = cm.exception
- self.assertEqual(timeout_exception.response.std_out, b'')
- self.assertEqual(timeout_exception.response.std_err, b'')
- self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE)
+ self.assertEqual(timeout_exception.response.std_out, b"")
+ self.assertEqual(timeout_exception.response.std_err, b"")
+ self.assertEqual(
+ timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE
+ )
mock_protocol._raw_get_command_output.assert_called_with(567, 890)
def test_winrm_run_cmd(self):
mock_protocol = mock.MagicMock()
mock_protocol.open_shell.return_value = 123
mock_protocol.run_command.return_value = 456
- mock_protocol._raw_get_command_output.return_value = (b'output', b'error', 9, True)
+ mock_protocol._raw_get_command_output.return_value = (
+ b"output",
+ b"error",
+ 9,
+ True,
+ )
mock_session = mock.MagicMock(protocol=mock_protocol)
self._init_runner()
- result = self._runner._winrm_run_cmd(mock_session, "fake-command",
- args=['arg1', 'arg2'],
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
- expected_response = Response((b'output', b'error', 9))
+ result = self._runner._winrm_run_cmd(
+ mock_session,
+ "fake-command",
+ args=["arg1", "arg2"],
+ env={"PATH": "C:\\st2\\bin"},
+ cwd="C:\\st2",
+ )
+ expected_response = Response((b"output", b"error", 9))
expected_response.timeout = False
self.assertEqual(result.__dict__, expected_response.__dict__)
- mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'},
- working_directory='C:\\st2')
- mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2'])
+ mock_protocol.open_shell.assert_called_with(
+ env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2"
+ )
+ mock_protocol.run_command.assert_called_with(
+ 123, "fake-command", ["arg1", "arg2"]
+ )
mock_protocol._raw_get_command_output.assert_called_with(123, 456)
mock_protocol.cleanup_command.assert_called_with(123, 456)
mock_protocol.close_shell.assert_called_with(123)
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output")
def test_winrm_run_cmd_timeout(self, mock_get_command_output):
mock_protocol = mock.MagicMock()
mock_protocol.open_shell.return_value = 123
mock_protocol.run_command.return_value = 456
mock_session = mock.MagicMock(protocol=mock_protocol)
- mock_get_command_output.side_effect = WinRmRunnerTimoutError(Response(('', '', 5)))
+ mock_get_command_output.side_effect = WinRmRunnerTimoutError(
+ Response(("", "", 5))
+ )
self._init_runner()
- result = self._runner._winrm_run_cmd(mock_session, "fake-command",
- args=['arg1', 'arg2'],
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
- expected_response = Response(('', '', 5))
+ result = self._runner._winrm_run_cmd(
+ mock_session,
+ "fake-command",
+ args=["arg1", "arg2"],
+ env={"PATH": "C:\\st2\\bin"},
+ cwd="C:\\st2",
+ )
+ expected_response = Response(("", "", 5))
expected_response.timeout = True
self.assertEqual(result.__dict__, expected_response.__dict__)
- mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'},
- working_directory='C:\\st2')
- mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2'])
+ mock_protocol.open_shell.assert_called_with(
+ env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2"
+ )
+ mock_protocol.run_command.assert_called_with(
+ 123, "fake-command", ["arg1", "arg2"]
+ )
mock_protocol.cleanup_command.assert_called_with(123, 456)
mock_protocol.close_shell.assert_called_with(123)
def test_winrm_encode(self):
- result = self._runner._winrm_encode('hello world')
+ result = self._runner._winrm_encode("hello world")
# result translated into UTF-16 little-endian
- self.assertEqual(result, 'aABlAGwAbABvACAAdwBvAHIAbABkAA==')
+ self.assertEqual(result, "aABlAGwAbABvACAAdwBvAHIAbABkAA==")
def test_winrm_ps_cmd(self):
- result = self._runner._winrm_ps_cmd('abc123==')
- self.assertEqual(result, 'powershell -encodedcommand abc123==')
+ result = self._runner._winrm_ps_cmd("abc123==")
+ self.assertEqual(result, "powershell -encodedcommand abc123==")
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd")
def test_winrm_run_ps(self, mock_run_cmd):
- mock_run_cmd.return_value = Response(('output', '', 3))
+ mock_run_cmd.return_value = Response(("output", "", 3))
script = "Get-ADUser stanley"
- result = self._runner._winrm_run_ps("session", script,
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
+ result = self._runner._winrm_run_ps(
+ "session", script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2"
+ )
- self.assertEqual(result.__dict__,
- Response(('output', '', 3)).__dict__)
- expected_ps = ('powershell -encodedcommand ' +
- b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii'))
- mock_run_cmd.assert_called_with("session",
- expected_ps,
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
+ self.assertEqual(result.__dict__, Response(("output", "", 3)).__dict__)
+ expected_ps = "powershell -encodedcommand " + b64encode(
+ "Get-ADUser stanley".encode("utf_16_le")
+ ).decode("ascii")
+ mock_run_cmd.assert_called_with(
+ "session", expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2"
+ )
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd")
def test_winrm_run_ps_clean_stderr(self, mock_run_cmd):
- mock_run_cmd.return_value = Response(('output', 'error', 3))
+ mock_run_cmd.return_value = Response(("output", "error", 3))
mock_session = mock.MagicMock()
- mock_session._clean_error_msg.return_value = 'e'
+ mock_session._clean_error_msg.return_value = "e"
script = "Get-ADUser stanley"
- result = self._runner._winrm_run_ps(mock_session, script,
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
+ result = self._runner._winrm_run_ps(
+ mock_session, script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2"
+ )
- self.assertEqual(result.__dict__,
- Response(('output', 'e', 3)).__dict__)
- expected_ps = ('powershell -encodedcommand ' +
- b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii'))
- mock_run_cmd.assert_called_with(mock_session,
- expected_ps,
- env={'PATH': 'C:\\st2\\bin'},
- cwd='C:\\st2')
- mock_session._clean_error_msg.assert_called_with('error')
+ self.assertEqual(result.__dict__, Response(("output", "e", 3)).__dict__)
+ expected_ps = "powershell -encodedcommand " + b64encode(
+ "Get-ADUser stanley".encode("utf_16_le")
+ ).decode("ascii")
+ mock_run_cmd.assert_called_with(
+ mock_session, expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2"
+ )
+ mock_session._clean_error_msg.assert_called_with("error")
def test_translate_response_success(self):
- response = Response(('output1', 'error1', 0))
+ response = Response(("output1", "error1", 0))
response.timeout = False
result = self._runner._translate_response(response)
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
def test_translate_response_failure(self):
- response = Response(('output1', 'error1', 123))
+ response = Response(("output1", "error1", 123))
response.timeout = False
result = self._runner._translate_response(response)
- self.assertEqual(result, ('failed',
- {'failed': True,
- 'succeeded': False,
- 'return_code': 123,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
+ self.assertEqual(
+ result,
+ (
+ "failed",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": 123,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
def test_translate_response_timeout(self):
- response = Response(('output1', 'error1', 123))
+ response = Response(("output1", "error1", 123))
response.timeout = True
result = self._runner._translate_response(response)
- self.assertEqual(result, ('timeout',
- {'failed': True,
- 'succeeded': False,
- 'return_code': -1,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise')
+ self.assertEqual(
+ result,
+ (
+ "timeout",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": -1,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise")
def test_make_tmp_dir(self, mock_run_ps_or_raise):
- mock_run_ps_or_raise.return_value = {'stdout': ' expected \n'}
+ mock_run_ps_or_raise.return_value = {"stdout": " expected \n"}
- result = self._runner._make_tmp_dir('C:\\Windows\\Temp')
- self.assertEqual(result, 'expected')
- mock_run_ps_or_raise.assert_called_with('''$parent = C:\\Windows\\Temp
+ result = self._runner._make_tmp_dir("C:\\Windows\\Temp")
+ self.assertEqual(result, "expected")
+ mock_run_ps_or_raise.assert_called_with(
+ """$parent = C:\\Windows\\Temp
$name = [System.IO.Path]::GetRandomFileName()
$path = Join-Path $parent $name
New-Item -ItemType Directory -Path $path | Out-Null
-$path''',
- ("Unable to make temporary directory for"
- " powershell script"))
+$path""",
+ ("Unable to make temporary directory for" " powershell script"),
+ )
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise")
def test_rm_dir(self, mock_run_ps_or_raise):
- self._runner._rm_dir('C:\\Windows\\Temp\\testtmpdir')
+ self._runner._rm_dir("C:\\Windows\\Temp\\testtmpdir")
mock_run_ps_or_raise.assert_called_with(
'Remove-Item -Force -Recurse -Path "C:\\Windows\\Temp\\testtmpdir"',
- "Unable to remove temporary directory for powershell script")
+ "Unable to remove temporary directory for powershell script",
+ )
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk')
- @mock.patch('winrm_runner.winrm_base.open')
- @mock.patch('os.path.exists')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk")
+ @mock.patch("winrm_runner.winrm_base.open")
+ @mock.patch("os.path.exists")
def test_upload_chunk_file(self, mock_os_path_exists, mock_open, mock_upload_chunk):
mock_os_path_exists.return_value = True
mock_src_file = mock.MagicMock()
mock_src_file.read.return_value = "test data"
mock_open.return_value.__enter__.return_value = mock_src_file
- self._runner._upload('/opt/data/test.ps1', 'C:\\Windows\\Temp\\test.ps1')
- mock_os_path_exists.assert_called_with('/opt/data/test.ps1')
- mock_open.assert_called_with('/opt/data/test.ps1', 'r')
+ self._runner._upload("/opt/data/test.ps1", "C:\\Windows\\Temp\\test.ps1")
+ mock_os_path_exists.assert_called_with("/opt/data/test.ps1")
+ mock_open.assert_called_with("/opt/data/test.ps1", "r")
mock_src_file.read.assert_called_with()
- mock_upload_chunk.assert_has_calls([
- mock.call('C:\\Windows\\Temp\\test.ps1', 'test data')
- ])
+ mock_upload_chunk.assert_has_calls(
+ [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")]
+ )
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk')
- @mock.patch('os.path.exists')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk")
+ @mock.patch("os.path.exists")
def test_upload_chunk_data(self, mock_os_path_exists, mock_upload_chunk):
mock_os_path_exists.return_value = False
- self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1')
- mock_os_path_exists.assert_called_with('test data')
- mock_upload_chunk.assert_has_calls([
- mock.call('C:\\Windows\\Temp\\test.ps1', 'test data')
- ])
+ self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1")
+ mock_os_path_exists.assert_called_with("test data")
+ mock_upload_chunk.assert_has_calls(
+ [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")]
+ )
- @mock.patch('winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES', 2)
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk')
- @mock.patch('os.path.exists')
+ @mock.patch("winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES", 2)
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk")
+ @mock.patch("os.path.exists")
def test_upload_chunk_multiple_chunks(self, mock_os_path_exists, mock_upload_chunk):
mock_os_path_exists.return_value = False
- self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1')
- mock_os_path_exists.assert_called_with('test data')
- mock_upload_chunk.assert_has_calls([
- mock.call('C:\\Windows\\Temp\\test.ps1', 'te'),
- mock.call('C:\\Windows\\Temp\\test.ps1', 'st'),
- mock.call('C:\\Windows\\Temp\\test.ps1', ' d'),
- mock.call('C:\\Windows\\Temp\\test.ps1', 'at'),
- mock.call('C:\\Windows\\Temp\\test.ps1', 'a'),
- ])
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise')
+ self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1")
+ mock_os_path_exists.assert_called_with("test data")
+ mock_upload_chunk.assert_has_calls(
+ [
+ mock.call("C:\\Windows\\Temp\\test.ps1", "te"),
+ mock.call("C:\\Windows\\Temp\\test.ps1", "st"),
+ mock.call("C:\\Windows\\Temp\\test.ps1", " d"),
+ mock.call("C:\\Windows\\Temp\\test.ps1", "at"),
+ mock.call("C:\\Windows\\Temp\\test.ps1", "a"),
+ ]
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise")
def test_upload_chunk(self, mock_run_ps_or_raise):
- self._runner._upload_chunk('C:\\Windows\\Temp\\testtmp.ps1', 'hello world')
+ self._runner._upload_chunk("C:\\Windows\\Temp\\testtmp.ps1", "hello world")
mock_run_ps_or_raise.assert_called_with(
- '''$filePath = "C:\\Windows\\Temp\\testtmp.ps1"
+ """$filePath = "C:\\Windows\\Temp\\testtmp.ps1"
$s = @"
aGVsbG8gd29ybGQ=
"@
$data = [System.Convert]::FromBase64String($s)
Add-Content -value $data -encoding byte -path $filePath
-''',
- "Failed to upload chunk of powershell script")
+""",
+ "Failed to upload chunk of powershell script",
+ )
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir")
def test_tmp_script(self, mock_make_tmp_dir, mock_upload, mock_rm_dir):
- mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123'
-
- with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp:
- self.assertEqual(tmp, 'C:\\Windows\\Temp\\abc123\\script.ps1')
- mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp')
- mock_upload.assert_called_with('Get-ChildItem',
- 'C:\\Windows\\Temp\\abc123\\script.ps1')
- mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123')
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir')
- def test_tmp_script_cleans_up_when_raises(self, mock_make_tmp_dir, mock_upload,
- mock_rm_dir):
- mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123'
+ mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123"
+
+ with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp:
+ self.assertEqual(tmp, "C:\\Windows\\Temp\\abc123\\script.ps1")
+ mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp")
+ mock_upload.assert_called_with(
+ "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1"
+ )
+ mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123")
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir")
+ def test_tmp_script_cleans_up_when_raises(
+ self, mock_make_tmp_dir, mock_upload, mock_rm_dir
+ ):
+ mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123"
mock_upload.side_effect = RuntimeError
with self.assertRaises(RuntimeError):
- with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp:
+ with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp:
self.assertEqual(tmp, "can never get here")
- mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp')
- mock_upload.assert_called_with('Get-ChildItem',
- 'C:\\Windows\\Temp\\abc123\\script.ps1')
- mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123')
+ mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp")
+ mock_upload.assert_called_with(
+ "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1"
+ )
+ mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123")
- @mock.patch('winrm.Protocol')
+ @mock.patch("winrm.Protocol")
def test_run_cmd(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 0, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 0, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner.run_cmd("ipconfig /all")
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test_run_cmd_failed(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 1, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 1, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner.run_cmd("ipconfig /all")
- self.assertEqual(result, ('failed',
- {'failed': True,
- 'succeeded': False,
- 'return_code': 1,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "failed",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": 1,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test_run_cmd_timeout(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
self._init_runner()
@@ -548,61 +627,82 @@ def test_run_cmd_timeout(self, mock_protocol_init):
def sleep_for_timeout_then_raise(*args, **kwargs):
time.sleep(0.2)
- return (b'output1', b'error1', 123, False)
+ return (b"output1", b"error1", 123, False)
mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise
mock_protocol_init.return_value = mock_protocol
result = self._runner.run_cmd("ipconfig /all")
- self.assertEqual(result, ('timeout',
- {'failed': True,
- 'succeeded': False,
- 'return_code': -1,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "timeout",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": -1,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test_run_ps(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 0, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 0, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner.run_ps("Get-Location")
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test_run_ps_failed(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 1, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 1, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner.run_ps("Get-Location")
- self.assertEqual(result, ('failed',
- {'failed': True,
- 'succeeded': False,
- 'return_code': 1,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "failed",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": 1,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test_run_ps_timeout(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
self._init_runner()
@@ -610,91 +710,113 @@ def test_run_ps_timeout(self, mock_protocol_init):
def sleep_for_timeout_then_raise(*args, **kwargs):
time.sleep(0.2)
- return (b'output1', b'error1', 123, False)
+ return (b"output1", b"error1", 123, False)
mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise
mock_protocol_init.return_value = mock_protocol
result = self._runner.run_ps("Get-Location")
- self.assertEqual(result, ('timeout',
- {'failed': True,
- 'succeeded': False,
- 'return_code': -1,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode')
+ self.assertEqual(
+ result,
+ (
+ "timeout",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": -1,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode")
def test_run_ps_params(self, mock_winrm_encode, mock_run_ps):
- mock_winrm_encode.return_value = 'xyz123=='
+ mock_winrm_encode.return_value = "xyz123=="
mock_run_ps.return_value = "expected"
self._init_runner()
- result = self._runner.run_ps("Get-Location", '-param1 value1 arg1')
+ result = self._runner.run_ps("Get-Location", "-param1 value1 arg1")
self.assertEqual(result, "expected")
- mock_winrm_encode.assert_called_with('& {Get-Location} -param1 value1 arg1')
- mock_run_ps.assert_called_with('xyz123==', is_b64=True)
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script')
- def test_run_ps_large_command_convert_to_script(self, mock_run_ps_script,
- mock_winrm_ps_cmd):
+ mock_winrm_encode.assert_called_with("& {Get-Location} -param1 value1 arg1")
+ mock_run_ps.assert_called_with("xyz123==", is_b64=True)
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script")
+ def test_run_ps_large_command_convert_to_script(
+ self, mock_run_ps_script, mock_winrm_ps_cmd
+ ):
mock_run_ps_script.return_value = "expected"
# max length of a command in powershelll
- script = 'powershell -encodedcommand '
- script += '#' * (WINRM_MAX_CMD_LENGTH + 1 - len(script))
+ script = "powershell -encodedcommand "
+ script += "#" * (WINRM_MAX_CMD_LENGTH + 1 - len(script))
mock_winrm_ps_cmd.return_value = script
self._init_runner()
- result = self._runner.run_ps('$PSVersionTable')
+ result = self._runner.run_ps("$PSVersionTable")
self.assertEqual(result, "expected")
- mock_run_ps_script.assert_called_with('$PSVersionTable', None)
+ mock_run_ps_script.assert_called_with("$PSVersionTable", None)
- @mock.patch('winrm.Protocol')
+ @mock.patch("winrm.Protocol")
def test__run_ps(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 0, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 0, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner._run_ps("Get-Location")
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test__run_ps_failed(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
mock_protocol._raw_get_command_output.side_effect = [
- (b'output1', b'error1', 0, False),
- (b'output2', b'error2', 0, False),
- (b'output3', b'error3', 1, True)
+ (b"output1", b"error1", 0, False),
+ (b"output2", b"error2", 0, False),
+ (b"output3", b"error3", 1, True),
]
mock_protocol_init.return_value = mock_protocol
self._init_runner()
result = self._runner._run_ps("Get-Location")
- self.assertEqual(result, ('failed',
- {'failed': True,
- 'succeeded': False,
- 'return_code': 1,
- 'stdout': 'output1output2output3',
- 'stderr': 'error1error2error3'},
- None))
-
- @mock.patch('winrm.Protocol')
+ self.assertEqual(
+ result,
+ (
+ "failed",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": 1,
+ "stdout": "output1output2output3",
+ "stderr": "error1error2error3",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm.Protocol")
def test__run_ps_timeout(self, mock_protocol_init):
mock_protocol = mock.MagicMock()
self._init_runner()
@@ -702,238 +824,236 @@ def test__run_ps_timeout(self, mock_protocol_init):
def sleep_for_timeout_then_raise(*args, **kwargs):
time.sleep(0.2)
- return (b'output1', b'error1', 123, False)
+ return (b"output1", b"error1", 123, False)
mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise
mock_protocol_init.return_value = mock_protocol
result = self._runner._run_ps("Get-Location")
- self.assertEqual(result, ('timeout',
- {'failed': True,
- 'succeeded': False,
- 'return_code': -1,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps')
+ self.assertEqual(
+ result,
+ (
+ "timeout",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": -1,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps")
def test__run_ps_b64_default(self, mock_winrm_run_ps):
- mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0,
- timeout=False,
- std_out='output1',
- std_err='error1')
+ mock_winrm_run_ps.return_value = mock.MagicMock(
+ status_code=0, timeout=False, std_out="output1", std_err="error1"
+ )
self._init_runner()
result = self._runner._run_ps("$PSVersionTable")
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
- mock_winrm_run_ps.assert_called_with(self._runner._session,
- '$PSVersionTable',
- env={},
- cwd=None,
- is_b64=False)
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps')
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+ mock_winrm_run_ps.assert_called_with(
+ self._runner._session, "$PSVersionTable", env={}, cwd=None, is_b64=False
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps")
def test__run_ps_b64_true(self, mock_winrm_run_ps):
- mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0,
- timeout=False,
- std_out='output1',
- std_err='error1')
+ mock_winrm_run_ps.return_value = mock.MagicMock(
+ status_code=0, timeout=False, std_out="output1", std_err="error1"
+ )
self._init_runner()
result = self._runner._run_ps("xyz123", is_b64=True)
- self.assertEqual(result, ('succeeded',
- {'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output1',
- 'stderr': 'error1'},
- None))
- mock_winrm_run_ps.assert_called_with(self._runner._session,
- 'xyz123',
- env={},
- cwd=None,
- is_b64=True)
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script')
+ self.assertEqual(
+ result,
+ (
+ "succeeded",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output1",
+ "stderr": "error1",
+ },
+ None,
+ ),
+ )
+ mock_winrm_run_ps.assert_called_with(
+ self._runner._session, "xyz123", env={}, cwd=None, is_b64=True
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script")
def test__run_ps_script(self, mock_tmp_script, mock_run_ps):
- mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1'
- mock_run_ps.return_value = 'expected'
+ mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1"
+ mock_run_ps.return_value = "expected"
self._init_runner()
result = self._runner._run_ps_script("$PSVersionTable")
- self.assertEqual(result, 'expected')
- mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()',
- '$PSVersionTable')
- mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1}')
+ self.assertEqual(result, "expected")
+ mock_tmp_script.assert_called_with(
+ "[System.IO.Path]::GetTempPath()", "$PSVersionTable"
+ )
+ mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1}")
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps')
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps")
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script")
def test__run_ps_script_with_params(self, mock_tmp_script, mock_run_ps):
- mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1'
- mock_run_ps.return_value = 'expected'
+ mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1"
+ mock_run_ps.return_value = "expected"
self._init_runner()
- result = self._runner._run_ps_script("Get-ChildItem", '-param1 value1 arg1')
- self.assertEqual(result, 'expected')
- mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()',
- 'Get-ChildItem')
- mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1} -param1 value1 arg1')
+ result = self._runner._run_ps_script("Get-ChildItem", "-param1 value1 arg1")
+ self.assertEqual(result, "expected")
+ mock_tmp_script.assert_called_with(
+ "[System.IO.Path]::GetTempPath()", "Get-ChildItem"
+ )
+ mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1} -param1 value1 arg1")
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps')
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps")
def test__run_ps_or_raise(self, mock_run_ps):
- mock_run_ps.return_value = ('success',
- {
- 'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output',
- 'stderr': 'error',
- },
- None)
+ mock_run_ps.return_value = (
+ "success",
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output",
+ "stderr": "error",
+ },
+ None,
+ )
self._init_runner()
- result = self._runner._run_ps_or_raise('Get-ChildItem', 'my error message')
- self.assertEqual(result, {
- 'failed': False,
- 'succeeded': True,
- 'return_code': 0,
- 'stdout': 'output',
- 'stderr': 'error',
- })
-
- @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps')
+ result = self._runner._run_ps_or_raise("Get-ChildItem", "my error message")
+ self.assertEqual(
+ result,
+ {
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
+ "stdout": "output",
+ "stderr": "error",
+ },
+ )
+
+ @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps")
def test__run_ps_or_raise_raises_on_failure(self, mock_run_ps):
- mock_run_ps.return_value = ('success',
- {
- 'failed': True,
- 'succeeded': False,
- 'return_code': 1,
- 'stdout': 'output',
- 'stderr': 'error',
- },
- None)
+ mock_run_ps.return_value = (
+ "success",
+ {
+ "failed": True,
+ "succeeded": False,
+ "return_code": 1,
+ "stdout": "output",
+ "stderr": "error",
+ },
+ None,
+ )
self._init_runner()
with self.assertRaises(RuntimeError):
- self._runner._run_ps_or_raise('Get-ChildItem', 'my error message')
+ self._runner._run_ps_or_raise("Get-ChildItem", "my error message")
def test_multireplace(self):
- multireplace_map = {'a': 'x',
- 'c': 'y',
- 'aaa': 'z'}
- result = self._runner._multireplace('aaaccaa', multireplace_map)
- self.assertEqual(result, 'zyyxx')
+ multireplace_map = {"a": "x", "c": "y", "aaa": "z"}
+ result = self._runner._multireplace("aaaccaa", multireplace_map)
+ self.assertEqual(result, "zyyxx")
def test_multireplace_powershell(self):
- param_str = (
- '\n'
- '\r'
- '\t'
- '\a'
- '\b'
- '\f'
- '\v'
- '"'
- '\''
- '`'
- '\0'
- '$'
- )
+ param_str = "\n" "\r" "\t" "\a" "\b" "\f" "\v" '"' "'" "`" "\0" "$"
result = self._runner._multireplace(param_str, PS_ESCAPE_SEQUENCES)
- self.assertEqual(result, (
- '`n'
- '`r'
- '`t'
- '`a'
- '`b'
- '`f'
- '`v'
- '`"'
- '`\''
- '``'
- '`0'
- '`$'
- ))
+ self.assertEqual(
+ result, ("`n" "`r" "`t" "`a" "`b" "`f" "`v" '`"' "`'" "``" "`0" "`$")
+ )
def test_param_to_ps_none(self):
# test None/null
param = None
result = self._runner._param_to_ps(param)
- self.assertEqual(result, '$null')
+ self.assertEqual(result, "$null")
def test_param_to_ps_string(self):
# test ascii
- param_str = 'StackStorm 1234'
+ param_str = "StackStorm 1234"
result = self._runner._param_to_ps(param_str)
self.assertEqual(result, '"StackStorm 1234"')
# test escaped
- param_str = '\n\r\t'
+ param_str = "\n\r\t"
result = self._runner._param_to_ps(param_str)
self.assertEqual(result, '"`n`r`t"')
def test_param_to_ps_bool(self):
# test True
result = self._runner._param_to_ps(True)
- self.assertEqual(result, '$true')
+ self.assertEqual(result, "$true")
# test False
result = self._runner._param_to_ps(False)
- self.assertEqual(result, '$false')
+ self.assertEqual(result, "$false")
def test_param_to_ps_integer(self):
result = self._runner._param_to_ps(9876)
- self.assertEqual(result, '9876')
+ self.assertEqual(result, "9876")
result = self._runner._param_to_ps(-765)
- self.assertEqual(result, '-765')
+ self.assertEqual(result, "-765")
def test_param_to_ps_float(self):
result = self._runner._param_to_ps(98.76)
- self.assertEqual(result, '98.76')
+ self.assertEqual(result, "98.76")
result = self._runner._param_to_ps(-76.5)
- self.assertEqual(result, '-76.5')
+ self.assertEqual(result, "-76.5")
def test_param_to_ps_list(self):
- input_list = ['StackStorm Test String',
- '`\0$',
- True,
- 99]
+ input_list = ["StackStorm Test String", "`\0$", True, 99]
result = self._runner._param_to_ps(input_list)
self.assertEqual(result, '@("StackStorm Test String", "```0`$", $true, 99)')
def test_param_to_ps_list_nested(self):
- input_list = [['a'], ['b'], [['c']]]
+ input_list = [["a"], ["b"], [["c"]]]
result = self._runner._param_to_ps(input_list)
self.assertEqual(result, '@(@("a"), @("b"), @(@("c")))')
def test_param_to_ps_dict(self):
input_list = collections.OrderedDict(
- [('str key', 'Value String'),
- ('esc str\n', '\b\f\v"'),
- (False, True),
- (11, 99),
- (18.3, 12.34)])
+ [
+ ("str key", "Value String"),
+ ("esc str\n", '\b\f\v"'),
+ (False, True),
+ (11, 99),
+ (18.3, 12.34),
+ ]
+ )
result = self._runner._param_to_ps(input_list)
expected_str = (
'@{"str key" = "Value String"; '
- '"esc str`n" = "`b`f`v`\""; '
- '$false = $true; '
- '11 = 99; '
- '18.3 = 12.34}'
+ '"esc str`n" = "`b`f`v`""; '
+ "$false = $true; "
+ "11 = 99; "
+ "18.3 = 12.34}"
)
self.assertEqual(result, expected_str)
def test_param_to_ps_dict_nexted(self):
input_list = collections.OrderedDict(
- [('a', {'deep_a': 'value'}),
- ('b', {'deep_b': {'deep_deep_b': 'value'}})])
+ [("a", {"deep_a": "value"}), ("b", {"deep_b": {"deep_deep_b": "value"}})]
+ )
result = self._runner._param_to_ps(input_list)
expected_str = (
'@{"a" = @{"deep_a" = "value"}; '
@@ -945,21 +1065,22 @@ def test_param_to_ps_deep_nested_dict_outer(self):
####
# dict as outer container
input_dict = collections.OrderedDict(
- [('a', [{'deep_a': 'value'},
- {'deep_b': ['a', 'b', 'c']}])])
+ [("a", [{"deep_a": "value"}, {"deep_b": ["a", "b", "c"]}])]
+ )
result = self._runner._param_to_ps(input_dict)
expected_str = (
- '@{"a" = @(@{"deep_a" = "value"}, '
- '@{"deep_b" = @("a", "b", "c")})}'
+ '@{"a" = @(@{"deep_a" = "value"}, ' '@{"deep_b" = @("a", "b", "c")})}'
)
self.assertEqual(result, expected_str)
def test_param_to_ps_deep_nested_list_outer(self):
####
# list as outer container
- input_list = [{'deep_a': 'value'},
- {'deep_b': ['a', 'b', 'c']},
- {'deep_c': [{'x': 'y'}]}]
+ input_list = [
+ {"deep_a": "value"},
+ {"deep_b": ["a", "b", "c"]},
+ {"deep_c": [{"x": "y"}]},
+ ]
result = self._runner._param_to_ps(input_list)
expected_str = (
'@(@{"deep_a" = "value"}, '
@@ -969,45 +1090,48 @@ def test_param_to_ps_deep_nested_list_outer(self):
self.assertEqual(result, expected_str)
def test_transform_params_to_ps(self):
- positional_args = [1, 'a', '\n']
+ positional_args = [1, "a", "\n"]
named_args = collections.OrderedDict(
- [('a', 'value1'),
- ('b', True),
- ('c', ['x', 'y']),
- ('d', {'z': 'w'})]
+ [("a", "value1"), ("b", True), ("c", ["x", "y"]), ("d", {"z": "w"})]
)
- result_pos, result_named = self._runner._transform_params_to_ps(positional_args,
- named_args)
- self.assertEqual(result_pos, ['1', '"a"', '"`n"'])
- self.assertEqual(result_named, collections.OrderedDict([
- ('a', '"value1"'),
- ('b', '$true'),
- ('c', '@("x", "y")'),
- ('d', '@{"z" = "w"}')]))
+ result_pos, result_named = self._runner._transform_params_to_ps(
+ positional_args, named_args
+ )
+ self.assertEqual(result_pos, ["1", '"a"', '"`n"'])
+ self.assertEqual(
+ result_named,
+ collections.OrderedDict(
+ [
+ ("a", '"value1"'),
+ ("b", "$true"),
+ ("c", '@("x", "y")'),
+ ("d", '@{"z" = "w"}'),
+ ]
+ ),
+ )
def test_transform_params_to_ps_none(self):
positional_args = None
named_args = None
- result_pos, result_named = self._runner._transform_params_to_ps(positional_args,
- named_args)
+ result_pos, result_named = self._runner._transform_params_to_ps(
+ positional_args, named_args
+ )
self.assertEqual(result_pos, None)
self.assertEqual(result_named, None)
def test_create_ps_params_string(self):
- positional_args = [1, 'a', '\n']
+ positional_args = [1, "a", "\n"]
named_args = collections.OrderedDict(
- [('-a', 'value1'),
- ('-b', True),
- ('-c', ['x', 'y']),
- ('-d', {'z': 'w'})]
+ [("-a", "value1"), ("-b", True), ("-c", ["x", "y"]), ("-d", {"z": "w"})]
)
result = self._runner.create_ps_params_string(positional_args, named_args)
- self.assertEqual(result,
- '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"')
+ self.assertEqual(
+ result, '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"'
+ )
def test_create_ps_params_string_none(self):
positional_args = None
diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py
index 9ff36a1b47..78365a333b 100644
--- a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py
+++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py
@@ -23,23 +23,22 @@
class WinRmCommandRunnerTestCase(RunnerTestCase):
-
def setUp(self):
super(WinRmCommandRunnerTestCase, self).setUpClass()
self._runner = winrm_command_runner.get_runner()
def test_init(self):
- runner = winrm_command_runner.WinRmCommandRunner('abcdef')
+ runner = winrm_command_runner.WinRmCommandRunner("abcdef")
self.assertIsInstance(runner, WinRmBaseRunner)
self.assertIsInstance(runner, ActionRunner)
- self.assertEqual(runner.runner_id, 'abcdef')
+ self.assertEqual(runner.runner_id, "abcdef")
- @mock.patch('winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd')
+ @mock.patch("winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd")
def test_run(self, mock_run_cmd):
- mock_run_cmd.return_value = 'expected'
+ mock_run_cmd.return_value = "expected"
- self._runner.runner_parameters = {'cmd': 'ipconfig /all'}
+ self._runner.runner_parameters = {"cmd": "ipconfig /all"}
result = self._runner.run({})
- self.assertEqual(result, 'expected')
- mock_run_cmd.assert_called_with('ipconfig /all')
+ self.assertEqual(result, "expected")
+ mock_run_cmd.assert_called_with("ipconfig /all")
diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py
index d6bae23e2c..90d9e95abd 100644
--- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py
+++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py
@@ -23,23 +23,22 @@
class WinRmPsCommandRunnerTestCase(RunnerTestCase):
-
def setUp(self):
super(WinRmPsCommandRunnerTestCase, self).setUpClass()
self._runner = winrm_ps_command_runner.get_runner()
def test_init(self):
- runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef')
+ runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef")
self.assertIsInstance(runner, WinRmBaseRunner)
self.assertIsInstance(runner, ActionRunner)
- self.assertEqual(runner.runner_id, 'abcdef')
+ self.assertEqual(runner.runner_id, "abcdef")
- @mock.patch('winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps')
+ @mock.patch("winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps")
def test_run(self, mock_run_ps):
- mock_run_ps.return_value = 'expected'
+ mock_run_ps.return_value = "expected"
- self._runner.runner_parameters = {'cmd': 'Get-ADUser stanley'}
+ self._runner.runner_parameters = {"cmd": "Get-ADUser stanley"}
result = self._runner.run({})
- self.assertEqual(result, 'expected')
- mock_run_ps.assert_called_with('Get-ADUser stanley')
+ self.assertEqual(result, "expected")
+ mock_run_ps.assert_called_with("Get-ADUser stanley")
diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py
index b3c1e14034..c1414c25e7 100644
--- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py
+++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py
@@ -22,39 +22,41 @@
from winrm_runner import winrm_ps_script_runner
from winrm_runner.winrm_base import WinRmBaseRunner
-FIXTURES_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')
+FIXTURES_PATH = os.path.join(os.path.dirname(__file__), "fixtures")
POWERSHELL_SCRIPT_PATH = os.path.join(FIXTURES_PATH, "TestScript.ps1")
class WinRmPsScriptRunnerTestCase(RunnerTestCase):
-
def setUp(self):
super(WinRmPsScriptRunnerTestCase, self).setUpClass()
self._runner = winrm_ps_script_runner.get_runner()
def test_init(self):
- runner = winrm_ps_script_runner.WinRmPsScriptRunner('abcdef')
+ runner = winrm_ps_script_runner.WinRmPsScriptRunner("abcdef")
self.assertIsInstance(runner, WinRmBaseRunner)
self.assertIsInstance(runner, ActionRunner)
- self.assertEqual(runner.runner_id, 'abcdef')
+ self.assertEqual(runner.runner_id, "abcdef")
- @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args')
- @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps')
+ @mock.patch(
+ "winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args"
+ )
+ @mock.patch("winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps")
def test_run(self, mock_run_ps, mock_get_script_args):
- mock_run_ps.return_value = 'expected'
- pos_args = [1, 'abc']
+ mock_run_ps.return_value = "expected"
+ pos_args = [1, "abc"]
named_args = {"d": {"test": ["\r", True, 3]}}
mock_get_script_args.return_value = (pos_args, named_args)
self._runner.entry_point = POWERSHELL_SCRIPT_PATH
self._runner.runner_parameters = {}
- self._runner._kwarg_op = '-'
+ self._runner._kwarg_op = "-"
result = self._runner.run({})
- self.assertEqual(result, 'expected')
- mock_run_ps.assert_called_with('''[CmdletBinding()]
+ self.assertEqual(result, "expected")
+ mock_run_ps.assert_called_with(
+ """[CmdletBinding()]
Param(
[bool]$p_bool,
[int]$p_integer,
@@ -77,5 +79,6 @@ def test_run(self, mock_run_ps, mock_get_script_args):
Write-Output "p_obj = $($p_obj | ConvertTo-Json -Compress)"
Write-Output "p_pos0 = $p_pos0"
Write-Output "p_pos1 = $p_pos1"
-''',
- '-d @{"test" = @("`r", $true, 3)} 1 "abc"')
+""",
+ '-d @{"test" = @("`r", $true, 3)} 1 "abc"',
+ )
diff --git a/contrib/runners/winrm_runner/winrm_runner/__init__.py b/contrib/runners/winrm_runner/winrm_runner/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/contrib/runners/winrm_runner/winrm_runner/__init__.py
+++ b/contrib/runners/winrm_runner/winrm_runner/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py
index fb26e49db6..9bebbedc7b 100644
--- a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py
+++ b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py
@@ -32,7 +32,7 @@
from winrm.exceptions import WinRMOperationTimeoutError
__all__ = [
- 'WinRmBaseRunner',
+ "WinRmBaseRunner",
]
LOG = logging.getLogger(__name__)
@@ -49,7 +49,7 @@
RUNNER_USERNAME = "username"
RUNNER_VERIFY_SSL = "verify_ssl_cert"
-WINRM_DEFAULT_TMP_DIR_PS = '[System.IO.Path]::GetTempPath()'
+WINRM_DEFAULT_TMP_DIR_PS = "[System.IO.Path]::GetTempPath()"
# maximum cmdline length for systems >= Windows XP
# https://support.microsoft.com/en-us/help/830473/command-prompt-cmd-exe-command-line-string-limitation
WINRM_MAX_CMD_LENGTH = 8191
@@ -76,28 +76,28 @@
# Compiled list from the following sources:
# https://ss64.com/ps/syntax-esc.html
# https://www.techotopia.com/index.php/Windows_PowerShell_1.0_String_Quoting_and_Escape_Sequences#PowerShell_Special_Escape_Sequences
-PS_ESCAPE_SEQUENCES = {'\n': '`n',
- '\r': '`r',
- '\t': '`t',
- '\a': '`a',
- '\b': '`b',
- '\f': '`f',
- '\v': '`v',
- '"': '`"',
- '\'': '`\'',
- '`': '``',
- '\0': '`0',
- '$': '`$'}
+PS_ESCAPE_SEQUENCES = {
+ "\n": "`n",
+ "\r": "`r",
+ "\t": "`t",
+ "\a": "`a",
+ "\b": "`b",
+ "\f": "`f",
+ "\v": "`v",
+ '"': '`"',
+ "'": "`'",
+ "`": "``",
+ "\0": "`0",
+ "$": "`$",
+}
class WinRmRunnerTimoutError(Exception):
-
def __init__(self, response):
self.response = response
class WinRmBaseRunner(ActionRunner):
-
def pre_run(self):
super(WinRmBaseRunner, self).pre_run()
@@ -107,12 +107,16 @@ def pre_run(self):
self._username = self.runner_parameters[RUNNER_USERNAME]
self._password = self.runner_parameters[RUNNER_PASSWORD]
self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, DEFAULT_TIMEOUT)
- self._read_timeout = self._timeout + 1 # read_timeout must be > operation_timeout
+ self._read_timeout = (
+ self._timeout + 1
+ ) # read_timeout must be > operation_timeout
# default to https port 5986 over ntlm
self._port = self.runner_parameters.get(RUNNER_PORT, DEFAULT_PORT)
self._scheme = self.runner_parameters.get(RUNNER_SCHEME, DEFAULT_SCHEME)
- self._transport = self.runner_parameters.get(RUNNER_TRANSPORT, DEFAULT_TRANSPORT)
+ self._transport = self.runner_parameters.get(
+ RUNNER_TRANSPORT, DEFAULT_TRANSPORT
+ )
# if connecting to the HTTP port then we must use "http" as the scheme
# in the URL
@@ -120,10 +124,14 @@ def pre_run(self):
self._scheme = "http"
# construct the URL for connecting to WinRM on the host
- self._winrm_url = "{}://{}:{}/wsman".format(self._scheme, self._host, self._port)
+ self._winrm_url = "{}://{}:{}/wsman".format(
+ self._scheme, self._host, self._port
+ )
# default to verifying SSL certs
- self._verify_ssl = self.runner_parameters.get(RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL)
+ self._verify_ssl = self.runner_parameters.get(
+ RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL
+ )
self._server_cert_validation = "validate" if self._verify_ssl else "ignore"
# additional parameters
@@ -136,12 +144,14 @@ def _get_session(self):
# cache session (only create if it doesn't exist yet)
if not self._session:
LOG.debug("Connecting via WinRM to url: {}".format(self._winrm_url))
- self._session = Session(self._winrm_url,
- auth=(self._username, self._password),
- transport=self._transport,
- server_cert_validation=self._server_cert_validation,
- operation_timeout_sec=self._timeout,
- read_timeout_sec=self._read_timeout)
+ self._session = Session(
+ self._winrm_url,
+ auth=(self._username, self._password),
+ transport=self._transport,
+ server_cert_validation=self._server_cert_validation,
+ operation_timeout_sec=self._timeout,
+ read_timeout_sec=self._read_timeout,
+ )
return self._session
def _winrm_get_command_output(self, protocol, shell_id, command_id):
@@ -154,37 +164,46 @@ def _winrm_get_command_output(self, protocol, shell_id, command_id):
while not command_done:
# check if we need to timeout (StackStorm custom)
current_time = time.time()
- elapsed_time = (current_time - start_time)
+ elapsed_time = current_time - start_time
if self._timeout and (elapsed_time > self._timeout):
- raise WinRmRunnerTimoutError(Response((b''.join(stdout_buffer),
- b''.join(stderr_buffer),
- WINRM_TIMEOUT_EXIT_CODE)))
+ raise WinRmRunnerTimoutError(
+ Response(
+ (
+ b"".join(stdout_buffer),
+ b"".join(stderr_buffer),
+ WINRM_TIMEOUT_EXIT_CODE,
+ )
+ )
+ )
# end stackstorm custom
try:
- stdout, stderr, return_code, command_done = \
- protocol._raw_get_command_output(shell_id, command_id)
+ (
+ stdout,
+ stderr,
+ return_code,
+ command_done,
+ ) = protocol._raw_get_command_output(shell_id, command_id)
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)
except WinRMOperationTimeoutError:
# this is an expected error when waiting for a long-running process,
# just silently retry
pass
- return b''.join(stdout_buffer), b''.join(stderr_buffer), return_code
+ return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code
def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None):
# NOTE: this is copied from pywinrm because it doesn't support
# passing env and working_directory from the Session.run_cmd.
# It also doesn't support timeouts. All of these things have been
# added
- shell_id = session.protocol.open_shell(env_vars=env,
- working_directory=cwd)
+ shell_id = session.protocol.open_shell(env_vars=env, working_directory=cwd)
command_id = session.protocol.run_command(shell_id, command, args)
# try/catch is for custom timeout handing (StackStorm custom)
try:
- rs = Response(self._winrm_get_command_output(session.protocol,
- shell_id,
- command_id))
+ rs = Response(
+ self._winrm_get_command_output(session.protocol, shell_id, command_id)
+ )
rs.timeout = False
except WinRmRunnerTimoutError as e:
rs = e.response
@@ -195,37 +214,34 @@ def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None):
return rs
def _winrm_encode(self, script):
- return b64encode(script.encode('utf_16_le')).decode('ascii')
+ return b64encode(script.encode("utf_16_le")).decode("ascii")
def _winrm_ps_cmd(self, encoded_ps):
- return 'powershell -encodedcommand {0}'.format(encoded_ps)
+ return "powershell -encodedcommand {0}".format(encoded_ps)
def _winrm_run_ps(self, session, script, env=None, cwd=None, is_b64=False):
# NOTE: this is copied from pywinrm because it doesn't support
# passing env and working_directory from the Session.run_ps
# encode the script in UTF only if it isn't passed in encoded
- LOG.debug('_winrm_run_ps() - script size = {}'.format(len(script)))
+ LOG.debug("_winrm_run_ps() - script size = {}".format(len(script)))
encoded_ps = script if is_b64 else self._winrm_encode(script)
ps_cmd = self._winrm_ps_cmd(encoded_ps)
- LOG.debug('_winrm_run_ps() - ps cmd size = {}'.format(len(ps_cmd)))
- rs = self._winrm_run_cmd(session,
- ps_cmd,
- env=env,
- cwd=cwd)
+ LOG.debug("_winrm_run_ps() - ps cmd size = {}".format(len(ps_cmd)))
+ rs = self._winrm_run_cmd(session, ps_cmd, env=env, cwd=cwd)
if len(rs.std_err):
# if there was an error message, clean it it up and make it human
# readable
if isinstance(rs.std_err, bytes):
# decode bytes into utf-8 because of a bug in pywinrm
# real fix is here: https://github.com/diyan/pywinrm/pull/222/files
- rs.std_err = rs.std_err.decode('utf-8')
+ rs.std_err = rs.std_err.decode("utf-8")
rs.std_err = session._clean_error_msg(rs.std_err)
return rs
def _translate_response(self, response):
# check exit status for errors
- succeeded = (response.status_code == exit_code_constants.SUCCESS_EXIT_CODE)
+ succeeded = response.status_code == exit_code_constants.SUCCESS_EXIT_CODE
status = action_constants.LIVEACTION_STATUS_SUCCEEDED
status_code = response.status_code
if response.timeout:
@@ -236,39 +252,46 @@ def _translate_response(self, response):
# create result
result = {
- 'failed': not succeeded,
- 'succeeded': succeeded,
- 'return_code': status_code,
- 'stdout': response.std_out,
- 'stderr': response.std_err
+ "failed": not succeeded,
+ "succeeded": succeeded,
+ "return_code": status_code,
+ "stdout": response.std_out,
+ "stderr": response.std_err,
}
# Ensure stdout and stderr is always a string
- if isinstance(result['stdout'], six.binary_type):
- result['stdout'] = result['stdout'].decode('utf-8')
+ if isinstance(result["stdout"], six.binary_type):
+ result["stdout"] = result["stdout"].decode("utf-8")
- if isinstance(result['stderr'], six.binary_type):
- result['stderr'] = result['stderr'].decode('utf-8')
+ if isinstance(result["stderr"], six.binary_type):
+ result["stderr"] = result["stderr"].decode("utf-8")
# automatically convert result stdout/stderr from JSON strings to
# objects so they can be used natively
return (status, jsonify.json_loads(result, RESULT_KEYS_TO_TRANSFORM), None)
def _make_tmp_dir(self, parent):
- LOG.debug("Creating temporary directory for WinRM script in parent: {}".format(parent))
+ LOG.debug(
+ "Creating temporary directory for WinRM script in parent: {}".format(parent)
+ )
ps = """$parent = {parent}
$name = [System.IO.Path]::GetRandomFileName()
$path = Join-Path $parent $name
New-Item -ItemType Directory -Path $path | Out-Null
-$path""".format(parent=parent)
- result = self._run_ps_or_raise(ps, ("Unable to make temporary directory for"
- " powershell script"))
+$path""".format(
+ parent=parent
+ )
+ result = self._run_ps_or_raise(
+ ps, ("Unable to make temporary directory for" " powershell script")
+ )
# strip to remove trailing newline and whitespace (if any)
- return result['stdout'].strip()
+ return result["stdout"].strip()
def _rm_dir(self, directory):
ps = 'Remove-Item -Force -Recurse -Path "{}"'.format(directory)
- self._run_ps_or_raise(ps, "Unable to remove temporary directory for powershell script")
+ self._run_ps_or_raise(
+ ps, "Unable to remove temporary directory for powershell script"
+ )
def _upload(self, src_path_or_data, dst_path):
src_data = None
@@ -276,7 +299,7 @@ def _upload(self, src_path_or_data, dst_path):
# if this is a path, then read the data from the path
if os.path.exists(src_path_or_data):
LOG.debug("WinRM uploading local file: {}".format(src_path_or_data))
- with open(src_path_or_data, 'r') as src_file:
+ with open(src_path_or_data, "r") as src_file:
src_data = src_file.read()
else:
LOG.debug("WinRM uploading data from a string")
@@ -285,14 +308,19 @@ def _upload(self, src_path_or_data, dst_path):
# upload the data in chunks such that each chunk doesn't exceed the
# max command size of the windows command line
for i in range(0, len(src_data), WINRM_UPLOAD_CHUNK_SIZE_BYTES):
- LOG.debug("WinRM uploading data bytes: {}-{}".
- format(i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)))
- self._upload_chunk(dst_path, src_data[i:(i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)])
+ LOG.debug(
+ "WinRM uploading data bytes: {}-{}".format(
+ i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)
+ )
+ )
+ self._upload_chunk(
+ dst_path, src_data[i : (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)]
+ )
def _upload_chunk(self, dst_path, src_data):
# adapted from https://github.com/diyan/pywinrm/issues/18
if not isinstance(src_data, six.binary_type):
- src_data = src_data.encode('utf-8')
+ src_data = src_data.encode("utf-8")
ps = """$filePath = "{dst_path}"
$s = @"
@@ -300,10 +328,11 @@ def _upload_chunk(self, dst_path, src_data):
"@
$data = [System.Convert]::FromBase64String($s)
Add-Content -value $data -encoding byte -path $filePath
-""".format(dst_path=dst_path,
- b64_data=base64.b64encode(src_data).decode('utf-8'))
+""".format(
+ dst_path=dst_path, b64_data=base64.b64encode(src_data).decode("utf-8")
+ )
- LOG.debug('WinRM uploading chunk, size = {}'.format(len(ps)))
+ LOG.debug("WinRM uploading chunk, size = {}".format(len(ps)))
self._run_ps_or_raise(ps, "Failed to upload chunk of powershell script")
@contextmanager
@@ -335,7 +364,7 @@ def run_cmd(self, cmd):
def run_ps(self, script, params=None):
# temporary directory for the powershell script
if params:
- powershell = '& {%s} %s' % (script, params)
+ powershell = "& {%s} %s" % (script, params)
else:
powershell = script
encoded_ps = self._winrm_encode(powershell)
@@ -346,9 +375,12 @@ def run_ps(self, script, params=None):
# else we need to upload the script to a temporary file and execute it,
# then remove the temporary file
if len(ps_cmd) <= WINRM_MAX_CMD_LENGTH:
- LOG.info(("WinRM powershell command size {} is > {}, the max size of a"
- " powershell command. Converting to a script execution.")
- .format(WINRM_MAX_CMD_LENGTH, len(ps_cmd)))
+ LOG.info(
+ (
+ "WinRM powershell command size {} is > {}, the max size of a"
+ " powershell command. Converting to a script execution."
+ ).format(WINRM_MAX_CMD_LENGTH, len(ps_cmd))
+ )
return self._run_ps(encoded_ps, is_b64=True)
else:
return self._run_ps_script(script, params)
@@ -360,8 +392,9 @@ def _run_ps(self, powershell, is_b64=False):
# connect
session = self._get_session()
# execute
- response = self._winrm_run_ps(session, powershell, env=self._env, cwd=self._cwd,
- is_b64=is_b64)
+ response = self._winrm_run_ps(
+ session, powershell, env=self._env, cwd=self._cwd, is_b64=is_b64
+ )
# create triplet from WinRM response
return self._translate_response(response)
@@ -383,12 +416,12 @@ def _run_ps_or_raise(self, ps, error_msg):
response = self._run_ps(ps)
# response is a tuple: (status, result, None)
result = response[1]
- if result['failed']:
- raise RuntimeError(("{}:\n"
- "stdout = {}\n\n"
- "stderr = {}").format(error_msg,
- result['stdout'],
- result['stderr']))
+ if result["failed"]:
+ raise RuntimeError(
+ ("{}:\n" "stdout = {}\n\n" "stderr = {}").format(
+ error_msg, result["stdout"], result["stderr"]
+ )
+ )
return result
def _multireplace(self, string, replacements):
@@ -407,7 +440,7 @@ def _multireplace(self, string, replacements):
substrs = sorted(replacements, key=len, reverse=True)
# Create a big OR regex that matches any of the substrings to replace
- regexp = re.compile('|'.join([re.escape(s) for s in substrs]))
+ regexp = re.compile("|".join([re.escape(s) for s in substrs]))
# For each match, look up the new string in the replacements
return regexp.sub(lambda match: replacements[match.group(0)], string)
@@ -426,8 +459,12 @@ def _param_to_ps(self, param):
ps_str += ")"
elif isinstance(param, dict):
ps_str = "@{"
- ps_str += "; ".join([(self._param_to_ps(k) + ' = ' + self._param_to_ps(v))
- for k, v in six.iteritems(param)])
+ ps_str += "; ".join(
+ [
+ (self._param_to_ps(k) + " = " + self._param_to_ps(v))
+ for k, v in six.iteritems(param)
+ ]
+ )
ps_str += "}"
else:
ps_str = str(param)
@@ -446,12 +483,15 @@ def _transform_params_to_ps(self, positional_args, named_args):
def create_ps_params_string(self, positional_args, named_args):
# convert the script parameters into powershell strings
- positional_args, named_args = self._transform_params_to_ps(positional_args,
- named_args)
+ positional_args, named_args = self._transform_params_to_ps(
+ positional_args, named_args
+ )
# concatenate them into a long string
ps_params_str = ""
if named_args:
- ps_params_str += " " .join([(k + " " + v) for k, v in six.iteritems(named_args)])
+ ps_params_str += " ".join(
+ [(k + " " + v) for k, v in six.iteritems(named_args)]
+ )
ps_params_str += " "
if positional_args:
ps_params_str += " ".join(positional_args)
diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py
index d09e5ce7d6..1239f3efd5 100644
--- a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py
+++ b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py
@@ -20,19 +20,14 @@
from st2common.runners.base import get_metadata as get_runner_metadata
from winrm_runner.winrm_base import WinRmBaseRunner
-__all__ = [
- 'WinRmCommandRunner',
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["WinRmCommandRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
-RUNNER_COMMAND = 'cmd'
+RUNNER_COMMAND = "cmd"
class WinRmCommandRunner(WinRmBaseRunner):
-
def run(self, action_parameters):
cmd_command = self.runner_parameters[RUNNER_COMMAND]
@@ -45,7 +40,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('winrm_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("winrm_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py
index f49db2b09e..e6d0a37e2f 100644
--- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py
+++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py
@@ -20,19 +20,14 @@
from st2common.runners.base import get_metadata as get_runner_metadata
from winrm_runner.winrm_base import WinRmBaseRunner
-__all__ = [
- 'WinRmPsCommandRunner',
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["WinRmPsCommandRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
-RUNNER_COMMAND = 'cmd'
+RUNNER_COMMAND = "cmd"
class WinRmPsCommandRunner(WinRmBaseRunner):
-
def run(self, action_parameters):
powershell_command = self.runner_parameters[RUNNER_COMMAND]
@@ -45,7 +40,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('winrm_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("winrm_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py
index 9f156bd8c9..ff162b7aee 100644
--- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py
+++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py
@@ -21,23 +21,18 @@
from st2common.runners.base import get_metadata as get_runner_metadata
from winrm_runner.winrm_base import WinRmBaseRunner
-__all__ = [
- 'WinRmPsScriptRunner',
- 'get_runner',
- 'get_metadata'
-]
+__all__ = ["WinRmPsScriptRunner", "get_runner", "get_metadata"]
LOG = logging.getLogger(__name__)
class WinRmPsScriptRunner(WinRmBaseRunner, ShellRunnerMixin):
-
def run(self, action_parameters):
if not self.entry_point:
- raise ValueError('Missing entry_point action metadata attribute')
+ raise ValueError("Missing entry_point action metadata attribute")
# read in the script contents from the local file
- with open(self.entry_point, 'r') as script_file:
+ with open(self.entry_point, "r") as script_file:
ps_script = script_file.read()
# extract script parameters specified in the action metadata file
@@ -57,7 +52,10 @@ def get_runner():
def get_metadata():
- metadata = get_runner_metadata('winrm_runner')
- metadata = [runner for runner in metadata if
- runner['runner_module'] == __name__.split('.')[-1]][0]
+ metadata = get_runner_metadata("winrm_runner")
+ metadata = [
+ runner
+ for runner in metadata
+ if runner["runner_module"] == __name__.split(".")[-1]
+ ][0]
return metadata
diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst
index 4e1c1f22d2..f61cedcba4 100644
--- a/dev_docs/Troubleshooting_Guide.rst
+++ b/dev_docs/Troubleshooting_Guide.rst
@@ -28,7 +28,7 @@ Troubleshooting Guide
$ sudo netstat -tupln | grep 910
tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python
tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python
-
+
As we can see from above output port ``9101`` is not even up. To verify this let us try another command:
.. code:: bash
@@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let
$ ps auxww | grep st2 | grep 910
vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python
./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1
- vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403
+ vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403
vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python
./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1
-
+
- This suggests that the API process crashed, we can verify that by running ``screen -ls``.::
.. code:: bash
@@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let
15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached)
15762.st2-stream (04/26/2016 06:39:10 PM) (Detached)
3 Sockets in /var/run/screen/S-vagrant.
-
-- Now let us check the logs for any errors:
+
+- Now let us check the logs for any errors:
.. code:: bash
tail logs/st2api.log
- 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d
- (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit',
- 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit',
+ 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d
+ (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit',
+ 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit',
'id': '570e9704909a5030cf758e6d', 'pack': u'core'})
- 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType.
- Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0,
- 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d',
+ 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType.
+ Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0,
+ 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d',
'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'})
2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB.
Traceback (most recent call last):
@@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let
NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" })
2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt
2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt
-
+
- To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command :
.. code:: bash
@@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en
File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in
from keyczar.keys import AesKey
ImportError: No module named keyczar.keys
-
+
So the problem is : module keyczar is missing. This module can be downloaded using following command:
*Solution:*
@@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi
.. code:: bash
(virtualenv) $ pip install python-keyczar
-
+
This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart``
diff --git a/lint-configs/python/.flake8 b/lint-configs/python/.flake8
index f3cc01b319..271a9a21e6 100644
--- a/lint-configs/python/.flake8
+++ b/lint-configs/python/.flake8
@@ -2,7 +2,14 @@
max-line-length = 100
# L102 - apache license header
enable-extensions = L101,L102
-ignore = E128,E402,E722,W504
+# We ignore some rules which conflict with black
+# E203 - whitespace before ':' - in direct conflict with black rule
+# W503 line break before binary operator - in direct conflict with black rule
+# E501 is line length limit
+# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length
+# We don't really need line length rule since black formatting takes care of
+# that.
+ignore = E128,E402,E722,W504,E501,E203,W503
exclude=*.egg/*,build,dist
# Configuration for flake8-copyright extension
diff --git a/pylint_plugins/api_models.py b/pylint_plugins/api_models.py
index 398a664d40..4e14095f71 100644
--- a/pylint_plugins/api_models.py
+++ b/pylint_plugins/api_models.py
@@ -29,9 +29,7 @@
from astroid import scoped_nodes
# A list of class names for which we want to skip the checks
-CLASS_NAME_BLACKLIST = [
- 'ExecutionSpecificationAPI'
-]
+CLASS_NAME_BLACKLIST = ["ExecutionSpecificationAPI"]
def register(linter):
@@ -42,11 +40,11 @@ def transform(cls):
if cls.name in CLASS_NAME_BLACKLIST:
return
- if cls.name.endswith('API') or 'schema' in cls.locals:
+ if cls.name.endswith("API") or "schema" in cls.locals:
# This is a class which defines attributes in "schema" variable using json schema.
# Those attributes are then assigned during run time inside the constructor
fqdn = cls.qname()
- module_name, class_name = fqdn.rsplit('.', 1)
+ module_name, class_name = fqdn.rsplit(".", 1)
module = __import__(module_name, fromlist=[class_name])
actual_cls = getattr(module, class_name)
@@ -57,29 +55,31 @@ def transform(cls):
# Not a class we are interested in
return
- properties = schema.get('properties', {})
+ properties = schema.get("properties", {})
for property_name, property_data in six.iteritems(properties):
- property_name = property_name.replace('-', '_') # Note: We do the same in Python code
- property_type = property_data.get('type', None)
+ property_name = property_name.replace(
+ "-", "_"
+ ) # Note: We do the same in Python code
+ property_type = property_data.get("type", None)
if isinstance(property_type, (list, tuple)):
# Hack for attributes with multiple types (e.g. string, null)
property_type = property_type[0]
- if property_type == 'object':
+ if property_type == "object":
node = nodes.Dict()
- elif property_type == 'array':
+ elif property_type == "array":
node = nodes.List()
- elif property_type == 'integer':
- node = scoped_nodes.builtin_lookup('int')[1][0]
- elif property_type == 'number':
- node = scoped_nodes.builtin_lookup('float')[1][0]
- elif property_type == 'string':
- node = scoped_nodes.builtin_lookup('str')[1][0]
- elif property_type == 'boolean':
- node = scoped_nodes.builtin_lookup('bool')[1][0]
- elif property_type == 'null':
- node = scoped_nodes.builtin_lookup('None')[1][0]
+ elif property_type == "integer":
+ node = scoped_nodes.builtin_lookup("int")[1][0]
+ elif property_type == "number":
+ node = scoped_nodes.builtin_lookup("float")[1][0]
+ elif property_type == "string":
+ node = scoped_nodes.builtin_lookup("str")[1][0]
+ elif property_type == "boolean":
+ node = scoped_nodes.builtin_lookup("bool")[1][0]
+ elif property_type == "null":
+ node = scoped_nodes.builtin_lookup("None")[1][0]
else:
# Unknown type
node = astroid.ClassDef(property_name, None)
diff --git a/pylint_plugins/db_models.py b/pylint_plugins/db_models.py
index 241e9ea582..da9251462e 100644
--- a/pylint_plugins/db_models.py
+++ b/pylint_plugins/db_models.py
@@ -23,8 +23,7 @@
from astroid import nodes
# A list of class names for which we want to skip the checks
-CLASS_NAME_BLACKLIST = [
-]
+CLASS_NAME_BLACKLIST = []
def register(linter):
@@ -35,14 +34,14 @@ def transform(cls):
if cls.name in CLASS_NAME_BLACKLIST:
return
- if cls.name == 'StormFoundationDB':
+ if cls.name == "StormFoundationDB":
# _fields get added automagically by mongoengine
- if '_fields' not in cls.locals:
- cls.locals['_fields'] = [nodes.Dict()]
+ if "_fields" not in cls.locals:
+ cls.locals["_fields"] = [nodes.Dict()]
- if cls.name.endswith('DB'):
+ if cls.name.endswith("DB"):
# mongoengine explicitly declared "id" field on each class so we teach pylint about that
- property_name = 'id'
+ property_name = "id"
node = astroid.ClassDef(property_name, None)
cls.locals[property_name] = [node]
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..1889c6a5da
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,14 @@
+[tool.black]
+max-line-length = 100
+target_version = ['py36']
+include = '\.pyi?$'
+exclude = '''
+(
+ /(
+ | \.git
+ | \.virtualenv
+ | __pycache__
+ | test_content_version
+ )/
+)
+'''
diff --git a/scripts/dist_utils.py b/scripts/dist_utils.py
index ba73f554c6..c0af527b6b 100644
--- a/scripts/dist_utils.py
+++ b/scripts/dist_utils.py
@@ -47,17 +47,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -68,15 +68,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -85,10 +85,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -102,30 +104,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -135,8 +139,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -150,7 +154,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -159,14 +163,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/scripts/dist_utils_old.py b/scripts/dist_utils_old.py
index 5dfadb1bef..da38f6edbf 100644
--- a/scripts/dist_utils_old.py
+++ b/scripts/dist_utils_old.py
@@ -35,17 +35,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
try:
import pip
from pip import __version__ as pip_version
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
try:
@@ -57,28 +57,30 @@
try:
from pip._internal.req.req_file import parse_requirements
except ImportError as e:
- print('Failed to import parse_requirements from pip: %s' % (text_type(e)))
- print('Using pip: %s' % (str(pip_version)))
+ print("Failed to import parse_requirements from pip: %s" % (text_type(e)))
+ print("Using pip: %s" % (str(pip_version)))
sys.exit(1)
__all__ = [
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
@@ -90,7 +92,7 @@ def fetch_requirements(requirements_file_path):
reqs = []
for req in parse_requirements(requirements_file_path, session=False):
# Note: req.url was used before 9.0.0 and req.link is used in all the recent versions
- link = getattr(req, 'link', getattr(req, 'url', None))
+ link = getattr(req, "link", getattr(req, "url", None))
if link:
links.append(str(link))
reqs.append(str(req.req))
@@ -104,7 +106,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -113,14 +115,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/scripts/fixate-requirements.py b/scripts/fixate-requirements.py
index dd5c8d2505..4277c986f8 100755
--- a/scripts/fixate-requirements.py
+++ b/scripts/fixate-requirements.py
@@ -43,18 +43,18 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
OSCWD = os.path.abspath(os.curdir)
-GET_PIP = ' curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = " curl https://bootstrap.pypa.io/get-pip.py | python"
try:
import pip
from pip import __version__ as pip_version
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
try:
@@ -66,24 +66,43 @@
try:
from pip._internal.req.req_file import parse_requirements
except ImportError as e:
- print('Failed to import parse_requirements from pip: %s' % (text_type(e)))
- print('Using pip: %s' % (str(pip_version)))
+ print("Failed to import parse_requirements from pip: %s" % (text_type(e)))
+ print("Using pip: %s" % (str(pip_version)))
sys.exit(1)
def parse_args():
- parser = argparse.ArgumentParser(description='Tool for requirements.txt generation.')
- parser.add_argument('-s', '--source-requirements', nargs='+',
- required=True,
- help='Specify paths to requirements file(s). '
- 'In case several requirements files are given their content is merged.')
- parser.add_argument('-f', '--fixed-requirements', required=True,
- help='Specify path to fixed-requirements.txt file.')
- parser.add_argument('-o', '--output-file', default='requirements.txt',
- help='Specify path to the resulting requirements file.')
- parser.add_argument('--skip', default=None,
- help=('Comma delimited list of requirements to not '
- 'include in the generated file.'))
+ parser = argparse.ArgumentParser(
+ description="Tool for requirements.txt generation."
+ )
+ parser.add_argument(
+ "-s",
+ "--source-requirements",
+ nargs="+",
+ required=True,
+ help="Specify paths to requirements file(s). "
+ "In case several requirements files are given their content is merged.",
+ )
+ parser.add_argument(
+ "-f",
+ "--fixed-requirements",
+ required=True,
+ help="Specify path to fixed-requirements.txt file.",
+ )
+ parser.add_argument(
+ "-o",
+ "--output-file",
+ default="requirements.txt",
+ help="Specify path to the resulting requirements file.",
+ )
+ parser.add_argument(
+ "--skip",
+ default=None,
+ help=(
+ "Comma delimited list of requirements to not "
+ "include in the generated file."
+ ),
+ )
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
@@ -91,9 +110,11 @@ def parse_args():
def check_pip_version():
- if StrictVersion(pip.__version__) < StrictVersion('6.1.0'):
- print("Upgrade pip, your version `{0}' "
- "is outdated:\n".format(pip.__version__), GET_PIP)
+ if StrictVersion(pip.__version__) < StrictVersion("6.1.0"):
+ print(
+ "Upgrade pip, your version `{0}' " "is outdated:\n".format(pip.__version__),
+ GET_PIP,
+ )
sys.exit(1)
@@ -129,13 +150,14 @@ def merge_source_requirements(sources):
elif req.link:
merged_requirements.append(req)
else:
- raise RuntimeError('Unexpected requirement {0}'.format(req))
+ raise RuntimeError("Unexpected requirement {0}".format(req))
return merged_requirements
-def write_requirements(sources=None, fixed_requirements=None, output_file=None,
- skip=None):
+def write_requirements(
+ sources=None, fixed_requirements=None, output_file=None, skip=None
+):
"""
Write resulting requirements taking versions from the fixed_requirements.
"""
@@ -153,7 +175,9 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None,
continue
if project_name in fixedreq_hash:
- raise ValueError('Duplicate definition for dependency "%s"' % (project_name))
+ raise ValueError(
+ 'Duplicate definition for dependency "%s"' % (project_name)
+ )
fixedreq_hash[project_name] = req
@@ -169,7 +193,7 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None,
rline = str(req.link)
if req.editable:
- rline = '-e %s' % (rline)
+ rline = "-e %s" % (rline)
elif req.req:
project = req.name
req_obj = fixedreq_hash.get(project, req)
@@ -184,30 +208,40 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None,
# Sort the lines to guarantee a stable order
lines_to_write = sorted(lines_to_write)
- data = '\n'.join(lines_to_write) + '\n'
- with open(output_file, 'w') as fp:
- fp.write('# Don\'t edit this file. It\'s generated automatically!\n')
- fp.write('# If you want to update global dependencies, modify fixed-requirements.txt\n')
- fp.write('# and then run \'make requirements\' to update requirements.txt for all\n')
- fp.write('# components.\n')
- fp.write('# If you want to update depdencies for a single component, modify the\n')
- fp.write('# in-requirements.txt for that component and then run \'make requirements\' to\n')
- fp.write('# update the component requirements.txt\n')
+ data = "\n".join(lines_to_write) + "\n"
+ with open(output_file, "w") as fp:
+ fp.write("# Don't edit this file. It's generated automatically!\n")
+ fp.write(
+ "# If you want to update global dependencies, modify fixed-requirements.txt\n"
+ )
+ fp.write(
+ "# and then run 'make requirements' to update requirements.txt for all\n"
+ )
+ fp.write("# components.\n")
+ fp.write(
+ "# If you want to update depdencies for a single component, modify the\n"
+ )
+ fp.write(
+ "# in-requirements.txt for that component and then run 'make requirements' to\n"
+ )
+ fp.write("# update the component requirements.txt\n")
fp.write(data)
- print('Requirements written to: {0}'.format(output_file))
+ print("Requirements written to: {0}".format(output_file))
-if __name__ == '__main__':
+if __name__ == "__main__":
check_pip_version()
args = parse_args()
- if args['skip']:
- skip = args['skip'].split(',')
+ if args["skip"]:
+ skip = args["skip"].split(",")
else:
skip = None
- write_requirements(sources=args['source_requirements'],
- fixed_requirements=args['fixed_requirements'],
- output_file=args['output_file'],
- skip=skip)
+ write_requirements(
+ sources=args["source_requirements"],
+ fixed_requirements=args["fixed_requirements"],
+ output_file=args["output_file"],
+ skip=skip,
+ )
diff --git a/st2actions/dist_utils.py b/st2actions/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2actions/dist_utils.py
+++ b/st2actions/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2actions/setup.py b/st2actions/setup.py
index 6fcb2cde92..a4e8c12790 100644
--- a/st2actions/setup.py
+++ b/st2actions/setup.py
@@ -23,9 +23,9 @@
from dist_utils import apply_vagrant_workaround
from st2actions import __version__
-ST2_COMPONENT = 'st2actions'
+ST2_COMPONENT = "st2actions"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -33,21 +33,23 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
+ packages=find_packages(exclude=["setuptools", "tests"]),
scripts=[
- 'bin/st2actionrunner',
- 'bin/st2notifier',
- 'bin/st2workflowengine',
- 'bin/st2scheduler',
- ]
+ "bin/st2actionrunner",
+ "bin/st2notifier",
+ "bin/st2workflowengine",
+ "bin/st2scheduler",
+ ],
)
diff --git a/st2actions/st2actions/__init__.py b/st2actions/st2actions/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2actions/st2actions/__init__.py
+++ b/st2actions/st2actions/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2actions/st2actions/cmd/actionrunner.py b/st2actions/st2actions/cmd/actionrunner.py
index 457bf45e03..6aa339115a 100644
--- a/st2actions/st2actions/cmd/actionrunner.py
+++ b/st2actions/st2actions/cmd/actionrunner.py
@@ -18,6 +18,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -30,15 +31,12 @@
from st2common.service_setup import setup as common_setup
from st2common.service_setup import teardown as common_teardown
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _setup_sigterm_handler():
-
def sigterm_handler(signum=None, frame=None):
# This will cause SystemExit to be throw and allow for component cleanup.
sys.exit(0)
@@ -49,18 +47,22 @@ def sigterm_handler(signum=None, frame=None):
def _setup():
- capabilities = {
- 'name': 'actionrunner',
- 'type': 'passive'
- }
- common_setup(service='actionrunner', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "actionrunner", "type": "passive"}
+ common_setup(
+ service="actionrunner",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ service_registry=True,
+ capabilities=capabilities,
+ )
_setup_sigterm_handler()
def _run_worker():
- LOG.info('(PID=%s) Worker started.', os.getpid())
+ LOG.info("(PID=%s) Worker started.", os.getpid())
action_worker = worker.get_worker()
@@ -68,20 +70,20 @@ def _run_worker():
action_worker.start()
action_worker.wait()
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) Worker stopped.', os.getpid())
+ LOG.info("(PID=%s) Worker stopped.", os.getpid())
errors = False
try:
action_worker.shutdown()
except:
- LOG.exception('Unable to shutdown worker.')
+ LOG.exception("Unable to shutdown worker.")
errors = True
if errors:
return 1
except:
- LOG.exception('(PID=%s) Worker unexpectedly stopped.', os.getpid())
+ LOG.exception("(PID=%s) Worker unexpectedly stopped.", os.getpid())
return 1
return 0
@@ -98,7 +100,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except:
- LOG.exception('(PID=%s) Worker quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) Worker quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2actions/st2actions/cmd/scheduler.py b/st2actions/st2actions/cmd/scheduler.py
index b3c972b654..df6dd768db 100644
--- a/st2actions/st2actions/cmd/scheduler.py
+++ b/st2actions/st2actions/cmd/scheduler.py
@@ -17,6 +17,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -28,9 +29,7 @@
from st2common.service_setup import teardown as common_teardown
from st2common.service_setup import setup as common_setup
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
@@ -46,23 +45,27 @@ def sigterm_handler(signum=None, frame=None):
def _setup():
- capabilities = {
- 'name': 'scheduler',
- 'type': 'passive'
- }
- common_setup(service='scheduler', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "scheduler", "type": "passive"}
+ common_setup(
+ service="scheduler",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ service_registry=True,
+ capabilities=capabilities,
+ )
_setup_sigterm_handler()
def _run_scheduler():
- LOG.info('(PID=%s) Scheduler started.', os.getpid())
+ LOG.info("(PID=%s) Scheduler started.", os.getpid())
# Lazy load these so that decorator metrics are in place
from st2actions.scheduler import (
handler as scheduler_handler,
- entrypoint as scheduler_entrypoint
+ entrypoint as scheduler_entrypoint,
)
handler = scheduler_handler.get_handler()
@@ -73,14 +76,18 @@ def _run_scheduler():
try:
handler._cleanup_policy_delayed()
except Exception:
- LOG.exception('(PID=%s) Scheduler unable to perform migration cleanup.', os.getpid())
+ LOG.exception(
+ "(PID=%s) Scheduler unable to perform migration cleanup.", os.getpid()
+ )
# TODO: Remove this try block for _fix_missing_action_execution_id in v3.2.
# This is a temporary fix to auto-populate action_execution_id.
try:
handler._fix_missing_action_execution_id()
except Exception:
- LOG.exception('(PID=%s) Scheduler unable to populate action_execution_id.', os.getpid())
+ LOG.exception(
+ "(PID=%s) Scheduler unable to populate action_execution_id.", os.getpid()
+ )
try:
handler.start()
@@ -89,7 +96,7 @@ def _run_scheduler():
# Wait on handler first since entrypoint is more durable.
handler.wait() or entrypoint.wait()
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) Scheduler stopped.', os.getpid())
+ LOG.info("(PID=%s) Scheduler stopped.", os.getpid())
errors = False
@@ -97,13 +104,13 @@ def _run_scheduler():
handler.shutdown()
entrypoint.shutdown()
except:
- LOG.exception('Unable to shutdown scheduler.')
+ LOG.exception("Unable to shutdown scheduler.")
errors = True
if errors:
return 1
except:
- LOG.exception('(PID=%s) Scheduler unexpectedly stopped.', os.getpid())
+ LOG.exception("(PID=%s) Scheduler unexpectedly stopped.", os.getpid())
try:
handler.shutdown()
@@ -127,7 +134,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except:
- LOG.exception('(PID=%s) Scheduler quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) Scheduler quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2actions/st2actions/cmd/st2notifier.py b/st2actions/st2actions/cmd/st2notifier.py
index fdf74f5bf1..7f1ccc7222 100644
--- a/st2actions/st2actions/cmd/st2notifier.py
+++ b/st2actions/st2actions/cmd/st2notifier.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -27,29 +28,31 @@
from st2actions.notifier import config
from st2actions.notifier import notifier
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _setup():
- capabilities = {
- 'name': 'notifier',
- 'type': 'passive'
- }
- common_setup(service='notifier', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "notifier", "type": "passive"}
+ common_setup(
+ service="notifier",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _run_worker():
- LOG.info('(PID=%s) Actions notifier started.', os.getpid())
+ LOG.info("(PID=%s) Actions notifier started.", os.getpid())
actions_notifier = notifier.get_notifier()
try:
actions_notifier.start(wait=True)
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) Actions notifier stopped.', os.getpid())
+ LOG.info("(PID=%s) Actions notifier stopped.", os.getpid())
actions_notifier.shutdown()
return 0
@@ -65,7 +68,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except:
- LOG.exception('(PID=%s) Results tracker quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) Results tracker quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2actions/st2actions/cmd/workflow_engine.py b/st2actions/st2actions/cmd/workflow_engine.py
index 361d6ce9e1..f51296b4b0 100644
--- a/st2actions/st2actions/cmd/workflow_engine.py
+++ b/st2actions/st2actions/cmd/workflow_engine.py
@@ -19,6 +19,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -32,15 +33,12 @@
from st2common.service_setup import setup as common_setup
from st2common.service_setup import teardown as common_teardown
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def setup_sigterm_handler():
-
def sigterm_handler(signum=None, frame=None):
# This will cause SystemExit to be throw and allow for component cleanup.
sys.exit(0)
@@ -51,35 +49,32 @@ def sigterm_handler(signum=None, frame=None):
def setup():
- capabilities = {
- 'name': 'workflowengine',
- 'type': 'passive'
- }
+ capabilities = {"name": "workflowengine", "type": "passive"}
common_setup(
- service='workflow_engine',
+ service="workflow_engine",
config=config,
setup_db=True,
register_mq_exchanges=True,
register_signal_handlers=True,
service_registry=True,
- capabilities=capabilities
+ capabilities=capabilities,
)
setup_sigterm_handler()
def run_server():
- LOG.info('(PID=%s) Workflow engine started.', os.getpid())
+ LOG.info("(PID=%s) Workflow engine started.", os.getpid())
engine = workflows.get_engine()
try:
engine.start(wait=True)
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) Workflow engine stopped.', os.getpid())
+ LOG.info("(PID=%s) Workflow engine stopped.", os.getpid())
engine.shutdown()
except:
- LOG.exception('(PID=%s) Workflow engine unexpectedly stopped.', os.getpid())
+ LOG.exception("(PID=%s) Workflow engine unexpectedly stopped.", os.getpid())
return 1
return 0
@@ -97,7 +92,7 @@ def main():
sys.exit(exit_code)
except Exception:
traceback.print_exc()
- LOG.exception('(PID=%s) Workflow engine quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) Workflow engine quit due to exception.", os.getpid())
return 1
finally:
teardown()
diff --git a/st2actions/st2actions/config.py b/st2actions/st2actions/config.py
index b4e83a5306..14dc2c4f58 100644
--- a/st2actions/st2actions/config.py
+++ b/st2actions/st2actions/config.py
@@ -28,8 +28,11 @@
def parse_args(args=None):
- CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
diff --git a/st2actions/st2actions/container/base.py b/st2actions/st2actions/container/base.py
index a350a3dd69..7f2b50f0c7 100644
--- a/st2actions/st2actions/container/base.py
+++ b/st2actions/st2actions/container/base.py
@@ -30,8 +30,8 @@
from st2common.models.system.action import ResolvedActionParameters
from st2common.persistence.execution import ActionExecution
from st2common.services import access, executions, queries
-from st2common.util.action_db import (get_action_by_ref, get_runnertype_by_name)
-from st2common.util.action_db import (update_liveaction_status, get_liveaction_by_id)
+from st2common.util.action_db import get_action_by_ref, get_runnertype_by_name
+from st2common.util.action_db import update_liveaction_status, get_liveaction_by_id
from st2common.util import param as param_utils
from st2common.util.config_loader import ContentPackConfigLoader
from st2common.metrics.base import CounterWithTimer
@@ -42,30 +42,28 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'RunnerContainer',
- 'get_runner_container'
-]
+__all__ = ["RunnerContainer", "get_runner_container"]
class RunnerContainer(object):
-
def dispatch(self, liveaction_db):
action_db = get_action_by_ref(liveaction_db.action)
if not action_db:
- raise Exception('Action %s not found in DB.' % (liveaction_db.action))
+ raise Exception("Action %s not found in DB." % (liveaction_db.action))
- liveaction_db.context['pack'] = action_db.pack
+ liveaction_db.context["pack"] = action_db.pack
- runner_type_db = get_runnertype_by_name(action_db.runner_type['name'])
+ runner_type_db = get_runnertype_by_name(action_db.runner_type["name"])
- extra = {'liveaction_db': liveaction_db, 'runner_type_db': runner_type_db}
- LOG.info('Dispatching Action to a runner', extra=extra)
+ extra = {"liveaction_db": liveaction_db, "runner_type_db": runner_type_db}
+ LOG.info("Dispatching Action to a runner", extra=extra)
# Get runner instance.
runner = self._get_runner(runner_type_db, action_db, liveaction_db)
- LOG.debug('Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner)
+ LOG.debug(
+ 'Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner
+ )
# Process the request.
funcs = {
@@ -74,12 +72,12 @@ def dispatch(self, liveaction_db):
action_constants.LIVEACTION_STATUS_RUNNING: self._do_run,
action_constants.LIVEACTION_STATUS_CANCELING: self._do_cancel,
action_constants.LIVEACTION_STATUS_PAUSING: self._do_pause,
- action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume
+ action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume,
}
if liveaction_db.status not in funcs:
raise actionrunner.ActionRunnerDispatchError(
- 'Action runner is unable to dispatch the liveaction because it is '
+ "Action runner is unable to dispatch the liveaction because it is "
'in an unsupported status of "%s".' % liveaction_db.status
)
@@ -94,7 +92,8 @@ def _do_run(self, runner):
runner.auth_token = self._create_auth_token(
context=runner.context,
action_db=runner.action,
- liveaction_db=runner.liveaction)
+ liveaction_db=runner.liveaction,
+ )
try:
# Finalized parameters are resolved and then rendered. This process could
@@ -104,13 +103,14 @@ def _do_run(self, runner):
runner.runner_type.runner_parameters,
runner.action.parameters,
runner.liveaction.parameters,
- runner.liveaction.context)
+ runner.liveaction.context,
+ )
runner.runner_parameters = runner_params
except ParamException as e:
raise actionrunner.ActionRunnerException(six.text_type(e))
- LOG.debug('Performing pre-run for runner: %s', runner.runner_id)
+ LOG.debug("Performing pre-run for runner: %s", runner.runner_id)
runner.pre_run()
# Mask secret parameters in the log context
@@ -118,90 +118,117 @@ def _do_run(self, runner):
action_db=runner.action,
runner_type_db=runner.runner_type,
runner_parameters=runner_params,
- action_parameters=action_params)
+ action_parameters=action_params,
+ )
- extra = {'runner': runner, 'parameters': resolved_action_params}
- LOG.debug('Performing run for runner: %s' % (runner.runner_id), extra=extra)
+ extra = {"runner": runner, "parameters": resolved_action_params}
+ LOG.debug("Performing run for runner: %s" % (runner.runner_id), extra=extra)
- with CounterWithTimer(key='action.executions'):
- with CounterWithTimer(key='action.%s.executions' % (runner.action.ref)):
+ with CounterWithTimer(key="action.executions"):
+ with CounterWithTimer(key="action.%s.executions" % (runner.action.ref)):
(status, result, context) = runner.run(action_params)
result = jsonify.try_loads(result)
action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES
- if (isinstance(runner, PollingAsyncActionRunner) and
- runner.is_polling_enabled() and not action_completed):
+ if (
+ isinstance(runner, PollingAsyncActionRunner)
+ and runner.is_polling_enabled()
+ and not action_completed
+ ):
queries.setup_query(runner.liveaction.id, runner.runner_type, context)
except:
- LOG.exception('Failed to run action.')
+ LOG.exception("Failed to run action.")
_, ex, tb = sys.exc_info()
# mark execution as failed.
status = action_constants.LIVEACTION_STATUS_FAILED
# include the error message and traceback to try and provide some hints.
- result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}
+ result = {
+ "error": str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ }
context = None
finally:
# Log action completion
- extra = {'result': result, 'status': status}
+ extra = {"result": result, "status": status}
LOG.debug('Action "%s" completed.' % (runner.action.name), extra=extra)
# Update the final status of liveaction and corresponding action execution.
- runner.liveaction = self._update_status(runner.liveaction.id, status, result, context)
+ runner.liveaction = self._update_status(
+ runner.liveaction.id, status, result, context
+ )
# Always clean-up the auth_token
# This method should be called in the finally block to ensure post_run is not impacted.
self._clean_up_auth_token(runner=runner, status=status)
- LOG.debug('Performing post_run for runner: %s', runner.runner_id)
+ LOG.debug("Performing post_run for runner: %s", runner.runner_id)
runner.post_run(status=status, result=result)
- LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result})
- LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction})
+ LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result})
+ LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction})
return runner.liveaction
def _do_cancel(self, runner):
try:
- extra = {'runner': runner}
- LOG.debug('Performing cancel for runner: %s', (runner.runner_id), extra=extra)
+ extra = {"runner": runner}
+ LOG.debug(
+ "Performing cancel for runner: %s", (runner.runner_id), extra=extra
+ )
(status, result, context) = runner.cancel()
# Update the final status of liveaction and corresponding action execution.
# The status is updated here because we want to keep the workflow running
# as is if the cancel operation failed.
- runner.liveaction = self._update_status(runner.liveaction.id, status, result, context)
+ runner.liveaction = self._update_status(
+ runner.liveaction.id, status, result, context
+ )
except:
_, ex, tb = sys.exc_info()
# include the error message and traceback to try and provide some hints.
- result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}
- LOG.exception('Failed to cancel action %s.' % (runner.liveaction.id), extra=result)
+ result = {
+ "error": str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ }
+ LOG.exception(
+ "Failed to cancel action %s." % (runner.liveaction.id), extra=result
+ )
finally:
# Always clean-up the auth_token
# This method should be called in the finally block to ensure post_run is not impacted.
self._clean_up_auth_token(runner=runner, status=runner.liveaction.status)
- LOG.debug('Performing post_run for runner: %s', runner.runner_id)
- result = {'error': 'Execution canceled by user.'}
+ LOG.debug("Performing post_run for runner: %s", runner.runner_id)
+ result = {"error": "Execution canceled by user."}
runner.post_run(status=runner.liveaction.status, result=result)
return runner.liveaction
def _do_pause(self, runner):
try:
- extra = {'runner': runner}
- LOG.debug('Performing pause for runner: %s', (runner.runner_id), extra=extra)
+ extra = {"runner": runner}
+ LOG.debug(
+ "Performing pause for runner: %s", (runner.runner_id), extra=extra
+ )
(status, result, context) = runner.pause()
except:
_, ex, tb = sys.exc_info()
# include the error message and traceback to try and provide some hints.
status = action_constants.LIVEACTION_STATUS_FAILED
- result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}
+ result = {
+ "error": str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ }
context = runner.liveaction.context
- LOG.exception('Failed to pause action %s.' % (runner.liveaction.id), extra=result)
+ LOG.exception(
+ "Failed to pause action %s." % (runner.liveaction.id), extra=result
+ )
finally:
# Update the final status of liveaction and corresponding action execution.
- runner.liveaction = self._update_status(runner.liveaction.id, status, result, context)
+ runner.liveaction = self._update_status(
+ runner.liveaction.id, status, result, context
+ )
# Always clean-up the auth_token
self._clean_up_auth_token(runner=runner, status=runner.liveaction.status)
@@ -210,35 +237,47 @@ def _do_pause(self, runner):
def _do_resume(self, runner):
try:
- extra = {'runner': runner}
- LOG.debug('Performing resume for runner: %s', (runner.runner_id), extra=extra)
+ extra = {"runner": runner}
+ LOG.debug(
+ "Performing resume for runner: %s", (runner.runner_id), extra=extra
+ )
(status, result, context) = runner.resume()
result = jsonify.try_loads(result)
action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES
- if (isinstance(runner, PollingAsyncActionRunner) and
- runner.is_polling_enabled() and not action_completed):
+ if (
+ isinstance(runner, PollingAsyncActionRunner)
+ and runner.is_polling_enabled()
+ and not action_completed
+ ):
queries.setup_query(runner.liveaction.id, runner.runner_type, context)
except:
_, ex, tb = sys.exc_info()
# include the error message and traceback to try and provide some hints.
status = action_constants.LIVEACTION_STATUS_FAILED
- result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}
+ result = {
+ "error": str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ }
context = runner.liveaction.context
- LOG.exception('Failed to resume action %s.' % (runner.liveaction.id), extra=result)
+ LOG.exception(
+ "Failed to resume action %s." % (runner.liveaction.id), extra=result
+ )
finally:
# Update the final status of liveaction and corresponding action execution.
- runner.liveaction = self._update_status(runner.liveaction.id, status, result, context)
+ runner.liveaction = self._update_status(
+ runner.liveaction.id, status, result, context
+ )
# Always clean-up the auth_token
# This method should be called in the finally block to ensure post_run is not impacted.
self._clean_up_auth_token(runner=runner, status=runner.liveaction.status)
- LOG.debug('Performing post_run for runner: %s', runner.runner_id)
+ LOG.debug("Performing post_run for runner: %s", runner.runner_id)
runner.post_run(status=status, result=result)
- LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result})
- LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction})
+ LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result})
+ LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction})
return runner.liveaction
@@ -260,7 +299,7 @@ def _clean_up_auth_token(self, runner, status):
try:
self._delete_auth_token(runner.auth_token)
except:
- LOG.exception('Unable to clean-up auth_token.')
+ LOG.exception("Unable to clean-up auth_token.")
return True
@@ -273,8 +312,8 @@ def _update_live_action_db(self, liveaction_id, status, result, context):
liveaction_db = get_liveaction_by_id(liveaction_id)
state_changed = (
- liveaction_db.status != status and
- liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES
+ liveaction_db.status != status
+ and liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES
)
if status in action_constants.LIVEACTION_COMPLETED_STATES:
@@ -287,64 +326,69 @@ def _update_live_action_db(self, liveaction_id, status, result, context):
result=result,
context=context,
end_timestamp=end_timestamp,
- liveaction_db=liveaction_db
+ liveaction_db=liveaction_db,
)
return (liveaction_db, state_changed)
def _update_status(self, liveaction_id, status, result, context):
try:
- LOG.debug('Setting status: %s for liveaction: %s', status, liveaction_id)
+ LOG.debug("Setting status: %s for liveaction: %s", status, liveaction_id)
liveaction_db, state_changed = self._update_live_action_db(
- liveaction_id, status, result, context)
+ liveaction_id, status, result, context
+ )
except Exception as e:
LOG.exception(
- 'Cannot update liveaction '
- '(id: %s, status: %s, result: %s).' % (
- liveaction_id, status, result)
+ "Cannot update liveaction "
+ "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result)
)
raise e
try:
executions.update_execution(liveaction_db, publish=state_changed)
- extra = {'liveaction_db': liveaction_db}
- LOG.debug('Updated liveaction after run', extra=extra)
+ extra = {"liveaction_db": liveaction_db}
+ LOG.debug("Updated liveaction after run", extra=extra)
except Exception as e:
LOG.exception(
- 'Cannot update action execution for liveaction '
- '(id: %s, status: %s, result: %s).' % (
- liveaction_id, status, result)
+ "Cannot update action execution for liveaction "
+ "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result)
)
raise e
return liveaction_db
def _get_entry_point_abs_path(self, pack, entry_point):
- return content_utils.get_entry_point_abs_path(pack=pack, entry_point=entry_point)
+ return content_utils.get_entry_point_abs_path(
+ pack=pack, entry_point=entry_point
+ )
def _get_action_libs_abs_path(self, pack, entry_point):
- return content_utils.get_action_libs_abs_path(pack=pack, entry_point=entry_point)
+ return content_utils.get_action_libs_abs_path(
+ pack=pack, entry_point=entry_point
+ )
def _get_rerun_reference(self, context):
- execution_id = context.get('re-run', {}).get('ref')
+ execution_id = context.get("re-run", {}).get("ref")
return ActionExecution.get_by_id(execution_id) if execution_id else None
def _get_runner(self, runner_type_db, action_db, liveaction_db):
- resolved_entry_point = self._get_entry_point_abs_path(action_db.pack, action_db.entry_point)
- context = getattr(liveaction_db, 'context', dict())
- user = context.get('user', cfg.CONF.system_user.user)
+ resolved_entry_point = self._get_entry_point_abs_path(
+ action_db.pack, action_db.entry_point
+ )
+ context = getattr(liveaction_db, "context", dict())
+ user = context.get("user", cfg.CONF.system_user.user)
config = None
# Note: Right now configs are only supported by the Python runner actions
- if (runner_type_db.name == 'python-script' or
- runner_type_db.runner_module == 'python_runner'):
- LOG.debug('Loading config from pack for python runner.')
+ if (
+ runner_type_db.name == "python-script"
+ or runner_type_db.runner_module == "python_runner"
+ ):
+ LOG.debug("Loading config from pack for python runner.")
config_loader = ContentPackConfigLoader(pack_name=action_db.pack, user=user)
config = config_loader.get_config()
- runner = get_runner(
- name=runner_type_db.name,
- config=config)
+ runner = get_runner(name=runner_type_db.name, config=config)
# TODO: Pass those arguments to the constructor instead of late
# assignment, late assignment is awful
@@ -357,13 +401,16 @@ def _get_runner(self, runner_type_db, action_db, liveaction_db):
runner.execution_id = str(runner.execution.id)
runner.entry_point = resolved_entry_point
runner.context = context
- runner.callback = getattr(liveaction_db, 'callback', dict())
- runner.libs_dir_path = self._get_action_libs_abs_path(action_db.pack,
- action_db.entry_point)
+ runner.callback = getattr(liveaction_db, "callback", dict())
+ runner.libs_dir_path = self._get_action_libs_abs_path(
+ action_db.pack, action_db.entry_point
+ )
# For re-run, get the ActionExecutionDB in which the re-run is based on.
- rerun_ref_id = runner.context.get('re-run', {}).get('ref')
- runner.rerun_ex_ref = ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None
+ rerun_ref_id = runner.context.get("re-run", {}).get("ref")
+ runner.rerun_ex_ref = (
+ ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None
+ )
return runner
@@ -371,19 +418,20 @@ def _create_auth_token(self, context, action_db, liveaction_db):
if not context:
return None
- user = context.get('user', None)
+ user = context.get("user", None)
if not user:
return None
metadata = {
- 'service': 'actions_container',
- 'action_name': action_db.name,
- 'live_action_id': str(liveaction_db.id)
-
+ "service": "actions_container",
+ "action_name": action_db.name,
+ "live_action_id": str(liveaction_db.id),
}
ttl = cfg.CONF.auth.service_token_ttl
- token_db = access.create_token(username=user, ttl=ttl, metadata=metadata, service=True)
+ token_db = access.create_token(
+ username=user, ttl=ttl, metadata=metadata, service=True
+ )
return token_db
def _delete_auth_token(self, auth_token):
diff --git a/st2actions/st2actions/notifier/config.py b/st2actions/st2actions/notifier/config.py
index 6c0162f310..0322179bbc 100644
--- a/st2actions/st2actions/notifier/config.py
+++ b/st2actions/st2actions/notifier/config.py
@@ -27,8 +27,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -47,11 +50,13 @@ def _register_common_opts():
def _register_notifier_opts():
notifier_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.notifier.conf',
- help='Location of the logging configuration file.')
+ "logging",
+ default="/etc/st2/logging.notifier.conf",
+ help="Location of the logging configuration file.",
+ )
]
- CONF.register_opts(notifier_opts, group='notifier')
+ CONF.register_opts(notifier_opts, group="notifier")
register_opts()
diff --git a/st2actions/st2actions/notifier/notifier.py b/st2actions/st2actions/notifier/notifier.py
index 37db830e52..ea1a537733 100644
--- a/st2actions/st2actions/notifier/notifier.py
+++ b/st2actions/st2actions/notifier/notifier.py
@@ -42,22 +42,23 @@
from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX
from st2common.constants.action import ACTION_PARAMETERS_KV_PREFIX
from st2common.constants.action import ACTION_RESULTS_KV_PREFIX
-from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, SYSTEM_SCOPE, DATASTORE_PARENT_SCOPE
+from st2common.constants.keyvalue import (
+ FULL_SYSTEM_SCOPE,
+ SYSTEM_SCOPE,
+ DATASTORE_PARENT_SCOPE,
+)
from st2common.services.keyvalues import KeyValueLookup
from st2common.transport.queues import NOTIFIER_ACTIONUPDATE_WORK_QUEUE
from st2common.metrics.base import CounterWithTimer
from st2common.metrics.base import Timer
-__all__ = [
- 'Notifier',
- 'get_notifier'
-]
+__all__ = ["Notifier", "get_notifier"]
LOG = logging.getLogger(__name__)
# XXX: Fix this nasty positional dependency.
-ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0]
-NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1]
+ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0]
+NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1]
class Notifier(consumers.MessageHandler):
@@ -69,35 +70,40 @@ def __init__(self, connection, queues, trigger_dispatcher=None):
trigger_dispatcher = TriggerDispatcher(LOG)
self._trigger_dispatcher = trigger_dispatcher
self._notify_trigger = ResourceReference.to_string_reference(
- pack=NOTIFY_TRIGGER_TYPE['pack'],
- name=NOTIFY_TRIGGER_TYPE['name'])
+ pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"]
+ )
self._action_trigger = ResourceReference.to_string_reference(
- pack=ACTION_TRIGGER_TYPE['pack'],
- name=ACTION_TRIGGER_TYPE['name'])
+ pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"]
+ )
- @CounterWithTimer(key='notifier.action.executions')
+ @CounterWithTimer(key="notifier.action.executions")
def process(self, execution_db):
execution_id = str(execution_db.id)
- extra = {'execution': execution_db}
+ extra = {"execution": execution_db}
LOG.debug('Processing action execution "%s".', execution_id, extra=extra)
# Get the corresponding liveaction record.
- liveaction_db = LiveAction.get_by_id(execution_db.liveaction['id'])
+ liveaction_db = LiveAction.get_by_id(execution_db.liveaction["id"])
if execution_db.status in LIVEACTION_COMPLETED_STATES:
# If the action execution is executed under an orquesta workflow, policies for the
# action execution will be applied by the workflow engine. A policy may affect the
# final state of the action execution thereby impacting the state of the workflow.
- if not workflow_service.is_action_execution_under_workflow_context(execution_db):
- with CounterWithTimer(key='notifier.apply_post_run_policies'):
+ if not workflow_service.is_action_execution_under_workflow_context(
+ execution_db
+ ):
+ with CounterWithTimer(key="notifier.apply_post_run_policies"):
policy_service.apply_post_run_policies(liveaction_db)
if liveaction_db.notify:
- with CounterWithTimer(key='notifier.notify_trigger.post'):
- self._post_notify_triggers(liveaction_db=liveaction_db,
- execution_db=execution_db)
+ with CounterWithTimer(key="notifier.notify_trigger.post"):
+ self._post_notify_triggers(
+ liveaction_db=liveaction_db, execution_db=execution_db
+ )
- self._post_generic_trigger(liveaction_db=liveaction_db, execution_db=execution_db)
+ self._post_generic_trigger(
+ liveaction_db=liveaction_db, execution_db=execution_db
+ )
def _get_execution_for_liveaction(self, liveaction):
execution = ActionExecution.get(liveaction__id=str(liveaction.id))
@@ -108,39 +114,52 @@ def _get_execution_for_liveaction(self, liveaction):
return execution
def _post_notify_triggers(self, liveaction_db=None, execution_db=None):
- notify = getattr(liveaction_db, 'notify', None)
+ notify = getattr(liveaction_db, "notify", None)
if not notify:
return
if notify.on_complete:
self._post_notify_subsection_triggers(
- liveaction_db=liveaction_db, execution_db=execution_db,
+ liveaction_db=liveaction_db,
+ execution_db=execution_db,
notify_subsection=notify.on_complete,
- default_message_suffix='completed.')
+ default_message_suffix="completed.",
+ )
if liveaction_db.status == LIVEACTION_STATUS_SUCCEEDED and notify.on_success:
self._post_notify_subsection_triggers(
- liveaction_db=liveaction_db, execution_db=execution_db,
+ liveaction_db=liveaction_db,
+ execution_db=execution_db,
notify_subsection=notify.on_success,
- default_message_suffix='succeeded.')
+ default_message_suffix="succeeded.",
+ )
if liveaction_db.status in LIVEACTION_FAILED_STATES and notify.on_failure:
self._post_notify_subsection_triggers(
- liveaction_db=liveaction_db, execution_db=execution_db,
+ liveaction_db=liveaction_db,
+ execution_db=execution_db,
notify_subsection=notify.on_failure,
- default_message_suffix='failed.')
+ default_message_suffix="failed.",
+ )
- def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None,
- notify_subsection=None,
- default_message_suffix=None):
- routes = (getattr(notify_subsection, 'routes') or
- getattr(notify_subsection, 'channels', [])) or []
+ def _post_notify_subsection_triggers(
+ self,
+ liveaction_db=None,
+ execution_db=None,
+ notify_subsection=None,
+ default_message_suffix=None,
+ ):
+ routes = (
+ getattr(notify_subsection, "routes")
+ or getattr(notify_subsection, "channels", [])
+ ) or []
execution_id = str(execution_db.id)
if routes and len(routes) >= 1:
payload = {}
message = notify_subsection.message or (
- 'Action ' + liveaction_db.action + ' ' + default_message_suffix)
+ "Action " + liveaction_db.action + " " + default_message_suffix
+ )
data = notify_subsection.data or {}
jinja_context = self._build_jinja_context(
@@ -148,17 +167,18 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None
)
try:
- with Timer(key='notifier.transform_message'):
- message = self._transform_message(message=message,
- context=jinja_context)
+ with Timer(key="notifier.transform_message"):
+ message = self._transform_message(
+ message=message, context=jinja_context
+ )
except:
- LOG.exception('Failed (Jinja) transforming `message`.')
+ LOG.exception("Failed (Jinja) transforming `message`.")
try:
- with Timer(key='notifier.transform_data'):
+ with Timer(key="notifier.transform_data"):
data = self._transform_data(data=data, context=jinja_context)
except:
- LOG.exception('Failed (Jinja) transforming `data`.')
+ LOG.exception("Failed (Jinja) transforming `data`.")
# At this point convert result to a string. This restricts the rulesengines
# ability to introspect the result. On the other handle atleast a json usable
@@ -166,69 +186,82 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None
# to a string representation it uses str(...) which make it impossible to
# parse the result as json any longer.
# TODO: Use to_serializable_dict
- data['result'] = json.dumps(liveaction_db.result)
+ data["result"] = json.dumps(liveaction_db.result)
- payload['message'] = message
- payload['data'] = data
- payload['execution_id'] = execution_id
- payload['status'] = liveaction_db.status
- payload['start_timestamp'] = isotime.format(liveaction_db.start_timestamp)
+ payload["message"] = message
+ payload["data"] = data
+ payload["execution_id"] = execution_id
+ payload["status"] = liveaction_db.status
+ payload["start_timestamp"] = isotime.format(liveaction_db.start_timestamp)
try:
- payload['end_timestamp'] = isotime.format(liveaction_db.end_timestamp)
+ payload["end_timestamp"] = isotime.format(liveaction_db.end_timestamp)
except AttributeError:
# This can be raised if liveaction.end_timestamp is None, which is caused
# when policy cancels a request due to concurrency
# In this case, use datetime.now() instead
- payload['end_timestamp'] = isotime.format(datetime.utcnow())
+ payload["end_timestamp"] = isotime.format(datetime.utcnow())
- payload['action_ref'] = liveaction_db.action
- payload['runner_ref'] = self._get_runner_ref(liveaction_db.action)
+ payload["action_ref"] = liveaction_db.action
+ payload["runner_ref"] = self._get_runner_ref(liveaction_db.action)
trace_context = self._get_trace_context(execution_id=execution_id)
failed_routes = []
for route in routes:
try:
- payload['route'] = route
+ payload["route"] = route
# Deprecated. Only for backward compatibility reasons.
- payload['channel'] = route
- LOG.debug('POSTing %s for %s. Payload - %s.', NOTIFY_TRIGGER_TYPE['name'],
- liveaction_db.id, payload)
-
- with CounterWithTimer(key='notifier.notify_trigger.dispatch'):
- self._trigger_dispatcher.dispatch(self._notify_trigger, payload=payload,
- trace_context=trace_context)
+ payload["channel"] = route
+ LOG.debug(
+ "POSTing %s for %s. Payload - %s.",
+ NOTIFY_TRIGGER_TYPE["name"],
+ liveaction_db.id,
+ payload,
+ )
+
+ with CounterWithTimer(key="notifier.notify_trigger.dispatch"):
+ self._trigger_dispatcher.dispatch(
+ self._notify_trigger,
+ payload=payload,
+ trace_context=trace_context,
+ )
except:
failed_routes.append(route)
if len(failed_routes) > 0:
- raise Exception('Failed notifications to routes: %s' % ', '.join(failed_routes))
+ raise Exception(
+ "Failed notifications to routes: %s" % ", ".join(failed_routes)
+ )
def _build_jinja_context(self, liveaction_db, execution_db):
context = {}
- context.update({
- DATASTORE_PARENT_SCOPE: {
- SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ context.update(
+ {
+ DATASTORE_PARENT_SCOPE: {
+ SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ }
}
- })
+ )
context.update({ACTION_PARAMETERS_KV_PREFIX: liveaction_db.parameters})
context.update({ACTION_CONTEXT_KV_PREFIX: liveaction_db.context})
context.update({ACTION_RESULTS_KV_PREFIX: execution_db.result})
return context
def _transform_message(self, message, context=None):
- mapping = {'message': message}
+ mapping = {"message": message}
context = context or {}
- return (jinja_utils.render_values(mapping=mapping, context=context)).get('message',
- message)
+ return (jinja_utils.render_values(mapping=mapping, context=context)).get(
+ "message", message
+ )
def _transform_data(self, data, context=None):
return jinja_utils.render_values(mapping=data, context=context)
def _get_trace_context(self, execution_id):
trace_db = trace_service.get_trace_db_by_action_execution(
- action_execution_id=execution_id)
+ action_execution_id=execution_id
+ )
if trace_db:
return TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag)
# If no trace_context is found then do not create a new one here. If necessary
@@ -237,38 +270,48 @@ def _get_trace_context(self, execution_id):
def _post_generic_trigger(self, liveaction_db=None, execution_db=None):
if not cfg.CONF.action_sensor.enable:
- LOG.debug('Action trigger is disabled, skipping trigger dispatch...')
+ LOG.debug("Action trigger is disabled, skipping trigger dispatch...")
return
execution_id = str(execution_db.id)
- extra = {'execution': execution_db}
+ extra = {"execution": execution_db}
target_statuses = cfg.CONF.action_sensor.emit_when
if execution_db.status not in target_statuses:
msg = 'Skip action execution "%s" because state "%s" is not in %s'
- LOG.debug(msg % (execution_id, execution_db.status, target_statuses), extra=extra)
+ LOG.debug(
+ msg % (execution_id, execution_db.status, target_statuses), extra=extra
+ )
return
- with CounterWithTimer(key='notifier.generic_trigger.post'):
- payload = {'execution_id': execution_id,
- 'status': liveaction_db.status,
- 'start_timestamp': str(liveaction_db.start_timestamp),
- # deprecate 'action_name' at some point and switch to 'action_ref'
- 'action_name': liveaction_db.action,
- 'action_ref': liveaction_db.action,
- 'runner_ref': self._get_runner_ref(liveaction_db.action),
- 'parameters': liveaction_db.get_masked_parameters(),
- 'result': liveaction_db.result}
+ with CounterWithTimer(key="notifier.generic_trigger.post"):
+ payload = {
+ "execution_id": execution_id,
+ "status": liveaction_db.status,
+ "start_timestamp": str(liveaction_db.start_timestamp),
+ # deprecate 'action_name' at some point and switch to 'action_ref'
+ "action_name": liveaction_db.action,
+ "action_ref": liveaction_db.action,
+ "runner_ref": self._get_runner_ref(liveaction_db.action),
+ "parameters": liveaction_db.get_masked_parameters(),
+ "result": liveaction_db.result,
+ }
# Use execution_id to extract trace rather than liveaction. execution_id
# will look-up an exact TraceDB while liveaction depending on context
# may not end up going to the DB.
trace_context = self._get_trace_context(execution_id=execution_id)
- LOG.debug('POSTing %s for %s. Payload - %s. TraceContext - %s',
- ACTION_TRIGGER_TYPE['name'], liveaction_db.id, payload, trace_context)
+ LOG.debug(
+ "POSTing %s for %s. Payload - %s. TraceContext - %s",
+ ACTION_TRIGGER_TYPE["name"],
+ liveaction_db.id,
+ payload,
+ trace_context,
+ )
- with CounterWithTimer(key='notifier.generic_trigger.dispatch'):
- self._trigger_dispatcher.dispatch(self._action_trigger, payload=payload,
- trace_context=trace_context)
+ with CounterWithTimer(key="notifier.generic_trigger.dispatch"):
+ self._trigger_dispatcher.dispatch(
+ self._action_trigger, payload=payload, trace_context=trace_context
+ )
def _get_runner_ref(self, action_ref):
"""
@@ -277,10 +320,13 @@ def _get_runner_ref(self, action_ref):
:rtype: ``str``
"""
action = Action.get_by_ref(action_ref)
- return action['runner_type']['name']
+ return action["runner_type"]["name"]
def get_notifier():
with transport_utils.get_connection() as conn:
- return Notifier(conn, [NOTIFIER_ACTIONUPDATE_WORK_QUEUE],
- trigger_dispatcher=TriggerDispatcher(LOG))
+ return Notifier(
+ conn,
+ [NOTIFIER_ACTIONUPDATE_WORK_QUEUE],
+ trigger_dispatcher=TriggerDispatcher(LOG),
+ )
diff --git a/st2actions/st2actions/policies/concurrency.py b/st2actions/st2actions/policies/concurrency.py
index 4f98b093c7..cf47ed0b69 100644
--- a/st2actions/st2actions/policies/concurrency.py
+++ b/st2actions/st2actions/policies/concurrency.py
@@ -22,53 +22,64 @@
from st2common.services import action as action_service
-__all__ = [
- 'ConcurrencyApplicator'
-]
+__all__ = ["ConcurrencyApplicator"]
LOG = logging.getLogger(__name__)
class ConcurrencyApplicator(BaseConcurrencyApplicator):
-
- def __init__(self, policy_ref, policy_type, threshold=0, action='delay'):
- super(ConcurrencyApplicator, self).__init__(policy_ref=policy_ref, policy_type=policy_type,
- threshold=threshold,
- action=action)
+ def __init__(self, policy_ref, policy_type, threshold=0, action="delay"):
+ super(ConcurrencyApplicator, self).__init__(
+ policy_ref=policy_ref,
+ policy_type=policy_type,
+ threshold=threshold,
+ action=action,
+ )
def _get_lock_uid(self, target):
- values = {'policy_type': self._policy_type, 'action': target.action}
+ values = {"policy_type": self._policy_type, "action": target.action}
return self._get_lock_name(values=values)
def _apply_before(self, target):
# Get the count of scheduled instances of the action.
scheduled = action_access.LiveAction.count(
- action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED)
+ action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED
+ )
# Get the count of running instances of the action.
running = action_access.LiveAction.count(
- action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING)
+ action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING
+ )
count = scheduled + running
# Mark the execution as scheduled if threshold is not reached or delayed otherwise.
if count < self.threshold:
- LOG.debug('There are %s instances of %s in scheduled or running status. '
- 'Threshold of %s is not reached. Action execution will be scheduled.',
- count, target.action, self._policy_ref)
+ LOG.debug(
+ "There are %s instances of %s in scheduled or running status. "
+ "Threshold of %s is not reached. Action execution will be scheduled.",
+ count,
+ target.action,
+ self._policy_ref,
+ )
status = action_constants.LIVEACTION_STATUS_REQUESTED
else:
- action = 'delayed' if self.policy_action == 'delay' else 'canceled'
- LOG.debug('There are %s instances of %s in scheduled or running status. '
- 'Threshold of %s is reached. Action execution will be %s.',
- count, target.action, self._policy_ref, action)
+ action = "delayed" if self.policy_action == "delay" else "canceled"
+ LOG.debug(
+ "There are %s instances of %s in scheduled or running status. "
+ "Threshold of %s is reached. Action execution will be %s.",
+ count,
+ target.action,
+ self._policy_ref,
+ action,
+ )
status = self._get_status_for_policy_action(action=self.policy_action)
# Update the status in the database. Publish status for cancellation so the
# appropriate runner can cancel the execution. Other statuses are not published
# because they will be picked up by the worker(s) to be processed again,
# leading to duplicate action executions.
- publish = (status == action_constants.LIVEACTION_STATUS_CANCELING)
+ publish = status == action_constants.LIVEACTION_STATUS_CANCELING
target = action_service.update_status(target, status, publish=publish)
return target
@@ -78,13 +89,17 @@ def apply_before(self, target):
valid_states = [
action_constants.LIVEACTION_STATUS_REQUESTED,
- action_constants.LIVEACTION_STATUS_DELAYED
+ action_constants.LIVEACTION_STATUS_DELAYED,
]
# Exit if target not in valid state.
if target.status not in valid_states:
- LOG.debug('The live action is not in a valid state therefore the policy '
- '"%s" cannot be applied. %s', self._policy_ref, target)
+ LOG.debug(
+ "The live action is not in a valid state therefore the policy "
+ '"%s" cannot be applied. %s',
+ self._policy_ref,
+ target,
+ )
return target
target = self._apply_before(target)
diff --git a/st2actions/st2actions/policies/concurrency_by_attr.py b/st2actions/st2actions/policies/concurrency_by_attr.py
index 7c9ee1dabc..ea3f9cd421 100644
--- a/st2actions/st2actions/policies/concurrency_by_attr.py
+++ b/st2actions/st2actions/policies/concurrency_by_attr.py
@@ -25,38 +25,41 @@
from st2common.policies.concurrency import BaseConcurrencyApplicator
from st2common.services import coordination
-__all__ = [
- 'ConcurrencyByAttributeApplicator'
-]
+__all__ = ["ConcurrencyByAttributeApplicator"]
LOG = logging.getLogger(__name__)
class ConcurrencyByAttributeApplicator(BaseConcurrencyApplicator):
-
- def __init__(self, policy_ref, policy_type, threshold=0, action='delay', attributes=None):
- super(ConcurrencyByAttributeApplicator, self).__init__(policy_ref=policy_ref,
- policy_type=policy_type,
- threshold=threshold,
- action=action)
+ def __init__(
+ self, policy_ref, policy_type, threshold=0, action="delay", attributes=None
+ ):
+ super(ConcurrencyByAttributeApplicator, self).__init__(
+ policy_ref=policy_ref,
+ policy_type=policy_type,
+ threshold=threshold,
+ action=action,
+ )
self.attributes = attributes or []
def _get_lock_uid(self, target):
meta = {
- 'policy_type': self._policy_type,
- 'action': target.action,
- 'attributes': self.attributes
+ "policy_type": self._policy_type,
+ "action": target.action,
+ "attributes": self.attributes,
}
return json.dumps(meta)
def _get_filters(self, target):
- filters = {('parameters__%s' % k): v
- for k, v in six.iteritems(target.parameters)
- if k in self.attributes}
+ filters = {
+ ("parameters__%s" % k): v
+ for k, v in six.iteritems(target.parameters)
+ if k in self.attributes
+ }
- filters['action'] = target.action
- filters['status'] = None
+ filters["action"] = target.action
+ filters["status"] = None
return filters
@@ -65,54 +68,71 @@ def _apply_before(self, target):
filters = self._get_filters(target)
# Get the count of scheduled instances of the action.
- filters['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED
+ filters["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED
scheduled = action_access.LiveAction.count(**filters)
# Get the count of running instances of the action.
- filters['status'] = action_constants.LIVEACTION_STATUS_RUNNING
+ filters["status"] = action_constants.LIVEACTION_STATUS_RUNNING
running = action_access.LiveAction.count(**filters)
count = scheduled + running
# Mark the execution as scheduled if threshold is not reached or delayed otherwise.
if count < self.threshold:
- LOG.debug('There are %s instances of %s in scheduled or running status. '
- 'Threshold of %s is not reached. Action execution will be scheduled.',
- count, target.action, self._policy_ref)
+ LOG.debug(
+ "There are %s instances of %s in scheduled or running status. "
+ "Threshold of %s is not reached. Action execution will be scheduled.",
+ count,
+ target.action,
+ self._policy_ref,
+ )
status = action_constants.LIVEACTION_STATUS_REQUESTED
else:
- action = 'delayed' if self.policy_action == 'delay' else 'canceled'
- LOG.debug('There are %s instances of %s in scheduled or running status. '
- 'Threshold of %s is reached. Action execution will be %s.',
- count, target.action, self._policy_ref, action)
+ action = "delayed" if self.policy_action == "delay" else "canceled"
+ LOG.debug(
+ "There are %s instances of %s in scheduled or running status. "
+ "Threshold of %s is reached. Action execution will be %s.",
+ count,
+ target.action,
+ self._policy_ref,
+ action,
+ )
status = self._get_status_for_policy_action(action=self.policy_action)
# Update the status in the database. Publish status for cancellation so the
# appropriate runner can cancel the execution. Other statuses are not published
# because they will be picked up by the worker(s) to be processed again,
# leading to duplicate action executions.
- publish = (status == action_constants.LIVEACTION_STATUS_CANCELING)
+ publish = status == action_constants.LIVEACTION_STATUS_CANCELING
target = action_service.update_status(target, status, publish=publish)
return target
def apply_before(self, target):
- target = super(ConcurrencyByAttributeApplicator, self).apply_before(target=target)
+ target = super(ConcurrencyByAttributeApplicator, self).apply_before(
+ target=target
+ )
valid_states = [
action_constants.LIVEACTION_STATUS_REQUESTED,
- action_constants.LIVEACTION_STATUS_DELAYED
+ action_constants.LIVEACTION_STATUS_DELAYED,
]
# Exit if target not in valid state.
if target.status not in valid_states:
- LOG.debug('The live action is not schedulable therefore the policy '
- '"%s" cannot be applied. %s', self._policy_ref, target)
+ LOG.debug(
+ "The live action is not schedulable therefore the policy "
+ '"%s" cannot be applied. %s',
+ self._policy_ref,
+ target,
+ )
return target
# Warn users that the coordination service is not configured.
if not coordination.configured():
- LOG.warn('Coordination service is not configured. Policy enforcement is best effort.')
+ LOG.warn(
+ "Coordination service is not configured. Policy enforcement is best effort."
+ )
target = self._apply_before(target)
diff --git a/st2actions/st2actions/policies/retry.py b/st2actions/st2actions/policies/retry.py
index 85775d4f13..abbbd70453 100644
--- a/st2actions/st2actions/policies/retry.py
+++ b/st2actions/st2actions/policies/retry.py
@@ -27,22 +27,16 @@
from st2common.util.enum import Enum
from st2common.policies.base import ResourcePolicyApplicator
-__all__ = [
- 'RetryOnPolicy',
- 'ExecutionRetryPolicyApplicator'
-]
+__all__ = ["RetryOnPolicy", "ExecutionRetryPolicyApplicator"]
LOG = logging.getLogger(__name__)
-VALID_RETRY_STATUSES = [
- LIVEACTION_STATUS_FAILED,
- LIVEACTION_STATUS_TIMED_OUT
-]
+VALID_RETRY_STATUSES = [LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT]
class RetryOnPolicy(Enum):
- FAILURE = 'failure' # Retry on execution failure
- TIMEOUT = 'timeout' # Retry on execution timeout
+ FAILURE = "failure" # Retry on execution failure
+ TIMEOUT = "timeout" # Retry on execution timeout
class ExecutionRetryPolicyApplicator(ResourcePolicyApplicator):
@@ -57,8 +51,9 @@ def __init__(self, policy_ref, policy_type, retry_on, max_retry_count=2, delay=0
:param delay: How long to wait before retrying an execution.
:type delay: ``float``
"""
- super(ExecutionRetryPolicyApplicator, self).__init__(policy_ref=policy_ref,
- policy_type=policy_type)
+ super(ExecutionRetryPolicyApplicator, self).__init__(
+ policy_ref=policy_ref, policy_type=policy_type
+ )
self.retry_on = retry_on
self.max_retry_count = max_retry_count
@@ -71,27 +66,33 @@ def apply_after(self, target):
if self._is_live_action_part_of_workflow_action(live_action_db):
LOG.warning(
- 'Retry cannot be applied to this liveaction because it is executed under a '
- 'workflow. Use workflow specific retry functionality where applicable. %s',
- live_action_db
+ "Retry cannot be applied to this liveaction because it is executed under a "
+ "workflow. Use workflow specific retry functionality where applicable. %s",
+ live_action_db,
)
return target
retry_count = self._get_live_action_retry_count(live_action_db=live_action_db)
- extra = {'live_action_db': live_action_db, 'policy_ref': self._policy_ref,
- 'retry_on': self.retry_on, 'max_retry_count': self.max_retry_count,
- 'current_retry_count': retry_count}
+ extra = {
+ "live_action_db": live_action_db,
+ "policy_ref": self._policy_ref,
+ "retry_on": self.retry_on,
+ "max_retry_count": self.max_retry_count,
+ "current_retry_count": retry_count,
+ }
if live_action_db.status not in VALID_RETRY_STATUSES:
# Currently we only support retrying on failed action
- LOG.debug('Liveaction not in a valid retry state, not checking retry policy',
- extra=extra)
+ LOG.debug(
+ "Liveaction not in a valid retry state, not checking retry policy",
+ extra=extra,
+ )
return target
if (retry_count + 1) > self.max_retry_count:
- LOG.info('Maximum retry count has been reached, not retrying', extra=extra)
+ LOG.info("Maximum retry count has been reached, not retrying", extra=extra)
return target
has_failed = live_action_db.status == LIVEACTION_STATUS_FAILED
@@ -100,34 +101,50 @@ def apply_after(self, target):
# TODO: This is not crash and restart safe, switch to using "DELAYED"
# status
if self.delay > 0:
- re_run_live_action = functools.partial(eventlet.spawn_after, self.delay,
- self._re_run_live_action,
- live_action_db=live_action_db)
+ re_run_live_action = functools.partial(
+ eventlet.spawn_after,
+ self.delay,
+ self._re_run_live_action,
+ live_action_db=live_action_db,
+ )
else:
# Even if delay is 0, use a small delay (0.1 seconds) to prevent busy wait
- re_run_live_action = functools.partial(eventlet.spawn_after, 0.1,
- self._re_run_live_action,
- live_action_db=live_action_db)
+ re_run_live_action = functools.partial(
+ eventlet.spawn_after,
+ 0.1,
+ self._re_run_live_action,
+ live_action_db=live_action_db,
+ )
- re_run_live_action = functools.partial(self._re_run_live_action,
- live_action_db=live_action_db)
+ re_run_live_action = functools.partial(
+ self._re_run_live_action, live_action_db=live_action_db
+ )
if has_failed and self.retry_on == RetryOnPolicy.FAILURE:
- extra['failure'] = True
- LOG.info('Policy matched (failure), retrying action execution in %s seconds...' %
- (self.delay), extra=extra)
+ extra["failure"] = True
+ LOG.info(
+ "Policy matched (failure), retrying action execution in %s seconds..."
+ % (self.delay),
+ extra=extra,
+ )
re_run_live_action()
return target
if has_timed_out and self.retry_on == RetryOnPolicy.TIMEOUT:
- extra['timeout'] = True
- LOG.info('Policy matched (timeout), retrying action execution in %s seconds...' %
- (self.delay), extra=extra)
+ extra["timeout"] = True
+ LOG.info(
+ "Policy matched (timeout), retrying action execution in %s seconds..."
+ % (self.delay),
+ extra=extra,
+ )
re_run_live_action()
return target
- LOG.info('Invalid status "%s" for live action "%s", wont retry' %
- (live_action_db.status, str(live_action_db.id)), extra=extra)
+ LOG.info(
+ 'Invalid status "%s" for live action "%s", wont retry'
+ % (live_action_db.status, str(live_action_db.id)),
+ extra=extra,
+ )
return target
@@ -137,9 +154,9 @@ def _is_live_action_part_of_workflow_action(self, live_action_db):
:rtype: ``dict``
"""
- context = getattr(live_action_db, 'context', {})
- parent = context.get('parent', {})
- is_wf_action = (parent is not None and parent != {})
+ context = getattr(live_action_db, "context", {})
+ parent = context.get("parent", {})
+ is_wf_action = parent is not None and parent != {}
return is_wf_action
@@ -151,8 +168,8 @@ def _get_live_action_retry_count(self, live_action_db):
"""
# TODO: Ideally we would store retry_count in zookeeper or similar and use locking so we
# can run multiple instances of st2notififer
- context = getattr(live_action_db, 'context', {})
- retry_count = context.get('policies', {}).get('retry', {}).get('retry_count', 0)
+ context = getattr(live_action_db, "context", {})
+ retry_count = context.get("policies", {}).get("retry", {}).get("retry_count", 0)
return retry_count
@@ -160,17 +177,18 @@ def _re_run_live_action(self, live_action_db):
retry_count = self._get_live_action_retry_count(live_action_db=live_action_db)
# Add additional policy specific info to the context
- context = getattr(live_action_db, 'context', {})
+ context = getattr(live_action_db, "context", {})
new_context = copy.deepcopy(context)
- new_context['policies'] = {}
- new_context['policies']['retry'] = {
- 'applied_policy': self._policy_ref,
- 'retry_count': (retry_count + 1),
- 'retried_liveaction_id': str(live_action_db.id)
+ new_context["policies"] = {}
+ new_context["policies"]["retry"] = {
+ "applied_policy": self._policy_ref,
+ "retry_count": (retry_count + 1),
+ "retried_liveaction_id": str(live_action_db.id),
}
action_ref = live_action_db.action
parameters = live_action_db.parameters
- new_live_action_db = LiveActionDB(action=action_ref, parameters=parameters,
- context=new_context)
+ new_live_action_db = LiveActionDB(
+ action=action_ref, parameters=parameters, context=new_context
+ )
_, action_execution_db = action_services.request(new_live_action_db)
return action_execution_db
diff --git a/st2actions/st2actions/runners/pythonrunner.py b/st2actions/st2actions/runners/pythonrunner.py
index 215edd83c8..33a3f3ec39 100644
--- a/st2actions/st2actions/runners/pythonrunner.py
+++ b/st2actions/st2actions/runners/pythonrunner.py
@@ -16,6 +16,4 @@
from __future__ import absolute_import
from st2common.runners.base_action import Action
-__all__ = [
- 'Action'
-]
+__all__ = ["Action"]
diff --git a/st2actions/st2actions/scheduler/config.py b/st2actions/st2actions/scheduler/config.py
index a991403a9b..8df6c3ff3e 100644
--- a/st2actions/st2actions/scheduler/config.py
+++ b/st2actions/st2actions/scheduler/config.py
@@ -27,8 +27,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=sys_constants.VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=sys_constants.VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -47,36 +50,48 @@ def _register_common_opts():
def _register_service_opts():
scheduler_opts = [
cfg.StrOpt(
- 'logging',
- default='/etc/st2/logging.scheduler.conf',
- help='Location of the logging configuration file.'
+ "logging",
+ default="/etc/st2/logging.scheduler.conf",
+ help="Location of the logging configuration file.",
),
cfg.FloatOpt(
- 'execution_scheduling_timeout_threshold_min', default=1,
- help='How long GC to search back in minutes for orphaned scheduled actions'),
+ "execution_scheduling_timeout_threshold_min",
+ default=1,
+ help="How long GC to search back in minutes for orphaned scheduled actions",
+ ),
cfg.IntOpt(
- 'pool_size', default=10,
- help='The size of the pool used by the scheduler for scheduling executions.'),
+ "pool_size",
+ default=10,
+ help="The size of the pool used by the scheduler for scheduling executions.",
+ ),
cfg.FloatOpt(
- 'sleep_interval', default=0.10,
- help='How long (in seconds) to sleep between each action scheduler main loop run '
- 'interval.'),
+ "sleep_interval",
+ default=0.10,
+ help="How long (in seconds) to sleep between each action scheduler main loop run "
+ "interval.",
+ ),
cfg.FloatOpt(
- 'gc_interval', default=10,
- help='How often (in seconds) to look for zombie execution requests before rescheduling '
- 'them.'),
+ "gc_interval",
+ default=10,
+ help="How often (in seconds) to look for zombie execution requests before rescheduling "
+ "them.",
+ ),
cfg.IntOpt(
- 'retry_max_attempt', default=10,
- help='The maximum number of attempts that the scheduler retries on error.'),
+ "retry_max_attempt",
+ default=10,
+ help="The maximum number of attempts that the scheduler retries on error.",
+ ),
cfg.IntOpt(
- 'retry_wait_msec', default=3000,
- help='The number of milliseconds to wait in between retries.')
+ "retry_wait_msec",
+ default=3000,
+ help="The number of milliseconds to wait in between retries.",
+ ),
]
- cfg.CONF.register_opts(scheduler_opts, group='scheduler')
+ cfg.CONF.register_opts(scheduler_opts, group="scheduler")
try:
register_opts()
except cfg.DuplicateOptError:
- LOG.exception('The scheduler configuration options are already parsed and loaded.')
+ LOG.exception("The scheduler configuration options are already parsed and loaded.")
diff --git a/st2actions/st2actions/scheduler/entrypoint.py b/st2actions/st2actions/scheduler/entrypoint.py
index ee8a76f2d1..14d816ded3 100644
--- a/st2actions/st2actions/scheduler/entrypoint.py
+++ b/st2actions/st2actions/scheduler/entrypoint.py
@@ -29,10 +29,7 @@
from st2common.persistence.execution_queue import ActionExecutionSchedulingQueue
from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB
-__all__ = [
- 'SchedulerEntrypoint',
- 'get_scheduler_entrypoint'
-]
+__all__ = ["SchedulerEntrypoint", "get_scheduler_entrypoint"]
LOG = logging.getLogger(__name__)
@@ -43,6 +40,7 @@ class SchedulerEntrypoint(consumers.MessageHandler):
SchedulerEntrypoint subscribes to the Action scheduler request queue and places new Live
Actions into the scheduling queue collection for scheduling on action runners.
"""
+
message_type = LiveActionDB
def process(self, request):
@@ -53,18 +51,25 @@ def process(self, request):
:type request: ``st2common.models.db.liveaction.LiveActionDB``
"""
if request.status != action_constants.LIVEACTION_STATUS_REQUESTED:
- LOG.info('%s is ignoring %s (id=%s) with "%s" status.',
- self.__class__.__name__, type(request), request.id, request.status)
+ LOG.info(
+ '%s is ignoring %s (id=%s) with "%s" status.',
+ self.__class__.__name__,
+ type(request),
+ request.id,
+ request.status,
+ )
return
try:
liveaction_db = action_utils.get_liveaction_by_id(str(request.id))
except StackStormDBObjectNotFoundError:
- LOG.exception('Failed to find liveaction %s in the database.', str(request.id))
+ LOG.exception(
+ "Failed to find liveaction %s in the database.", str(request.id)
+ )
raise
query = {
- 'liveaction_id': str(liveaction_db.id),
+ "liveaction_id": str(liveaction_db.id),
}
queued_requests = ActionExecutionSchedulingQueue.query(**query)
@@ -75,17 +80,16 @@ def process(self, request):
if liveaction_db.delay and liveaction_db.delay > 0:
liveaction_db = action_service.update_status(
- liveaction_db,
- action_constants.LIVEACTION_STATUS_DELAYED,
- publish=False
+ liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False
)
execution_queue_item_db = self._create_execution_queue_item_db_from_liveaction(
- liveaction_db,
- delay=liveaction_db.delay
+ liveaction_db, delay=liveaction_db.delay
)
- ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False)
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
return execution_queue_item_db
@@ -99,9 +103,8 @@ def _create_execution_queue_item_db_from_liveaction(self, liveaction, delay=None
execution_queue_item_db.action_execution_id = str(execution.id)
execution_queue_item_db.liveaction_id = str(liveaction.id)
execution_queue_item_db.original_start_timestamp = liveaction.start_timestamp
- execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time(
- liveaction.start_timestamp,
- delay or 0
+ execution_queue_item_db.scheduled_start_timestamp = (
+ date.append_milliseconds_to_time(liveaction.start_timestamp, delay or 0)
)
execution_queue_item_db.delay = delay
diff --git a/st2actions/st2actions/scheduler/handler.py b/st2actions/st2actions/scheduler/handler.py
index 76d54066a9..e39871db3e 100644
--- a/st2actions/st2actions/scheduler/handler.py
+++ b/st2actions/st2actions/scheduler/handler.py
@@ -37,10 +37,7 @@
from st2common.metrics import base as metrics
from st2common.exceptions import db as db_exc
-__all__ = [
- 'ActionExecutionSchedulingQueueHandler',
- 'get_handler'
-]
+__all__ = ["ActionExecutionSchedulingQueueHandler", "get_handler"]
LOG = logging.getLogger(__name__)
@@ -61,14 +58,15 @@ def __init__(self):
# fast (< 5 seconds). If an item is still being marked as processing it likely indicates
# that the scheduler process which was processing that item crashed or similar so we need
# to mark it as "handling=False" so some other scheduler process can pick it up.
- self._execution_scheduling_timeout_threshold_ms = \
+ self._execution_scheduling_timeout_threshold_ms = (
cfg.CONF.scheduler.execution_scheduling_timeout_threshold_min * 60 * 1000
+ )
self._coordinator = coordination_service.get_coordinator(start_heart=True)
self._main_thread = None
self._cleanup_thread = None
def run(self):
- LOG.debug('Starting scheduler handler...')
+ LOG.debug("Starting scheduler handler...")
while not self._shutdown:
eventlet.greenthread.sleep(cfg.CONF.scheduler.sleep_interval)
@@ -77,7 +75,8 @@ def run(self):
@retrying.retry(
retry_on_exception=service_utils.retry_on_exceptions,
stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt,
- wait_fixed=cfg.CONF.scheduler.retry_wait_msec)
+ wait_fixed=cfg.CONF.scheduler.retry_wait_msec,
+ )
def process(self):
execution_queue_item_db = self._get_next_execution()
@@ -85,7 +84,7 @@ def process(self):
self._pool.spawn(self._handle_execution, execution_queue_item_db)
def cleanup(self):
- LOG.debug('Starting scheduler garbage collection...')
+ LOG.debug("Starting scheduler garbage collection...")
while not self._shutdown:
eventlet.greenthread.sleep(cfg.CONF.scheduler.gc_interval)
@@ -99,11 +98,11 @@ def _reset_handling_flag(self):
False so other scheduler can pick it up.
"""
query = {
- 'scheduled_start_timestamp__lte': date.append_milliseconds_to_time(
+ "scheduled_start_timestamp__lte": date.append_milliseconds_to_time(
date.get_datetime_utc_now(),
- -self._execution_scheduling_timeout_threshold_ms
+ -self._execution_scheduling_timeout_threshold_ms,
),
- 'handling': True
+ "handling": True,
}
execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**query) or []
@@ -112,17 +111,19 @@ def _reset_handling_flag(self):
execution_queue_item_db.handling = False
try:
- ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False)
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
LOG.info(
'[%s] Removing lock for orphaned execution queue item "%s".',
execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id)
+ str(execution_queue_item_db.id),
)
except db_exc.StackStormDBObjectWriteConflictError:
LOG.info(
'[%s] Execution queue item "%s" updated during garbage collection.',
execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id)
+ str(execution_queue_item_db.id),
)
# TODO: Remove this function for fixing missing action_execution_id in v3.2.
@@ -132,7 +133,9 @@ def _fix_missing_action_execution_id(self):
"""
Auto-populate the action_execution_id in ActionExecutionSchedulingQueue if empty.
"""
- for entry in ActionExecutionSchedulingQueue.query(action_execution_id__in=['', None]):
+ for entry in ActionExecutionSchedulingQueue.query(
+ action_execution_id__in=["", None]
+ ):
execution_db = ActionExecution.get(liveaction__id=entry.liveaction_id)
if not execution_db:
@@ -152,23 +155,27 @@ def _cleanup_policy_delayed(self):
moved back into requested status.
"""
- policy_delayed_liveaction_dbs = LiveAction.query(status='policy-delayed') or []
+ policy_delayed_liveaction_dbs = LiveAction.query(status="policy-delayed") or []
for liveaction_db in policy_delayed_liveaction_dbs:
- ex_que_qry = {'liveaction_id': str(liveaction_db.id), 'handling': False}
- execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**ex_que_qry) or []
+ ex_que_qry = {"liveaction_id": str(liveaction_db.id), "handling": False}
+ execution_queue_item_dbs = (
+ ActionExecutionSchedulingQueue.query(**ex_que_qry) or []
+ )
for execution_queue_item_db in execution_queue_item_dbs:
# Mark the entry in the scheduling queue for handling.
try:
execution_queue_item_db.handling = True
- execution_queue_item_db = ActionExecutionSchedulingQueue.add_or_update(
- execution_queue_item_db, publish=False)
+ execution_queue_item_db = (
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
+ )
except db_exc.StackStormDBObjectWriteConflictError:
- msg = (
- '[%s] Item "%s" is currently being processed by another scheduler.' %
- (execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id))
+ msg = '[%s] Item "%s" is currently being processed by another scheduler.' % (
+ execution_queue_item_db.action_execution_id,
+ str(execution_queue_item_db.id),
)
LOG.error(msg)
raise Exception(msg)
@@ -177,7 +184,7 @@ def _cleanup_policy_delayed(self):
LOG.info(
'[%s] Removing policy-delayed entry "%s" from the scheduling queue.',
execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id)
+ str(execution_queue_item_db.id),
)
ActionExecutionSchedulingQueue.delete(execution_queue_item_db)
@@ -186,18 +193,20 @@ def _cleanup_policy_delayed(self):
LOG.info(
'[%s] Removing policy-delayed entry "%s" from the scheduling queue.',
execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id)
+ str(execution_queue_item_db.id),
)
liveaction_db = action_service.update_status(
- liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED)
+ liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
execution_service.update_execution(liveaction_db)
@retrying.retry(
retry_on_exception=service_utils.retry_on_exceptions,
stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt,
- wait_fixed=cfg.CONF.scheduler.retry_wait_msec)
+ wait_fixed=cfg.CONF.scheduler.retry_wait_msec,
+ )
def _handle_garbage_collection(self):
self._reset_handling_flag()
@@ -212,13 +221,10 @@ def _get_next_execution(self):
due to a policy.
"""
query = {
- 'scheduled_start_timestamp__lte': date.get_datetime_utc_now(),
- 'handling': False,
- 'limit': 1,
- 'order_by': [
- '+scheduled_start_timestamp',
- '+original_start_timestamp'
- ]
+ "scheduled_start_timestamp__lte": date.get_datetime_utc_now(),
+ "handling": False,
+ "limit": 1,
+ "order_by": ["+scheduled_start_timestamp", "+original_start_timestamp"],
}
execution_queue_item_db = ActionExecutionSchedulingQueue.query(**query).first()
@@ -229,45 +235,52 @@ def _get_next_execution(self):
# Mark that this scheduler process is currently handling (processing) that request
# NOTE: This operation is atomic (CAS)
msg = '[%s] Retrieved item "%s" from scheduling queue.'
- LOG.info(msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id)
+ LOG.info(
+ msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id
+ )
execution_queue_item_db.handling = True
try:
- ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False)
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
return execution_queue_item_db
except db_exc.StackStormDBObjectWriteConflictError:
LOG.info(
'[%s] Item "%s" is already handled by another scheduler.',
execution_queue_item_db.action_execution_id,
- str(execution_queue_item_db.id)
+ str(execution_queue_item_db.id),
)
return None
- @metrics.CounterWithTimer(key='scheduler.handle_execution')
+ @metrics.CounterWithTimer(key="scheduler.handle_execution")
def _handle_execution(self, execution_queue_item_db):
action_execution_id = str(execution_queue_item_db.action_execution_id)
liveaction_id = str(execution_queue_item_db.liveaction_id)
queue_item_id = str(execution_queue_item_db.id)
- extra = {'queue_item_id': queue_item_id}
+ extra = {"queue_item_id": queue_item_id}
LOG.info(
'[%s] Scheduling Liveaction "%s".',
- action_execution_id, liveaction_id, extra=extra
+ action_execution_id,
+ liveaction_id,
+ extra=extra,
)
try:
liveaction_db = action_utils.get_liveaction_by_id(liveaction_id)
except StackStormDBObjectNotFoundError:
msg = '[%s] Failed to find liveaction "%s" in the database (queue_item_id=%s).'
- LOG.exception(msg, action_execution_id, liveaction_id, queue_item_id, extra=extra)
+ LOG.exception(
+ msg, action_execution_id, liveaction_id, queue_item_id, extra=extra
+ )
ActionExecutionSchedulingQueue.delete(execution_queue_item_db)
raise
# Identify if the action has policies that require locking.
action_has_policies_require_lock = policy_service.has_policies(
- liveaction_db,
- policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK
+ liveaction_db, policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK
)
# Acquire a distributed lock if the referenced action has specific policies attached.
@@ -275,9 +288,9 @@ def _handle_execution(self, execution_queue_item_db):
# Warn users that the coordination service is not configured.
if not coordination_service.configured():
LOG.warn(
- '[%s] Coordination backend is not configured. '
- 'Policy enforcement is best effort.',
- action_execution_id
+ "[%s] Coordination backend is not configured. "
+ "Policy enforcement is best effort.",
+ action_execution_id,
)
# Acquire a distributed lock before querying the database to make sure that only one
@@ -304,11 +317,14 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db):
action_execution_id = str(execution_queue_item_db.action_execution_id)
liveaction_id = str(execution_queue_item_db.liveaction_id)
queue_item_id = str(execution_queue_item_db.id)
- extra = {'queue_item_id': queue_item_id}
+ extra = {"queue_item_id": queue_item_id}
LOG.info(
'[%s] Liveaction "%s" has status "%s" before applying policies.',
- action_execution_id, liveaction_id, liveaction_db.status, extra=extra
+ action_execution_id,
+ liveaction_id,
+ liveaction_db.status,
+ extra=extra,
)
# Apply policies defined for the action.
@@ -316,13 +332,18 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db):
LOG.info(
'[%s] Liveaction "%s" has status "%s" after applying policies.',
- action_execution_id, liveaction_id, liveaction_db.status, extra=extra
+ action_execution_id,
+ liveaction_id,
+ liveaction_db.status,
+ extra=extra,
)
if liveaction_db.status == action_constants.LIVEACTION_STATUS_DELAYED:
LOG.info(
'[%s] Liveaction "%s" is delayed and scheduling queue is updated.',
- action_execution_id, liveaction_id, extra=extra
+ action_execution_id,
+ liveaction_id,
+ extra=extra,
)
liveaction_db = action_service.update_status(
@@ -330,23 +351,30 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db):
)
execution_queue_item_db.handling = False
- execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time(
- date.get_datetime_utc_now(),
- POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS
+ execution_queue_item_db.scheduled_start_timestamp = (
+ date.append_milliseconds_to_time(
+ date.get_datetime_utc_now(),
+ POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS,
+ )
)
try:
- ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False)
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
except db_exc.StackStormDBObjectWriteConflictError:
LOG.warning(
- '[%s] Database write conflict on updating scheduling queue.',
- action_execution_id, extra=extra
+ "[%s] Database write conflict on updating scheduling queue.",
+ action_execution_id,
+ extra=extra,
)
return
- if (liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES or
- liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES):
+ if (
+ liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES
+ or liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES
+ ):
ActionExecutionSchedulingQueue.delete(execution_queue_item_db)
return
@@ -356,33 +384,41 @@ def _delay(self, liveaction_db, execution_queue_item_db):
action_execution_id = str(execution_queue_item_db.action_execution_id)
liveaction_id = str(execution_queue_item_db.liveaction_id)
queue_item_id = str(execution_queue_item_db.id)
- extra = {'queue_item_id': queue_item_id}
+ extra = {"queue_item_id": queue_item_id}
LOG.info(
'[%s] Liveaction "%s" is delayed and scheduling queue is updated.',
- action_execution_id, liveaction_id, extra=extra
+ action_execution_id,
+ liveaction_id,
+ extra=extra,
)
liveaction_db = action_service.update_status(
liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False
)
- execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time(
- date.get_datetime_utc_now(),
- POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS
+ execution_queue_item_db.scheduled_start_timestamp = (
+ date.append_milliseconds_to_time(
+ date.get_datetime_utc_now(), POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS
+ )
)
try:
execution_queue_item_db.handling = False
- ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False)
+ ActionExecutionSchedulingQueue.add_or_update(
+ execution_queue_item_db, publish=False
+ )
except db_exc.StackStormDBObjectWriteConflictError:
LOG.warning(
- '[%s] Database write conflict on updating scheduling queue.',
- action_execution_id, extra=extra
+ "[%s] Database write conflict on updating scheduling queue.",
+ action_execution_id,
+ extra=extra,
)
def _schedule(self, liveaction_db, execution_queue_item_db):
- if self._is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db):
+ if self._is_execution_queue_item_runnable(
+ liveaction_db, execution_queue_item_db
+ ):
self._update_to_scheduled(liveaction_db, execution_queue_item_db)
@staticmethod
@@ -396,7 +432,7 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db):
valid_status = [
action_constants.LIVEACTION_STATUS_REQUESTED,
action_constants.LIVEACTION_STATUS_SCHEDULED,
- action_constants.LIVEACTION_STATUS_DELAYED
+ action_constants.LIVEACTION_STATUS_DELAYED,
]
if liveaction_db.status in valid_status:
@@ -405,11 +441,14 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db):
action_execution_id = str(execution_queue_item_db.action_execution_id)
liveaction_id = str(execution_queue_item_db.liveaction_id)
queue_item_id = str(execution_queue_item_db.id)
- extra = {'queue_item_id': queue_item_id}
+ extra = {"queue_item_id": queue_item_id}
LOG.info(
'[%s] Ignoring Liveaction "%s" with status "%s" after policies are applied.',
- action_execution_id, liveaction_id, liveaction_db.status, extra=extra
+ action_execution_id,
+ liveaction_id,
+ liveaction_db.status,
+ extra=extra,
)
ActionExecutionSchedulingQueue.delete(execution_queue_item_db)
@@ -421,18 +460,26 @@ def _update_to_scheduled(liveaction_db, execution_queue_item_db):
action_execution_id = str(execution_queue_item_db.action_execution_id)
liveaction_id = str(execution_queue_item_db.liveaction_id)
queue_item_id = str(execution_queue_item_db.id)
- extra = {'queue_item_id': queue_item_id}
+ extra = {"queue_item_id": queue_item_id}
# Update liveaction status to "scheduled".
LOG.info(
'[%s] Liveaction "%s" with status "%s" is updated to status "scheduled."',
- action_execution_id, liveaction_id, liveaction_db.status, extra=extra
+ action_execution_id,
+ liveaction_id,
+ liveaction_db.status,
+ extra=extra,
)
- if liveaction_db.status in [action_constants.LIVEACTION_STATUS_REQUESTED,
- action_constants.LIVEACTION_STATUS_DELAYED]:
+ if liveaction_db.status in [
+ action_constants.LIVEACTION_STATUS_REQUESTED,
+ action_constants.LIVEACTION_STATUS_DELAYED,
+ ]:
liveaction_db = action_service.update_status(
- liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED, publish=False)
+ liveaction_db,
+ action_constants.LIVEACTION_STATUS_SCHEDULED,
+ publish=False,
+ )
# Publish the "scheduled" status here manually. Otherwise, there could be a
# race condition with the update of the action_execution_db if the execution
diff --git a/st2actions/st2actions/worker.py b/st2actions/st2actions/worker.py
index 3147ce1aae..1741d60724 100644
--- a/st2actions/st2actions/worker.py
+++ b/st2actions/st2actions/worker.py
@@ -34,10 +34,7 @@
from st2common.transport import queues
-__all__ = [
- 'ActionExecutionDispatcher',
- 'get_worker'
-]
+__all__ = ["ActionExecutionDispatcher", "get_worker"]
LOG = logging.getLogger(__name__)
@@ -46,14 +43,14 @@
queues.ACTIONRUNNER_WORK_QUEUE,
queues.ACTIONRUNNER_CANCEL_QUEUE,
queues.ACTIONRUNNER_PAUSE_QUEUE,
- queues.ACTIONRUNNER_RESUME_QUEUE
+ queues.ACTIONRUNNER_RESUME_QUEUE,
]
ACTIONRUNNER_DISPATCHABLE_STATES = [
action_constants.LIVEACTION_STATUS_SCHEDULED,
action_constants.LIVEACTION_STATUS_CANCELING,
action_constants.LIVEACTION_STATUS_PAUSING,
- action_constants.LIVEACTION_STATUS_RESUMING
+ action_constants.LIVEACTION_STATUS_RESUMING,
]
@@ -83,41 +80,54 @@ def process(self, liveaction):
"""
if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED:
- LOG.info('%s is not executing %s (id=%s) with "%s" status.',
- self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status)
+ LOG.info(
+ '%s is not executing %s (id=%s) with "%s" status.',
+ self.__class__.__name__,
+ type(liveaction),
+ liveaction.id,
+ liveaction.status,
+ )
if not liveaction.result:
updated_liveaction = action_utils.update_liveaction_status(
status=liveaction.status,
- result={'message': 'Action execution canceled by user.'},
- liveaction_id=liveaction.id)
+ result={"message": "Action execution canceled by user."},
+ liveaction_id=liveaction.id,
+ )
executions.update_execution(updated_liveaction)
return
if liveaction.status not in ACTIONRUNNER_DISPATCHABLE_STATES:
- LOG.info('%s is not dispatching %s (id=%s) with "%s" status.',
- self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status)
+ LOG.info(
+ '%s is not dispatching %s (id=%s) with "%s" status.',
+ self.__class__.__name__,
+ type(liveaction),
+ liveaction.id,
+ liveaction.status,
+ )
return
try:
liveaction_db = action_utils.get_liveaction_by_id(liveaction.id)
except StackStormDBObjectNotFoundError:
- LOG.exception('Failed to find liveaction %s in the database.', liveaction.id)
+ LOG.exception(
+ "Failed to find liveaction %s in the database.", liveaction.id
+ )
raise
if liveaction.status != liveaction_db.status:
LOG.warning(
- 'The status of liveaction %s has changed from %s to %s '
- 'while in the queue waiting for processing.',
+ "The status of liveaction %s has changed from %s to %s "
+ "while in the queue waiting for processing.",
liveaction.id,
liveaction.status,
- liveaction_db.status
+ liveaction_db.status,
)
dispatchers = {
action_constants.LIVEACTION_STATUS_SCHEDULED: self._run_action,
action_constants.LIVEACTION_STATUS_CANCELING: self._cancel_action,
action_constants.LIVEACTION_STATUS_PAUSING: self._pause_action,
- action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action
+ action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action,
}
return dispatchers[liveaction.status](liveaction)
@@ -130,7 +140,7 @@ def shutdown(self):
try:
executions.abandon_execution_if_incomplete(liveaction_id=liveaction_id)
except:
- LOG.exception('Failed to abandon liveaction %s.', liveaction_id)
+ LOG.exception("Failed to abandon liveaction %s.", liveaction_id)
def _run_action(self, liveaction_db):
# stamp liveaction with process_info
@@ -140,35 +150,49 @@ def _run_action(self, liveaction_db):
liveaction_db = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_RUNNING,
runner_info=runner_info,
- liveaction_id=liveaction_db.id)
+ liveaction_id=liveaction_db.id,
+ )
self._running_liveactions.add(liveaction_db.id)
action_execution_db = executions.update_execution(liveaction_db)
# Launch action
- extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db}
- LOG.audit('Launching action execution.', extra=extra)
+ extra = {
+ "action_execution_db": action_execution_db,
+ "liveaction_db": liveaction_db,
+ }
+ LOG.audit("Launching action execution.", extra=extra)
# the extra field will not be shown in non-audit logs so temporarily log at info.
- LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
- action_execution_db.id, liveaction_db.id, liveaction_db.status)
-
- extra = {'liveaction_db': liveaction_db}
+ LOG.info(
+ 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
+ action_execution_db.id,
+ liveaction_db.id,
+ liveaction_db.status,
+ )
+
+ extra = {"liveaction_db": liveaction_db}
try:
result = self.container.dispatch(liveaction_db)
- LOG.debug('Runner dispatch produced result: %s', result)
+ LOG.debug("Runner dispatch produced result: %s", result)
if not result and not liveaction_db.action_is_workflow:
- raise ActionRunnerException('Failed to execute action.')
+ raise ActionRunnerException("Failed to execute action.")
except:
_, ex, tb = sys.exc_info()
- extra['error'] = str(ex)
- LOG.info('Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra)
+ extra["error"] = str(ex)
+ LOG.info(
+ 'Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra
+ )
liveaction_db = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_FAILED,
liveaction_id=liveaction_db.id,
- result={'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))})
+ result={
+ "error": str(ex),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ },
+ )
executions.update_execution(liveaction_db)
raise
finally:
@@ -182,66 +206,98 @@ def _run_action(self, liveaction_db):
def _cancel_action(self, liveaction_db):
action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id))
- extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db}
- LOG.audit('Canceling action execution.', extra=extra)
+ extra = {
+ "action_execution_db": action_execution_db,
+ "liveaction_db": liveaction_db,
+ }
+ LOG.audit("Canceling action execution.", extra=extra)
# the extra field will not be shown in non-audit logs so temporarily log at info.
- LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
- action_execution_db.id, liveaction_db.id, liveaction_db.status)
+ LOG.info(
+ 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
+ action_execution_db.id,
+ liveaction_db.id,
+ liveaction_db.status,
+ )
try:
result = self.container.dispatch(liveaction_db)
- LOG.debug('Runner dispatch produced result: %s', result)
+ LOG.debug("Runner dispatch produced result: %s", result)
except:
_, ex, tb = sys.exc_info()
- extra['error'] = str(ex)
- LOG.info('Failed to cancel action execution %s.' % (liveaction_db.id), extra=extra)
+ extra["error"] = str(ex)
+ LOG.info(
+ "Failed to cancel action execution %s." % (liveaction_db.id),
+ extra=extra,
+ )
raise
return result
def _pause_action(self, liveaction_db):
action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id))
- extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db}
- LOG.audit('Pausing action execution.', extra=extra)
+ extra = {
+ "action_execution_db": action_execution_db,
+ "liveaction_db": liveaction_db,
+ }
+ LOG.audit("Pausing action execution.", extra=extra)
# the extra field will not be shown in non-audit logs so temporarily log at info.
- LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
- action_execution_db.id, liveaction_db.id, liveaction_db.status)
+ LOG.info(
+ 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
+ action_execution_db.id,
+ liveaction_db.id,
+ liveaction_db.status,
+ )
try:
result = self.container.dispatch(liveaction_db)
- LOG.debug('Runner dispatch produced result: %s', result)
+ LOG.debug("Runner dispatch produced result: %s", result)
except:
_, ex, tb = sys.exc_info()
- extra['error'] = str(ex)
- LOG.info('Failed to pause action execution %s.' % (liveaction_db.id), extra=extra)
+ extra["error"] = str(ex)
+ LOG.info(
+ "Failed to pause action execution %s." % (liveaction_db.id), extra=extra
+ )
raise
return result
def _resume_action(self, liveaction_db):
action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id))
- extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db}
- LOG.audit('Resuming action execution.', extra=extra)
+ extra = {
+ "action_execution_db": action_execution_db,
+ "liveaction_db": liveaction_db,
+ }
+ LOG.audit("Resuming action execution.", extra=extra)
# the extra field will not be shown in non-audit logs so temporarily log at info.
- LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
- action_execution_db.id, liveaction_db.id, liveaction_db.status)
+ LOG.info(
+ 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.',
+ action_execution_db.id,
+ liveaction_db.id,
+ liveaction_db.status,
+ )
try:
result = self.container.dispatch(liveaction_db)
- LOG.debug('Runner dispatch produced result: %s', result)
+ LOG.debug("Runner dispatch produced result: %s", result)
except:
_, ex, tb = sys.exc_info()
- extra['error'] = str(ex)
- LOG.info('Failed to resume action execution %s.' % (liveaction_db.id), extra=extra)
+ extra["error"] = str(ex)
+ LOG.info(
+ "Failed to resume action execution %s." % (liveaction_db.id),
+ extra=extra,
+ )
raise
# Cascade the resume upstream if action execution is child of an orquesta workflow.
# The action service request_resume function is not used here because we do not want
# other peer subworkflows to be resumed.
- if 'orquesta' in action_execution_db.context and 'parent' in action_execution_db.context:
+ if (
+ "orquesta" in action_execution_db.context
+ and "parent" in action_execution_db.context
+ ):
wf_svc.handle_action_execution_resume(action_execution_db)
return result
diff --git a/st2actions/st2actions/workflows/config.py b/st2actions/st2actions/workflows/config.py
index 0d2556f67a..6854323ddd 100644
--- a/st2actions/st2actions/workflows/config.py
+++ b/st2actions/st2actions/workflows/config.py
@@ -23,8 +23,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=sys_constants.VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=sys_constants.VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -43,13 +46,13 @@ def _register_common_opts():
def _register_service_opts():
wf_engine_opts = [
cfg.StrOpt(
- 'logging',
- default='/etc/st2/logging.workflowengine.conf',
- help='Location of the logging configuration file.'
+ "logging",
+ default="/etc/st2/logging.workflowengine.conf",
+ help="Location of the logging configuration file.",
)
]
- cfg.CONF.register_opts(wf_engine_opts, group='workflow_engine')
+ cfg.CONF.register_opts(wf_engine_opts, group="workflow_engine")
register_opts()
diff --git a/st2actions/st2actions/workflows/workflows.py b/st2actions/st2actions/workflows/workflows.py
index 0351998025..2151c7d440 100644
--- a/st2actions/st2actions/workflows/workflows.py
+++ b/st2actions/st2actions/workflows/workflows.py
@@ -37,17 +37,16 @@
WORKFLOW_EXECUTION_QUEUES = [
queues.WORKFLOW_EXECUTION_WORK_QUEUE,
queues.WORKFLOW_EXECUTION_RESUME_QUEUE,
- queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE
+ queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE,
]
class WorkflowExecutionHandler(consumers.VariableMessageHandler):
-
def __init__(self, connection, queues):
super(WorkflowExecutionHandler, self).__init__(connection, queues)
def handle_workflow_execution_with_instrumentation(wf_ex_db):
- with metrics.CounterWithTimer(key='orquesta.workflow.executions'):
+ with metrics.CounterWithTimer(key="orquesta.workflow.executions"):
return self.handle_workflow_execution(wf_ex_db=wf_ex_db)
def handle_action_execution_with_instrumentation(ac_ex_db):
@@ -55,27 +54,27 @@ def handle_action_execution_with_instrumentation(ac_ex_db):
if not wf_svc.is_action_execution_under_workflow_context(ac_ex_db):
return
- with metrics.CounterWithTimer(key='orquesta.action.executions'):
+ with metrics.CounterWithTimer(key="orquesta.action.executions"):
return self.handle_action_execution(ac_ex_db=ac_ex_db)
self.message_types = {
wf_db_models.WorkflowExecutionDB: handle_workflow_execution_with_instrumentation,
- ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation
+ ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation,
}
def get_queue_consumer(self, connection, queues):
# We want to use a special ActionsQueueConsumer which uses 2 dispatcher pools
return consumers.VariableMessageQueueConsumer(
- connection=connection,
- queues=queues,
- handler=self
+ connection=connection, queues=queues, handler=self
)
def process(self, message):
handler_function = self.message_types.get(type(message), None)
if not handler_function:
- msg = 'Handler function for message type "%s" is not defined.' % type(message)
+ msg = 'Handler function for message type "%s" is not defined.' % type(
+ message
+ )
raise ValueError(msg)
try:
@@ -90,43 +89,45 @@ def process(self, message):
def fail_workflow_execution(self, message, exception):
# Prepare attributes based on message type.
if isinstance(message, wf_db_models.WorkflowExecutionDB):
- msg_type = 'workflow'
+ msg_type = "workflow"
wf_ex_db = message
wf_ex_id = str(wf_ex_db.id)
task = None
else:
- msg_type = 'task'
+ msg_type = "task"
ac_ex_db = message
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id)
- task = {'id': task_ex_db.task_id, 'route': task_ex_db.task_route}
+ task = {"id": task_ex_db.task_id, "route": task_ex_db.task_route}
# Log the error.
- msg = 'Unknown error while processing %s execution. %s: %s'
+ msg = "Unknown error while processing %s execution. %s: %s"
wf_svc.update_progress(
wf_ex_db,
msg % (msg_type, exception.__class__.__name__, str(exception)),
- severity='error'
+ severity="error",
)
# Fail the task execution so it's marked correctly in the
# conductor state to allow for task rerun if needed.
if isinstance(message, ex_db_models.ActionExecutionDB):
msg = 'Unknown error while processing %s execution. Failing task execution "%s".'
- wf_svc.update_progress(wf_ex_db, msg % (msg_type, task_ex_id), severity='error')
+ wf_svc.update_progress(
+ wf_ex_db, msg % (msg_type, task_ex_id), severity="error"
+ )
wf_svc.update_task_execution(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED)
wf_svc.update_task_state(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED)
# Fail the workflow execution.
msg = 'Unknown error while processing %s execution. Failing workflow execution "%s".'
- wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity='error')
+ wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity="error")
wf_svc.fail_workflow_execution(wf_ex_id, exception, task=task)
def handle_workflow_execution(self, wf_ex_db):
# Request the next set of tasks to execute.
- wf_svc.update_progress(wf_ex_db, 'Processing request for workflow execution.')
+ wf_svc.update_progress(wf_ex_db, "Processing request for workflow execution.")
wf_svc.request_next_tasks(wf_ex_db)
def handle_action_execution(self, ac_ex_db):
@@ -135,16 +136,17 @@ def handle_action_execution(self, ac_ex_db):
return
# Get related record identifiers.
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
# Get execution records for logging purposes.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id)
- msg = (
- 'Action execution "%s" for task "%s" is updated and in "%s" state.' %
- (str(ac_ex_db.id), task_ex_db.task_id, ac_ex_db.status)
+ msg = 'Action execution "%s" for task "%s" is updated and in "%s" state.' % (
+ str(ac_ex_db.id),
+ task_ex_db.task_id,
+ ac_ex_db.status,
)
wf_svc.update_progress(wf_ex_db, msg)
@@ -152,9 +154,13 @@ def handle_action_execution(self, ac_ex_db):
if task_ex_db.status in statuses.COMPLETED_STATUSES:
msg = (
'Action execution "%s" for task "%s", route "%s", is not processed '
- 'because task execution "%s" is already in completed state "%s".' % (
- str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route),
- str(task_ex_db.id), task_ex_db.status
+ 'because task execution "%s" is already in completed state "%s".'
+ % (
+ str(ac_ex_db.id),
+ task_ex_db.task_id,
+ str(task_ex_db.task_route),
+ str(task_ex_db.id),
+ task_ex_db.status,
)
)
wf_svc.update_progress(wf_ex_db, msg)
@@ -175,7 +181,7 @@ def handle_action_execution(self, ac_ex_db):
return
# Apply post run policies.
- lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction['id'])
+ lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction["id"])
pc_svc.apply_post_run_policies(lv_ac_db)
# Process completion of the action execution.
diff --git a/st2actions/tests/unit/policies/test_base.py b/st2actions/tests/unit/policies/test_base.py
index fcf3aef40d..2e5003d89c 100644
--- a/st2actions/tests/unit/policies/test_base.py
+++ b/st2actions/tests/unit/policies/test_base.py
@@ -17,6 +17,7 @@
import mock
from st2tests import config as test_config
+
test_config.parse_args()
import st2common
@@ -32,28 +33,21 @@
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'SchedulerPoliciesTestCase',
- 'NotifierPoliciesTestCase'
-]
+__all__ = ["SchedulerPoliciesTestCase", "NotifierPoliciesTestCase"]
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES_1 = {
- 'actions': [
- 'action1.yaml'
+ "actions": ["action1.yaml"],
+ "policies": [
+ "policy_4.yaml",
],
- 'policies': [
- 'policy_4.yaml',
- ]
}
TEST_FIXTURES_2 = {
- 'actions': [
- 'action1.yaml'
+ "actions": ["action1.yaml"],
+ "policies": [
+ "policy_1.yaml",
],
- 'policies': [
- 'policy_1.yaml',
- ]
}
@@ -73,15 +67,14 @@ def setUp(self):
register_policy_types(st2common)
loader = FixturesLoader()
- models = loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES_2)
+ models = loader.save_fixtures_to_db(
+ fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_2
+ )
# Policy with "post_run" application
- self.policy_db = models['policies']['policy_1.yaml']
+ self.policy_db = models["policies"]["policy_1.yaml"]
- @mock.patch.object(
- policies, 'get_driver',
- mock.MagicMock(return_value=None))
+ @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None))
def test_disabled_policy_not_applied_on_pre_run(self):
##########
# First test a scenario where policy is enabled
@@ -91,7 +84,9 @@ def test_disabled_policy_not_applied_on_pre_run(self):
# Post run hasn't been called yet, call count should be 0
self.assertEqual(policies.get_driver.call_count, 0)
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
policy_service.apply_pre_run_policies(live_action_db)
@@ -108,7 +103,9 @@ def test_disabled_policy_not_applied_on_pre_run(self):
self.assertEqual(policies.get_driver.call_count, 0)
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
policy_service.apply_pre_run_policies(live_action_db)
@@ -133,15 +130,14 @@ def setUp(self):
register_policy_types(st2common)
loader = FixturesLoader()
- models = loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES_1)
+ models = loader.save_fixtures_to_db(
+ fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_1
+ )
# Policy with "post_run" application
- self.policy_db = models['policies']['policy_4.yaml']
+ self.policy_db = models["policies"]["policy_4.yaml"]
- @mock.patch.object(
- policies, 'get_driver',
- mock.MagicMock(return_value=None))
+ @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None))
def test_disabled_policy_not_applied_on_post_run(self):
##########
# First test a scenario where policy is enabled
@@ -151,7 +147,9 @@ def test_disabled_policy_not_applied_on_post_run(self):
# Post run hasn't been called yet, call count should be 0
self.assertEqual(policies.get_driver.call_count, 0)
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
policy_service.apply_post_run_policies(live_action_db)
@@ -168,7 +166,9 @@ def test_disabled_policy_not_applied_on_post_run(self):
self.assertEqual(policies.get_driver.call_count, 0)
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
policy_service.apply_post_run_policies(live_action_db)
diff --git a/st2actions/tests/unit/policies/test_concurrency.py b/st2actions/tests/unit/policies/test_concurrency.py
index 670c38d839..f22a0303cd 100644
--- a/st2actions/tests/unit/policies/test_concurrency.py
+++ b/st2actions/tests/unit/policies/test_concurrency.py
@@ -42,40 +42,40 @@
from st2tests.mocks.runners import runner
-__all__ = [
- 'ConcurrencyPolicyTestCase'
-]
+__all__ = ["ConcurrencyPolicyTestCase"]
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml',
- 'action2.yaml'
- ],
- 'policies': [
- 'policy_1.yaml',
- 'policy_5.yaml'
- ]
+ "actions": ["action1.yaml", "action2.yaml"],
+ "policies": ["policy_1.yaml", "policy_5.yaml"],
}
-NON_EMPTY_RESULT = 'non-empty'
-MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None)
+NON_EMPTY_RESULT = "non-empty"
+MOCK_RUN_RETURN_VALUE = (
+ action_constants.LIVEACTION_STATUS_RUNNING,
+ NON_EMPTY_RESULT,
+ None,
+)
SCHEDULED_STATES = [
action_constants.LIVEACTION_STATUS_SCHEDULED,
action_constants.LIVEACTION_STATUS_RUNNING,
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
]
-@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch.object(
- CUDPublisher, 'publish_update',
- mock.MagicMock(side_effect=MockExecutionPublisher.publish_update))
+@mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch(
+ "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher,
+ "publish_update",
+ mock.MagicMock(side_effect=MockExecutionPublisher.publish_update),
+)
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
class ConcurrencyPolicyTestCase(EventletTestCase, ExecutionDbTestCase):
@classmethod
def setUpClass(cls):
@@ -93,8 +93,7 @@ def setUpClass(cls):
register_policy_types(st2common)
loader = FixturesLoader()
- loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
@classmethod
def tearDownClass(cls):
@@ -106,10 +105,15 @@ def tearDownClass(cls):
# NOTE: This monkey patch needs to happen again here because during tests for some reason this
# method gets unpatched (test doing reload() or similar)
- @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+ @mock.patch(
+ "st2actions.container.base.get_runner",
+ mock.Mock(return_value=runner.get_runner()),
+ )
def tearDown(self):
for liveaction in LiveAction.get_all():
- action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ action_service.update_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
@staticmethod
def _process_scheduling_queue():
@@ -117,64 +121,82 @@ def _process_scheduling_queue():
scheduling_queue.get_handler()._handle_execution(queued_req)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_over_threshold_delay_executions(self):
# Ensure the concurrency policy is accurate.
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency')
- self.assertGreater(policy_db.parameters['threshold'], 0)
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- parameters = {'actionstr': 'foo-' + str(i)}
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters)
+ for i in range(0, policy_db.parameters["threshold"]):
+ parameters = {"actionstr": "foo-" + str(i)}
+ liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters)
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# Assert the correct number of published states and action executions. This is to avoid
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be delayed since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo-last'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo-last"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Since states are being processed async, wait for the liveaction to go into delayed state.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
expected_num_exec += 0 # This request will not be scheduled for execution.
expected_num_pubs += 0 # The delayed status change should not be published.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Mark one of the scheduled/running execution as completed.
action_service.update_status(
- scheduled[0],
- action_constants.LIVEACTION_STATUS_SUCCEEDED,
- publish=True
+ scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True
)
expected_num_pubs += 1 # Tally succeeded state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -185,52 +207,74 @@ def test_over_threshold_delay_executions(self):
# Since states are being processed async, wait for the liveaction to be scheduled.
liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES)
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Check the status changes.
execution = ActionExecution.get(liveaction__id=str(liveaction.id))
- expected_status_changes = ['requested', 'delayed', 'requested', 'scheduled', 'running']
- actual_status_changes = [entry['status'] for entry in execution.log]
+ expected_status_changes = [
+ "requested",
+ "delayed",
+ "requested",
+ "scheduled",
+ "running",
+ ]
+ actual_status_changes = [entry["status"] for entry in execution.log]
self.assertListEqual(actual_status_changes, expected_status_changes)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_over_threshold_cancel_executions(self):
- policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.cancel')
- self.assertEqual(policy_db.parameters['action'], 'cancel')
- self.assertGreater(policy_db.parameters['threshold'], 0)
+ policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.cancel")
+ self.assertEqual(policy_db.parameters["action"], "cancel")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- parameters = {'actionstr': 'foo-' + str(i)}
- liveaction = LiveActionDB(action='wolfpack.action-2', parameters=parameters)
+ for i in range(0, policy_db.parameters["threshold"]):
+ parameters = {"actionstr": "foo-" + str(i)}
+ liveaction = LiveActionDB(action="wolfpack.action-2", parameters=parameters)
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be canceled since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-2", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -240,67 +284,91 @@ def test_over_threshold_cancel_executions(self):
LiveActionPublisher.publish_state.assert_has_calls(calls)
expected_num_pubs += 2 # Tally canceling and canceled state changes.
expected_num_exec += 0 # This request will not be scheduled for execution.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Assert the action is canceled.
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED)
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_on_cancellation(self):
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency')
- self.assertGreater(policy_db.parameters['threshold'], 0)
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- parameters = {'actionstr': 'foo-' + str(i)}
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters)
+ for i in range(0, policy_db.parameters["threshold"]):
+ parameters = {"actionstr": "foo-" + str(i)}
+ liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters)
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be delayed since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Since states are being processed async, wait for the liveaction to go into delayed state.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
expected_num_exec += 0 # This request will not be scheduled for execution.
expected_num_pubs += 0 # The delayed status change should not be published.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Cancel execution.
- action_service.request_cancellation(scheduled[0], 'stanley')
+ action_service.request_cancellation(scheduled[0], "stanley")
expected_num_pubs += 2 # Tally the canceling and canceled states.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -312,5 +380,7 @@ def test_on_cancellation(self):
# Execution is expected to be rescheduled.
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertIn(liveaction.status, SCHEDULED_STATES)
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
diff --git a/st2actions/tests/unit/policies/test_concurrency_by_attr.py b/st2actions/tests/unit/policies/test_concurrency_by_attr.py
index b576e3a669..98cfc3a4dc 100644
--- a/st2actions/tests/unit/policies/test_concurrency_by_attr.py
+++ b/st2actions/tests/unit/policies/test_concurrency_by_attr.py
@@ -39,42 +39,41 @@
from st2tests.mocks.runners import runner
from six.moves import range
-__all__ = [
- 'ConcurrencyByAttributePolicyTestCase'
-]
+__all__ = ["ConcurrencyByAttributePolicyTestCase"]
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml',
- 'action2.yaml'
- ],
- 'policies': [
- 'policy_3.yaml',
- 'policy_7.yaml'
- ]
+ "actions": ["action1.yaml", "action2.yaml"],
+ "policies": ["policy_3.yaml", "policy_7.yaml"],
}
-NON_EMPTY_RESULT = 'non-empty'
-MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None)
+NON_EMPTY_RESULT = "non-empty"
+MOCK_RUN_RETURN_VALUE = (
+ action_constants.LIVEACTION_STATUS_RUNNING,
+ NON_EMPTY_RESULT,
+ None,
+)
SCHEDULED_STATES = [
action_constants.LIVEACTION_STATUS_SCHEDULED,
action_constants.LIVEACTION_STATUS_RUNNING,
- action_constants.LIVEACTION_STATUS_SUCCEEDED
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
]
-@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+@mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch(
+ "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
@mock.patch.object(
- CUDPublisher, 'publish_update',
- mock.MagicMock(side_effect=MockExecutionPublisher.publish_update))
-@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher,
+ "publish_update",
+ mock.MagicMock(side_effect=MockExecutionPublisher.publish_update),
+)
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
class ConcurrencyByAttributePolicyTestCase(EventletTestCase, ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
EventletTestCase.setUpClass()
@@ -91,8 +90,7 @@ def setUpClass(cls):
register_policy_types(st2common)
loader = FixturesLoader()
- loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
@classmethod
def tearDownClass(cls):
@@ -104,10 +102,15 @@ def tearDownClass(cls):
# NOTE: This monkey patch needs to happen again here because during tests for some reason this
# method gets unpatched (test doing reload() or similar)
- @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+ @mock.patch(
+ "st2actions.container.base.get_runner",
+ mock.Mock(return_value=runner.get_runner()),
+ )
def tearDown(self):
for liveaction in LiveAction.get_all():
- action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ action_service.update_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
@staticmethod
def _process_scheduling_queue():
@@ -115,58 +118,80 @@ def _process_scheduling_queue():
scheduling_queue.get_handler()._handle_execution(queued_req)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_over_threshold_delay_executions(self):
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr')
- self.assertGreater(policy_db.parameters['threshold'], 0)
- self.assertIn('actionstr', policy_db.parameters['attributes'])
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
+ self.assertIn("actionstr", policy_db.parameters["attributes"])
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ for i in range(0, policy_db.parameters["threshold"]):
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# Assert the correct number of published states and action executions. This is to avoid
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be delayed since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Since states are being processed asynchronously, wait for the
# liveaction to go into delayed state.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
expected_num_exec += 0 # This request will not be scheduled for execution.
expected_num_pubs += 0 # The delayed status change should not be published.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be scheduled since concurrency threshold is not reached.
# The execution with actionstr "fu" is over the threshold but actionstr "bar" is not.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "bar"}
+ )
liveaction, _ = action_service.request(liveaction)
# Run the scheduler to schedule action executions.
@@ -177,18 +202,20 @@ def test_over_threshold_delay_executions(self):
liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES)
expected_num_exec += 1 # This request is expected to be executed.
expected_num_pubs += 3 # Tally requested, scheduled, and running state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Mark one of the execution as completed.
action_service.update_status(
- scheduled[0],
- action_constants.LIVEACTION_STATUS_SUCCEEDED,
- publish=True
+ scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True
)
expected_num_pubs += 1 # Tally succeeded state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -197,47 +224,65 @@ def test_over_threshold_delay_executions(self):
liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES)
expected_num_exec += 1 # The delayed request is expected to be executed.
expected_num_pubs += 2 # Tally scheduled and running state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_over_threshold_cancel_executions(self):
- policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.attr.cancel')
- self.assertEqual(policy_db.parameters['action'], 'cancel')
- self.assertGreater(policy_db.parameters['threshold'], 0)
- self.assertIn('actionstr', policy_db.parameters['attributes'])
+ policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.attr.cancel")
+ self.assertEqual(policy_db.parameters["action"], "cancel")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
+ self.assertIn("actionstr", policy_db.parameters["attributes"])
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'})
+ for i in range(0, policy_db.parameters["threshold"]):
+ liveaction = LiveActionDB(
+ action="wolfpack.action-2", parameters={"actionstr": "foo"}
+ )
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# Assert the correct number of published states and action executions. This is to avoid
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be delayed since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-2", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -247,7 +292,9 @@ def test_over_threshold_cancel_executions(self):
LiveActionPublisher.publish_state.assert_has_calls(calls)
expected_num_pubs += 2 # Tally canceling and canceled state changes.
expected_num_exec += 0 # This request will not be scheduled for execution.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Assert the action is canceled.
@@ -255,58 +302,80 @@ def test_over_threshold_cancel_executions(self):
self.assertEqual(canceled.status, action_constants.LIVEACTION_STATUS_CANCELED)
@mock.patch.object(
- runner.MockActionRunner, 'run',
- mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE))
+ runner.MockActionRunner,
+ "run",
+ mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+ )
def test_on_cancellation(self):
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr')
- self.assertGreater(policy_db.parameters['threshold'], 0)
- self.assertIn('actionstr', policy_db.parameters['attributes'])
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr")
+ self.assertGreater(policy_db.parameters["threshold"], 0)
+ self.assertIn("actionstr", policy_db.parameters["attributes"])
# Launch action executions until the expected threshold is reached.
- for i in range(0, policy_db.parameters['threshold']):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ for i in range(0, policy_db.parameters["threshold"]):
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
action_service.request(liveaction)
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Check the number of action executions in scheduled state.
- scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES]
- self.assertEqual(len(scheduled), policy_db.parameters['threshold'])
+ scheduled = [
+ item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES
+ ]
+ self.assertEqual(len(scheduled), policy_db.parameters["threshold"])
# duplicate executions caused by accidental publishing of state in the concurrency policies.
# num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running'])
expected_num_exec = len(scheduled)
expected_num_pubs = expected_num_exec * 3
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be delayed since concurrency threshold is reached.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
expected_num_pubs += 1 # Tally requested state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
# Since states are being processed asynchronously, wait for the
# liveaction to go into delayed state.
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
delayed = liveaction
expected_num_exec += 0 # This request will not be scheduled for execution.
expected_num_pubs += 0 # The delayed status change should not be published.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Execution is expected to be scheduled since concurrency threshold is not reached.
# The execution with actionstr "fu" is over the threshold but actionstr "bar" is not.
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "bar"}
+ )
liveaction, _ = action_service.request(liveaction)
# Run the scheduler to schedule action executions.
@@ -317,13 +386,17 @@ def test_on_cancellation(self):
liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES)
expected_num_exec += 1 # This request is expected to be executed.
expected_num_pubs += 3 # Tally requested, scheduled, and running states.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Cancel execution.
- action_service.request_cancellation(scheduled[0], 'stanley')
+ action_service.request_cancellation(scheduled[0], "stanley")
expected_num_pubs += 2 # Tally the canceling and canceled states.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
# Run the scheduler to schedule action executions.
self._process_scheduling_queue()
@@ -331,7 +404,9 @@ def test_on_cancellation(self):
# Once capacity freed up, the delayed execution is published as requested again.
expected_num_exec += 1 # The delayed request is expected to be executed.
expected_num_pubs += 2 # Tally scheduled and running state.
- self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count)
+ self.assertEqual(
+ expected_num_pubs, LiveActionPublisher.publish_state.call_count
+ )
self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count)
# Since states are being processed asynchronously, wait for the
diff --git a/st2actions/tests/unit/policies/test_retry_policy.py b/st2actions/tests/unit/policies/test_retry_policy.py
index 6b6f0f0cc4..21371c6a02 100644
--- a/st2actions/tests/unit/policies/test_retry_policy.py
+++ b/st2actions/tests/unit/policies/test_retry_policy.py
@@ -35,19 +35,10 @@
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'RetryPolicyTestCase'
-]
+__all__ = ["RetryPolicyTestCase"]
-PACK = 'generic'
-TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml'
- ],
- 'policies': [
- 'policy_4.yaml'
- ]
-}
+PACK = "generic"
+TEST_FIXTURES = {"actions": ["action1.yaml"], "policies": ["policy_4.yaml"]}
class RetryPolicyTestCase(CleanDbTestCase):
@@ -66,18 +57,21 @@ def setUp(self):
register_policy_types(st2actions)
loader = FixturesLoader()
- models = loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ models = loader.save_fixtures_to_db(
+ fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES
+ )
# Instantiate policy applicator we will use in the tests
- policy_db = models['policies']['policy_4.yaml']
- retry_on = policy_db.parameters['retry_on']
- max_retry_count = policy_db.parameters['max_retry_count']
- self.policy = ExecutionRetryPolicyApplicator(policy_ref='test_policy',
- policy_type='action.retry',
- retry_on=retry_on,
- max_retry_count=max_retry_count,
- delay=0)
+ policy_db = models["policies"]["policy_4.yaml"]
+ retry_on = policy_db.parameters["retry_on"]
+ max_retry_count = policy_db.parameters["max_retry_count"]
+ self.policy = ExecutionRetryPolicyApplicator(
+ policy_ref="test_policy",
+ policy_type="action.retry",
+ retry_on=retry_on,
+ max_retry_count=max_retry_count,
+ delay=0,
+ )
def test_retry_on_timeout_no_retry_since_no_timeout_reached(self):
# Verify initial state
@@ -85,7 +79,9 @@ def test_retry_on_timeout_no_retry_since_no_timeout_reached(self):
self.assertSequenceEqual(ActionExecution.get_all(), [])
# Start a mock action which succeeds
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
live_action_db.status = LIVEACTION_STATUS_SUCCEEDED
@@ -110,7 +106,9 @@ def test_retry_on_timeout_first_retry_is_successful(self):
self.assertSequenceEqual(ActionExecution.get_all(), [])
# Start a mock action which times out
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
live_action_db.status = LIVEACTION_STATUS_TIMED_OUT
@@ -130,14 +128,16 @@ def test_retry_on_timeout_first_retry_is_successful(self):
self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED)
# Verify retried execution contains policy related context
- original_liveaction_id = action_execution_dbs[0].liveaction['id']
+ original_liveaction_id = action_execution_dbs[0].liveaction["id"]
context = action_execution_dbs[1].context
- self.assertIn('policies', context)
- self.assertEqual(context['policies']['retry']['retry_count'], 1)
- self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy')
- self.assertEqual(context['policies']['retry']['retried_liveaction_id'],
- original_liveaction_id)
+ self.assertIn("policies", context)
+ self.assertEqual(context["policies"]["retry"]["retry_count"], 1)
+ self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy")
+ self.assertEqual(
+ context["policies"]["retry"]["retried_liveaction_id"],
+ original_liveaction_id,
+ )
# Simulate success of second action so no it shouldn't be retried anymore
live_action_db = live_action_dbs[1]
@@ -161,7 +161,9 @@ def test_retry_on_timeout_policy_is_retried_twice(self):
self.assertSequenceEqual(ActionExecution.get_all(), [])
# Start a mock action which times out
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
live_action_db.status = LIVEACTION_STATUS_TIMED_OUT
@@ -181,14 +183,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self):
self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED)
# Verify retried execution contains policy related context
- original_liveaction_id = action_execution_dbs[0].liveaction['id']
+ original_liveaction_id = action_execution_dbs[0].liveaction["id"]
context = action_execution_dbs[1].context
- self.assertIn('policies', context)
- self.assertEqual(context['policies']['retry']['retry_count'], 1)
- self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy')
- self.assertEqual(context['policies']['retry']['retried_liveaction_id'],
- original_liveaction_id)
+ self.assertIn("policies", context)
+ self.assertEqual(context["policies"]["retry"]["retry_count"], 1)
+ self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy")
+ self.assertEqual(
+ context["policies"]["retry"]["retried_liveaction_id"],
+ original_liveaction_id,
+ )
# Simulate timeout of second action which should cause another retry
live_action_db = live_action_dbs[1]
@@ -212,14 +216,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self):
self.assertEqual(action_execution_dbs[2].status, LIVEACTION_STATUS_REQUESTED)
# Verify retried execution contains policy related context
- original_liveaction_id = action_execution_dbs[1].liveaction['id']
+ original_liveaction_id = action_execution_dbs[1].liveaction["id"]
context = action_execution_dbs[2].context
- self.assertIn('policies', context)
- self.assertEqual(context['policies']['retry']['retry_count'], 2)
- self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy')
- self.assertEqual(context['policies']['retry']['retried_liveaction_id'],
- original_liveaction_id)
+ self.assertIn("policies", context)
+ self.assertEqual(context["policies"]["retry"]["retry_count"], 2)
+ self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy")
+ self.assertEqual(
+ context["policies"]["retry"]["retried_liveaction_id"],
+ original_liveaction_id,
+ )
def test_retry_on_timeout_max_retries_reached(self):
# Verify initial state
@@ -227,12 +233,14 @@ def test_retry_on_timeout_max_retries_reached(self):
self.assertSequenceEqual(ActionExecution.get_all(), [])
# Start a mock action which times out
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
live_action_db.status = LIVEACTION_STATUS_TIMED_OUT
- live_action_db.context['policies'] = {}
- live_action_db.context['policies']['retry'] = {'retry_count': 2}
+ live_action_db.context["policies"] = {}
+ live_action_db.context["policies"]["retry"] = {"retry_count": 2}
execution_db.status = LIVEACTION_STATUS_TIMED_OUT
LiveAction.add_or_update(live_action_db)
ActionExecution.add_or_update(execution_db)
@@ -248,8 +256,10 @@ def test_retry_on_timeout_max_retries_reached(self):
self.assertEqual(action_execution_dbs[0].status, LIVEACTION_STATUS_TIMED_OUT)
@mock.patch.object(
- trace_service, 'get_trace_db_by_live_action',
- mock.MagicMock(return_value=(None, None)))
+ trace_service,
+ "get_trace_db_by_live_action",
+ mock.MagicMock(return_value=(None, None)),
+ )
def test_no_retry_on_workflow_task(self):
# Verify initial state
self.assertSequenceEqual(LiveAction.get_all(), [])
@@ -257,9 +267,9 @@ def test_no_retry_on_workflow_task(self):
# Start a mock action which times out
live_action_db = LiveActionDB(
- action='wolfpack.action-1',
- parameters={'actionstr': 'foo'},
- context={'parent': {'execution_id': 'abcde'}}
+ action="wolfpack.action-1",
+ parameters={"actionstr": "foo"},
+ context={"parent": {"execution_id": "abcde"}},
)
live_action_db, execution_db = action_service.request(live_action_db)
@@ -268,7 +278,7 @@ def test_no_retry_on_workflow_task(self):
# Expire the workflow instance.
live_action_db.status = LIVEACTION_STATUS_TIMED_OUT
- live_action_db.context['policies'] = {}
+ live_action_db.context["policies"] = {}
execution_db.status = LIVEACTION_STATUS_TIMED_OUT
LiveAction.add_or_update(live_action_db)
ActionExecution.add_or_update(execution_db)
@@ -297,10 +307,12 @@ def test_no_retry_on_non_applicable_statuses(self):
LIVEACTION_STATUS_CANCELED,
]
- action_ref = 'wolfpack.action-1'
+ action_ref = "wolfpack.action-1"
for status in non_retry_statuses:
- liveaction = LiveActionDB(action=action_ref, parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action=action_ref, parameters={"actionstr": "foo"}
+ )
live_action_db, execution_db = action_service.request(liveaction)
live_action_db.status = status
diff --git a/st2actions/tests/unit/test_action_runner_worker.py b/st2actions/tests/unit/test_action_runner_worker.py
index 4f2494a431..1d0c7bbbd0 100644
--- a/st2actions/tests/unit/test_action_runner_worker.py
+++ b/st2actions/tests/unit/test_action_runner_worker.py
@@ -21,11 +21,10 @@
from st2common.models.db.liveaction import LiveActionDB
from st2tests import config as test_config
+
test_config.parse_args()
-__all__ = [
- 'ActionsQueueConsumerTestCase'
-]
+__all__ = ["ActionsQueueConsumerTestCase"]
class ActionsQueueConsumerTestCase(TestCase):
@@ -38,7 +37,9 @@ def test_process_right_dispatcher_is_used(self):
consumer._workflows_dispatcher = Mock()
consumer._actions_dispatcher = Mock()
- body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=False)
+ body = LiveActionDB(
+ status="scheduled", action="core.local", action_is_workflow=False
+ )
message = Mock()
consumer.process(body=body, message=message)
@@ -49,7 +50,9 @@ def test_process_right_dispatcher_is_used(self):
consumer._workflows_dispatcher = Mock()
consumer._actions_dispatcher = Mock()
- body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=True)
+ body = LiveActionDB(
+ status="scheduled", action="core.local", action_is_workflow=True
+ )
message = Mock()
consumer.process(body=body, message=message)
diff --git a/st2actions/tests/unit/test_actions_registrar.py b/st2actions/tests/unit/test_actions_registrar.py
index c4d2771268..cc9da33299 100644
--- a/st2actions/tests/unit/test_actions_registrar.py
+++ b/st2actions/tests/unit/test_actions_registrar.py
@@ -31,18 +31,24 @@
import st2tests.fixturesloader as fixtures_loader
from st2tests.fixturesloader import get_fixtures_base_path
-MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name='run-local', runner_module='st2.runners.local')
+MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name="run-local", runner_module="st2.runners.local")
# NOTE: We need to perform this patching because test fixtures are located outside of the packs
# base paths directory. This will never happen outside the context of test fixtures.
-@mock.patch('st2common.content.utils.get_pack_base_path',
- mock.Mock(return_value=os.path.join(get_fixtures_base_path(), 'generic')))
+@mock.patch(
+ "st2common.content.utils.get_pack_base_path",
+ mock.Mock(return_value=os.path.join(get_fixtures_base_path(), "generic")),
+)
class ActionsRegistrarTest(tests_base.DbTestCase):
-
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_register_all_actions(self):
try:
packs_base_path = fixtures_loader.get_fixtures_base_path()
@@ -50,111 +56,157 @@ def test_register_all_actions(self):
actions_registrar.register_actions(packs_base_paths=[packs_base_path])
except Exception as e:
print(six.text_type(e))
- self.fail('All actions must be registered without exceptions.')
+ self.fail("All actions must be registered without exceptions.")
else:
all_actions_in_db = Action.get_all()
self.assertTrue(len(all_actions_in_db) > 0)
# Assert metadata_file field is populated
- expected_path = 'actions/action-with-no-parameters.yaml'
+ expected_path = "actions/action-with-no-parameters.yaml"
self.assertEqual(all_actions_in_db[0].metadata_file, expected_path)
def test_register_actions_from_bad_pack(self):
packs_base_path = tests_base.get_fixtures_path()
try:
actions_registrar.register_actions(packs_base_paths=[packs_base_path])
- self.fail('Should have thrown.')
+ self.fail("Should have thrown.")
except:
pass
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_pack_name_missing(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action_3_pack_missing.yaml')
- registrar._register_action('dummy', action_file)
+ "generic", "actions", "action_3_pack_missing.yaml"
+ )
+ registrar._register_action("dummy", action_file)
action_name = None
- with open(action_file, 'r') as fd:
+ with open(action_file, "r") as fd:
content = yaml.safe_load(fd)
- action_name = str(content['name'])
+ action_name = str(content["name"])
action_db = Action.get_by_name(action_name)
- expected_msg = 'Content pack must be set to dummy'
- self.assertEqual(action_db.pack, 'dummy', expected_msg)
+ expected_msg = "Content pack must be set to dummy"
+ self.assertEqual(action_db.pack, "dummy", expected_msg)
Action.delete(action_db)
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_register_action_with_no_params(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action-with-no-parameters.yaml')
-
- self.assertEqual(registrar._register_action('dummy', action_file), None)
-
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ "generic", "actions", "action-with-no-parameters.yaml"
+ )
+
+ self.assertEqual(registrar._register_action("dummy", action_file), None)
+
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_register_action_invalid_parameter_type_attribute(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action_invalid_param_type.yaml')
-
- expected_msg = '\'list\' is not valid under any of the given schema'
- self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg,
- registrar._register_action,
- 'dummy', action_file)
-
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ "generic", "actions", "action_invalid_param_type.yaml"
+ )
+
+ expected_msg = "'list' is not valid under any of the given schema"
+ self.assertRaisesRegexp(
+ jsonschema.ValidationError,
+ expected_msg,
+ registrar._register_action,
+ "dummy",
+ action_file,
+ )
+
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_register_action_invalid_parameter_name(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action_invalid_parameter_name.yaml')
-
- expected_msg = ('Parameter name "action-name" is invalid. Valid characters for '
- 'parameter name are')
- self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg,
- registrar._register_action,
- 'generic', action_file)
-
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ "generic", "actions", "action_invalid_parameter_name.yaml"
+ )
+
+ expected_msg = (
+ 'Parameter name "action-name" is invalid. Valid characters for '
+ "parameter name are"
+ )
+ self.assertRaisesRegexp(
+ jsonschema.ValidationError,
+ expected_msg,
+ registrar._register_action,
+ "generic",
+ action_file,
+ )
+
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_invalid_params_schema(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action-invalid-schema-params.yaml')
+ "generic", "actions", "action-invalid-schema-params.yaml"
+ )
try:
- registrar._register_action('generic', action_file)
- self.fail('Invalid action schema. Should have failed.')
+ registrar._register_action("generic", action_file)
+ self.fail("Invalid action schema. Should have failed.")
except jsonschema.ValidationError:
pass
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True))
- @mock.patch.object(action_validator, 'get_runner_model',
- mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ action_validator,
+ "get_runner_model",
+ mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB),
+ )
def test_action_update(self):
registrar = actions_registrar.ActionsRegistrar()
loader = fixtures_loader.FixturesLoader()
action_file = loader.get_fixture_file_path_abs(
- 'generic', 'actions', 'action1.yaml')
- registrar._register_action('wolfpack', action_file)
+ "generic", "actions", "action1.yaml"
+ )
+ registrar._register_action("wolfpack", action_file)
# try registering again. this should not throw errors.
- registrar._register_action('wolfpack', action_file)
+ registrar._register_action("wolfpack", action_file)
action_name = None
- with open(action_file, 'r') as fd:
+ with open(action_file, "r") as fd:
content = yaml.safe_load(fd)
- action_name = str(content['name'])
+ action_name = str(content["name"])
action_db = Action.get_by_name(action_name)
- expected_msg = 'Content pack must be set to wolfpack'
- self.assertEqual(action_db.pack, 'wolfpack', expected_msg)
+ expected_msg = "Content pack must be set to wolfpack"
+ self.assertEqual(action_db.pack, "wolfpack", expected_msg)
Action.delete(action_db)
diff --git a/st2actions/tests/unit/test_async_runner.py b/st2actions/tests/unit/test_async_runner.py
index 0409202903..31258fae4e 100644
--- a/st2actions/tests/unit/test_async_runner.py
+++ b/st2actions/tests/unit/test_async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import AsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class AsyncTestRunner(AsyncActionRunner):
def __init__(self):
- super(AsyncTestRunner, self).__init__(runner_id='1')
+ super(AsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2actions/tests/unit/test_execution_cancellation.py b/st2actions/tests/unit/test_execution_cancellation.py
index 6a130e2fe7..e6c51159ef 100644
--- a/st2actions/tests/unit/test_execution_cancellation.py
+++ b/st2actions/tests/unit/test_execution_cancellation.py
@@ -22,6 +22,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.constants import action as action_constants
@@ -42,35 +43,32 @@
from st2tests.mocks.liveaction import MockLiveActionPublisherNonBlocking
from st2tests.mocks.runners import runner
-__all__ = [
- 'ExecutionCancellationTestCase'
-]
+__all__ = ["ExecutionCancellationTestCase"]
-TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml'
- ]
-}
+TEST_FIXTURES = {"actions": ["action1.yaml"]}
-PACK = 'generic'
+PACK = "generic"
LOADER = FixturesLoader()
FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
-@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch.object(
- CUDPublisher, 'publish_update',
- mock.MagicMock(side_effect=MockExecutionPublisher.publish_update))
+@mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch(
+ "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher,
+ "publish_update",
+ mock.MagicMock(side_effect=MockExecutionPublisher.publish_update),
+)
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
class ExecutionCancellationTestCase(ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(ExecutionCancellationTestCase, cls).setUpClass()
- for _, fixture in six.iteritems(FIXTURES['actions']):
+ for _, fixture in six.iteritems(FIXTURES["actions"]):
instance = ActionAPI(**fixture)
Action.add_or_update(ActionAPI.to_model(instance))
@@ -80,62 +78,84 @@ def tearDown(self):
# Ensure all liveactions are canceled at end of each test.
for liveaction in LiveAction.get_all():
action_service.update_status(
- liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
@classmethod
def get_runner_class(cls, runner_name):
return runners.get_runner(runner_name).__class__
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state))
- @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
- @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state),
+ )
+ @mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+ )
+ @mock.patch(
+ "st2actions.container.base.get_runner",
+ mock.Mock(return_value=runner.get_runner()),
+ )
def test_basic_cancel(self):
- runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None)
+ runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None)
mock_runner_run = mock.Mock(return_value=runner_run_result)
- with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run):
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
liveaction = self._wait_on_status(
- liveaction,
- action_constants.LIVEACTION_STATUS_RUNNING
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
)
# Cancel execution.
action_service.request_cancellation(liveaction, cfg.CONF.system_user.user)
liveaction = self._wait_on_status(
- liveaction,
- action_constants.LIVEACTION_STATUS_CANCELED
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create))
+ CUDPublisher,
+ "publish_create",
+ mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create),
+ )
@mock.patch.object(
- CUDPublisher, 'publish_update',
- mock.MagicMock(side_effect=MockExecutionPublisher.publish_update))
+ CUDPublisher,
+ "publish_update",
+ mock.MagicMock(side_effect=MockExecutionPublisher.publish_update),
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state),
+ )
@mock.patch.object(
- runners.ActionRunner, 'cancel',
- mock.MagicMock(side_effect=Exception('Mock cancellation failure.')))
- @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
- @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+ runners.ActionRunner,
+ "cancel",
+ mock.MagicMock(side_effect=Exception("Mock cancellation failure.")),
+ )
+ @mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+ )
+ @mock.patch(
+ "st2actions.container.base.get_runner",
+ mock.Mock(return_value=runner.get_runner()),
+ )
def test_failed_cancel(self):
- runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None)
+ runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None)
mock_runner_run = mock.Mock(return_value=runner_run_result)
- with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run):
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
liveaction = self._wait_on_status(
- liveaction,
- action_constants.LIVEACTION_STATUS_RUNNING
+ liveaction, action_constants.LIVEACTION_STATUS_RUNNING
)
# Cancel execution.
@@ -144,22 +164,28 @@ def test_failed_cancel(self):
# Cancellation failed and execution state remains "canceling".
runners.ActionRunner.cancel.assert_called_once_with()
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING
+ )
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher, "publish_create", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(return_value=None))
+ LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- runners.ActionRunner, 'cancel',
- mock.MagicMock(return_value=None))
+ runners.ActionRunner, "cancel", mock.MagicMock(return_value=None)
+ )
def test_noop_cancel(self):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
# Cancel execution.
action_service.request_cancellation(liveaction, cfg.CONF.system_user.user)
@@ -171,22 +197,28 @@ def test_noop_cancel(self):
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher, "publish_create", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(return_value=None))
+ LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- runners.ActionRunner, 'cancel',
- mock.MagicMock(return_value=None))
+ runners.ActionRunner, "cancel", mock.MagicMock(return_value=None)
+ )
def test_cancel_delayed_execution(self):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
# Manually update the liveaction from requested to delayed to mock concurrency policy.
- action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ action_service.update_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED)
@@ -200,27 +232,33 @@ def test_cancel_delayed_execution(self):
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
+ CUDPublisher, "publish_create", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(return_value=None))
+ LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- trace_service, 'get_trace_db_by_live_action',
- mock.MagicMock(return_value=(None, None)))
+ trace_service,
+ "get_trace_db_by_live_action",
+ mock.MagicMock(return_value=(None, None)),
+ )
def test_cancel_delayed_execution_with_parent(self):
liveaction = LiveActionDB(
- action='wolfpack.action-1',
- parameters={'actionstr': 'foo'},
- context={'parent': {'execution_id': uuid.uuid4().hex}}
+ action="wolfpack.action-1",
+ parameters={"actionstr": "foo"},
+ context={"parent": {"execution_id": uuid.uuid4().hex}},
)
liveaction, _ = action_service.request(liveaction)
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
# Manually update the liveaction from requested to delayed to mock concurrency policy.
- action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED)
+ action_service.update_status(
+ liveaction, action_constants.LIVEACTION_STATUS_DELAYED
+ )
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED)
@@ -230,4 +268,6 @@ def test_cancel_delayed_execution_with_parent(self):
# Cancel is only called when liveaction is still in running state.
# Otherwise, the cancellation is only a state change.
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING
+ )
diff --git a/st2actions/tests/unit/test_executions.py b/st2actions/tests/unit/test_executions.py
index f143631e42..64bde6b654 100644
--- a/st2actions/tests/unit/test_executions.py
+++ b/st2actions/tests/unit/test_executions.py
@@ -20,6 +20,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
import st2common.bootstrap.runnersregistrar as runners_registrar
@@ -53,47 +54,57 @@
@mock.patch.object(
- LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=(action_constants.LIVEACTION_STATUS_FAILED, 'Non-empty', None)))
+ LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(
+ return_value=(action_constants.LIVEACTION_STATUS_FAILED, "Non-empty", None)
+ ),
+)
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create))
+ CUDPublisher,
+ "publish_create",
+ mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state),
+)
class TestActionExecutionHistoryWorker(ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestActionExecutionHistoryWorker, cls).setUpClass()
runners_registrar.register_runners()
- action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['local']))
+ action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]))
Action.add_or_update(ActionAPI.to_model(action_local))
- action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['chain']))
- action_chain.entry_point = fixture.PATH + '/chain.yaml'
+ action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]))
+ action_chain.entry_point = fixture.PATH + "/chain.yaml"
Action.add_or_update(ActionAPI.to_model(action_chain))
def tearDown(self):
- MOCK_FAIL_EXECUTION_CREATE = False # noqa
+ MOCK_FAIL_EXECUTION_CREATE = False # noqa
super(TestActionExecutionHistoryWorker, self).tearDown()
def test_basic_execution(self):
- liveaction = LiveActionDB(action='executions.local', parameters={'cmd': 'uname -a'})
+ liveaction = LiveActionDB(
+ action="executions.local", parameters={"cmd": "uname -a"}
+ )
liveaction, _ = action_service.request(liveaction)
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_FAILED
+ )
execution = self._get_action_execution(
- liveaction__id=str(liveaction.id),
- raise_exception=True
+ liveaction__id=str(liveaction.id), raise_exception=True
)
self.assertDictEqual(execution.trigger, {})
self.assertDictEqual(execution.trigger_type, {})
self.assertDictEqual(execution.trigger_instance, {})
self.assertDictEqual(execution.rule, {})
- action = action_utils.get_action_by_ref('executions.local')
+ action = action_utils.get_action_by_ref("executions.local")
self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action)))
- runner = RunnerType.get_by_name(action.runner_type['name'])
+ runner = RunnerType.get_by_name(action.runner_type["name"])
self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner)))
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(execution.start_timestamp, liveaction.start_timestamp)
@@ -101,26 +112,27 @@ def test_basic_execution(self):
self.assertEqual(execution.result, liveaction.result)
self.assertEqual(execution.status, liveaction.status)
self.assertEqual(execution.context, liveaction.context)
- self.assertEqual(execution.liveaction['callback'], liveaction.callback)
- self.assertEqual(execution.liveaction['action'], liveaction.action)
+ self.assertEqual(execution.liveaction["callback"], liveaction.callback)
+ self.assertEqual(execution.liveaction["action"], liveaction.action)
def test_basic_execution_history_create_failed(self):
- MOCK_FAIL_EXECUTION_CREATE = True # noqa
+ MOCK_FAIL_EXECUTION_CREATE = True # noqa
self.test_basic_execution()
def test_chained_executions(self):
- liveaction = LiveActionDB(action='executions.chain')
+ liveaction = LiveActionDB(action="executions.chain")
liveaction, _ = action_service.request(liveaction)
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_FAILED
+ )
execution = self._get_action_execution(
- liveaction__id=str(liveaction.id),
- raise_exception=True
+ liveaction__id=str(liveaction.id), raise_exception=True
)
- action = action_utils.get_action_by_ref('executions.chain')
+ action = action_utils.get_action_by_ref("executions.chain")
self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action)))
- runner = RunnerType.get_by_name(action.runner_type['name'])
+ runner = RunnerType.get_by_name(action.runner_type["name"])
self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner)))
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(execution.start_timestamp, liveaction.start_timestamp)
@@ -128,56 +140,69 @@ def test_chained_executions(self):
self.assertEqual(execution.result, liveaction.result)
self.assertEqual(execution.status, liveaction.status)
self.assertEqual(execution.context, liveaction.context)
- self.assertEqual(execution.liveaction['callback'], liveaction.callback)
- self.assertEqual(execution.liveaction['action'], liveaction.action)
+ self.assertEqual(execution.liveaction["callback"], liveaction.callback)
+ self.assertEqual(execution.liveaction["action"], liveaction.action)
self.assertGreater(len(execution.children), 0)
for child in execution.children:
record = ActionExecution.get(id=child, raise_exception=True)
self.assertEqual(record.parent, str(execution.id))
- self.assertEqual(record.action['name'], 'local')
- self.assertEqual(record.runner['name'], 'local-shell-cmd')
+ self.assertEqual(record.action["name"], "local")
+ self.assertEqual(record.runner["name"], "local-shell-cmd")
def test_triggered_execution(self):
docs = {
- 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']),
- 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']),
- 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']),
- 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance'])}
+ "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]),
+ "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]),
+ "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]),
+ "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]),
+ }
# Trigger an action execution.
trigger_type = TriggerType.add_or_update(
- TriggerTypeAPI.to_model(TriggerTypeAPI(**docs['trigger_type'])))
- trigger = Trigger.add_or_update(TriggerAPI.to_model(TriggerAPI(**docs['trigger'])))
- rule = RuleAPI.to_model(RuleAPI(**docs['rule']))
+ TriggerTypeAPI.to_model(TriggerTypeAPI(**docs["trigger_type"]))
+ )
+ trigger = Trigger.add_or_update(
+ TriggerAPI.to_model(TriggerAPI(**docs["trigger"]))
+ )
+ rule = RuleAPI.to_model(RuleAPI(**docs["rule"]))
rule.trigger = reference.get_str_resource_ref_from_model(trigger)
rule = Rule.add_or_update(rule)
trigger_instance = TriggerInstance.add_or_update(
- TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs['trigger_instance'])))
+ TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs["trigger_instance"]))
+ )
trace_service.add_or_update_given_trace_context(
- trace_context={'trace_tag': 'test_triggered_execution_trace'},
- trigger_instances=[str(trigger_instance.id)])
+ trace_context={"trace_tag": "test_triggered_execution_trace"},
+ trigger_instances=[str(trigger_instance.id)],
+ )
enforcer = RuleEnforcer(trigger_instance, rule)
enforcer.enforce()
# Wait for the action execution to complete and then confirm outcome.
- liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id))
+ liveaction = LiveAction.get(
+ context__trigger_instance__id=str(trigger_instance.id)
+ )
self.assertIsNotNone(liveaction)
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_FAILED
+ )
execution = self._get_action_execution(
- liveaction__id=str(liveaction.id),
- raise_exception=True
+ liveaction__id=str(liveaction.id), raise_exception=True
)
self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger)))
- self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)))
- self.assertDictEqual(execution.trigger_instance,
- vars(TriggerInstanceAPI.from_model(trigger_instance)))
+ self.assertDictEqual(
+ execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))
+ )
+ self.assertDictEqual(
+ execution.trigger_instance,
+ vars(TriggerInstanceAPI.from_model(trigger_instance)),
+ )
self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule)))
action = action_utils.get_action_by_ref(liveaction.action)
self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action)))
- runner = RunnerType.get_by_name(action.runner_type['name'])
+ runner = RunnerType.get_by_name(action.runner_type["name"])
self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner)))
liveaction = LiveAction.get_by_id(str(liveaction.id))
self.assertEqual(execution.start_timestamp, liveaction.start_timestamp)
@@ -185,8 +210,8 @@ def test_triggered_execution(self):
self.assertEqual(execution.result, liveaction.result)
self.assertEqual(execution.status, liveaction.status)
self.assertEqual(execution.context, liveaction.context)
- self.assertEqual(execution.liveaction['callback'], liveaction.callback)
- self.assertEqual(execution.liveaction['action'], liveaction.action)
+ self.assertEqual(execution.liveaction["callback"], liveaction.callback)
+ self.assertEqual(execution.liveaction["action"], liveaction.action)
def _get_action_execution(self, **kwargs):
return ActionExecution.get(**kwargs)
diff --git a/st2actions/tests/unit/test_notifier.py b/st2actions/tests/unit/test_notifier.py
index fa1af31ca8..b648d7fad3 100644
--- a/st2actions/tests/unit/test_notifier.py
+++ b/st2actions/tests/unit/test_notifier.py
@@ -20,6 +20,7 @@
import mock
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2actions.notifier.notifier import Notifier
@@ -41,77 +42,96 @@
from st2common.util import isotime
from st2tests.base import CleanDbTestCase
-ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0]
-NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1]
-MOCK_EXECUTION = ActionExecutionDB(id=bson.ObjectId(), result={'stdout': 'stuff happens'})
+ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0]
+NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1]
+MOCK_EXECUTION = ActionExecutionDB(
+ id=bson.ObjectId(), result={"stdout": "stuff happens"}
+)
class NotifierTestCase(CleanDbTestCase):
-
class MockDispatcher(object):
def __init__(self, tester):
self.tester = tester
self.notify_trigger = ResourceReference.to_string_reference(
- pack=NOTIFY_TRIGGER_TYPE['pack'],
- name=NOTIFY_TRIGGER_TYPE['name'])
+ pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"]
+ )
self.action_trigger = ResourceReference.to_string_reference(
- pack=ACTION_TRIGGER_TYPE['pack'],
- name=ACTION_TRIGGER_TYPE['name'])
+ pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"]
+ )
def dispatch(self, *args, **kwargs):
try:
self.tester.assertEqual(len(args), 1)
- self.tester.assertTrue('payload' in kwargs)
- payload = kwargs['payload']
+ self.tester.assertTrue("payload" in kwargs)
+ payload = kwargs["payload"]
if args[0] == self.notify_trigger:
- self.tester.assertEqual(payload['status'], 'succeeded')
- self.tester.assertTrue('execution_id' in payload)
- self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id))
- self.tester.assertTrue('start_timestamp' in payload)
- self.tester.assertTrue('end_timestamp' in payload)
- self.tester.assertEqual('core.local', payload['action_ref'])
- self.tester.assertEqual('Action succeeded.', payload['message'])
- self.tester.assertTrue('data' in payload)
- self.tester.assertTrue('local-shell-cmd', payload['runner_ref'])
+ self.tester.assertEqual(payload["status"], "succeeded")
+ self.tester.assertTrue("execution_id" in payload)
+ self.tester.assertEqual(
+ payload["execution_id"], str(MOCK_EXECUTION.id)
+ )
+ self.tester.assertTrue("start_timestamp" in payload)
+ self.tester.assertTrue("end_timestamp" in payload)
+ self.tester.assertEqual("core.local", payload["action_ref"])
+ self.tester.assertEqual("Action succeeded.", payload["message"])
+ self.tester.assertTrue("data" in payload)
+ self.tester.assertTrue("local-shell-cmd", payload["runner_ref"])
if args[0] == self.action_trigger:
- self.tester.assertEqual(payload['status'], 'succeeded')
- self.tester.assertTrue('execution_id' in payload)
- self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id))
- self.tester.assertTrue('start_timestamp' in payload)
- self.tester.assertEqual('core.local', payload['action_name'])
- self.tester.assertEqual('core.local', payload['action_ref'])
- self.tester.assertTrue('result' in payload)
- self.tester.assertTrue('parameters' in payload)
- self.tester.assertTrue('local-shell-cmd', payload['runner_ref'])
+ self.tester.assertEqual(payload["status"], "succeeded")
+ self.tester.assertTrue("execution_id" in payload)
+ self.tester.assertEqual(
+ payload["execution_id"], str(MOCK_EXECUTION.id)
+ )
+ self.tester.assertTrue("start_timestamp" in payload)
+ self.tester.assertEqual("core.local", payload["action_name"])
+ self.tester.assertEqual("core.local", payload["action_ref"])
+ self.tester.assertTrue("result" in payload)
+ self.tester.assertTrue("parameters" in payload)
+ self.tester.assertTrue("local-shell-cmd", payload["runner_ref"])
except Exception:
- self.tester.fail('Test failed')
-
- @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock(
- return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'},
- parameters={})))
- @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock(
- return_value=RunnerTypeDB(name='foo', runner_parameters={})))
- @mock.patch.object(Action, 'get_by_ref', mock.MagicMock(
- return_value={'runner_type': {'name': 'local-shell-cmd'}}))
- @mock.patch.object(Policy, 'query', mock.MagicMock(
- return_value=[]))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={}))
+ self.tester.fail("Test failed")
+
+ @mock.patch(
+ "st2common.util.action_db.get_action_by_ref",
+ mock.MagicMock(
+ return_value=ActionDB(
+ pack="core",
+ name="local",
+ runner_type={"name": "local-shell-cmd"},
+ parameters={},
+ )
+ ),
+ )
+ @mock.patch(
+ "st2common.util.action_db.get_runnertype_by_name",
+ mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})),
+ )
+ @mock.patch.object(
+ Action,
+ "get_by_ref",
+ mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}),
+ )
+ @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[]))
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
def test_notify_triggers(self):
- liveaction_db = LiveActionDB(action='core.local')
+ liveaction_db = LiveActionDB(action="core.local")
liveaction_db.id = bson.ObjectId()
- liveaction_db.description = ''
- liveaction_db.status = 'succeeded'
+ liveaction_db.description = ""
+ liveaction_db.status = "succeeded"
liveaction_db.parameters = {}
- on_success = NotificationSubSchema(message='Action succeeded.')
- on_failure = NotificationSubSchema(message='Action failed.')
- liveaction_db.notify = NotificationSchema(on_success=on_success,
- on_failure=on_failure)
+ on_success = NotificationSubSchema(message="Action succeeded.")
+ on_failure = NotificationSubSchema(message="Action failed.")
+ liveaction_db.notify = NotificationSchema(
+ on_success=on_success, on_failure=on_failure
+ )
liveaction_db.start_timestamp = date_utils.get_datetime_utc_now()
- liveaction_db.end_timestamp = \
- (liveaction_db.start_timestamp + datetime.timedelta(seconds=50))
+ liveaction_db.end_timestamp = (
+ liveaction_db.start_timestamp + datetime.timedelta(seconds=50)
+ )
LiveAction.add_or_update(liveaction_db)
execution = MOCK_EXECUTION
@@ -122,26 +142,39 @@ def test_notify_triggers(self):
notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher)
notifier.process(execution)
- @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock(
- return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'},
- parameters={})))
- @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock(
- return_value=RunnerTypeDB(name='foo', runner_parameters={})))
- @mock.patch.object(Action, 'get_by_ref', mock.MagicMock(
- return_value={'runner_type': {'name': 'local-shell-cmd'}}))
- @mock.patch.object(Policy, 'query', mock.MagicMock(
- return_value=[]))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={}))
+ @mock.patch(
+ "st2common.util.action_db.get_action_by_ref",
+ mock.MagicMock(
+ return_value=ActionDB(
+ pack="core",
+ name="local",
+ runner_type={"name": "local-shell-cmd"},
+ parameters={},
+ )
+ ),
+ )
+ @mock.patch(
+ "st2common.util.action_db.get_runnertype_by_name",
+ mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})),
+ )
+ @mock.patch.object(
+ Action,
+ "get_by_ref",
+ mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}),
+ )
+ @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[]))
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
def test_notify_triggers_end_timestamp_none(self):
- liveaction_db = LiveActionDB(action='core.local')
+ liveaction_db = LiveActionDB(action="core.local")
liveaction_db.id = bson.ObjectId()
- liveaction_db.description = ''
- liveaction_db.status = 'succeeded'
+ liveaction_db.description = ""
+ liveaction_db.status = "succeeded"
liveaction_db.parameters = {}
- on_success = NotificationSubSchema(message='Action succeeded.')
- on_failure = NotificationSubSchema(message='Action failed.')
- liveaction_db.notify = NotificationSchema(on_success=on_success,
- on_failure=on_failure)
+ on_success = NotificationSubSchema(message="Action succeeded.")
+ on_failure = NotificationSubSchema(message="Action failed.")
+ liveaction_db.notify = NotificationSchema(
+ on_success=on_success, on_failure=on_failure
+ )
liveaction_db.start_timestamp = date_utils.get_datetime_utc_now()
# This tests for end_timestamp being set to None, which can happen when a policy cancels
@@ -159,30 +192,48 @@ def test_notify_triggers_end_timestamp_none(self):
notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher)
notifier.process(execution)
- @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock(
- return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'})))
- @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock(
- return_value=RunnerTypeDB(name='foo', runner_parameters={'runner_foo': 'foo'})))
- @mock.patch.object(Action, 'get_by_ref', mock.MagicMock(
- return_value={'runner_type': {'name': 'local-shell-cmd'}}))
- @mock.patch.object(Policy, 'query', mock.MagicMock(
- return_value=[]))
- @mock.patch.object(Notifier, '_post_generic_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={}))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch(
+ "st2common.util.action_db.get_action_by_ref",
+ mock.MagicMock(
+ return_value=ActionDB(
+ pack="core", name="local", runner_type={"name": "local-shell-cmd"}
+ )
+ ),
+ )
+ @mock.patch(
+ "st2common.util.action_db.get_runnertype_by_name",
+ mock.MagicMock(
+ return_value=RunnerTypeDB(
+ name="foo", runner_parameters={"runner_foo": "foo"}
+ )
+ ),
+ )
+ @mock.patch.object(
+ Action,
+ "get_by_ref",
+ mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}),
+ )
+ @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[]))
+ @mock.patch.object(
+ Notifier, "_post_generic_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_notify_triggers_jinja_patterns(self, dispatch):
- liveaction_db = LiveActionDB(action='core.local')
+ liveaction_db = LiveActionDB(action="core.local")
liveaction_db.id = bson.ObjectId()
- liveaction_db.description = ''
- liveaction_db.status = 'succeeded'
- liveaction_db.parameters = {'cmd': 'mamma mia', 'runner_foo': 'foo'}
- on_success = NotificationSubSchema(message='Command {{action_parameters.cmd}} succeeded.',
- data={'stdout': '{{action_results.stdout}}'})
+ liveaction_db.description = ""
+ liveaction_db.status = "succeeded"
+ liveaction_db.parameters = {"cmd": "mamma mia", "runner_foo": "foo"}
+ on_success = NotificationSubSchema(
+ message="Command {{action_parameters.cmd}} succeeded.",
+ data={"stdout": "{{action_results.stdout}}"},
+ )
liveaction_db.notify = NotificationSchema(on_success=on_success)
liveaction_db.start_timestamp = date_utils.get_datetime_utc_now()
- liveaction_db.end_timestamp = \
- (liveaction_db.start_timestamp + datetime.timedelta(seconds=50))
+ liveaction_db.end_timestamp = (
+ liveaction_db.start_timestamp + datetime.timedelta(seconds=50)
+ )
LiveAction.add_or_update(liveaction_db)
@@ -192,26 +243,31 @@ def test_notify_triggers_jinja_patterns(self, dispatch):
notifier = Notifier(connection=None, queues=[])
notifier.process(execution)
- exp = {'status': 'succeeded',
- 'start_timestamp': isotime.format(liveaction_db.start_timestamp),
- 'route': 'notify.default', 'runner_ref': 'local-shell-cmd',
- 'channel': 'notify.default', 'message': u'Command mamma mia succeeded.',
- 'data': {'result': '{}', 'stdout': 'stuff happens'},
- 'action_ref': u'core.local',
- 'execution_id': str(MOCK_EXECUTION.id),
- 'end_timestamp': isotime.format(liveaction_db.end_timestamp)}
- dispatch.assert_called_once_with('core.st2.generic.notifytrigger', payload=exp,
- trace_context={})
+ exp = {
+ "status": "succeeded",
+ "start_timestamp": isotime.format(liveaction_db.start_timestamp),
+ "route": "notify.default",
+ "runner_ref": "local-shell-cmd",
+ "channel": "notify.default",
+ "message": "Command mamma mia succeeded.",
+ "data": {"result": "{}", "stdout": "stuff happens"},
+ "action_ref": "core.local",
+ "execution_id": str(MOCK_EXECUTION.id),
+ "end_timestamp": isotime.format(liveaction_db.end_timestamp),
+ }
+ dispatch.assert_called_once_with(
+ "core.st2.generic.notifytrigger", payload=exp, trace_context={}
+ )
notifier.process(execution)
- @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock(
- return_value='local-shell-cmd'))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(
- return_value={}))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch.object(
+ Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd")
+ )
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch):
for status in LIVEACTION_STATUSES:
- liveaction_db = LiveActionDB(action='core.local')
+ liveaction_db = LiveActionDB(action="core.local")
liveaction_db.status = status
execution = MOCK_EXECUTION
execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db))
@@ -221,28 +277,34 @@ def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch):
notifier._post_generic_trigger(liveaction_db, execution)
if status in LIVEACTION_COMPLETED_STATES:
- exp = {'status': status,
- 'start_timestamp': str(liveaction_db.start_timestamp),
- 'result': {}, 'parameters': {},
- 'action_ref': u'core.local',
- 'runner_ref': 'local-shell-cmd',
- 'execution_id': str(MOCK_EXECUTION.id),
- 'action_name': u'core.local'}
- dispatch.assert_called_with('core.st2.generic.actiontrigger',
- payload=exp, trace_context={})
+ exp = {
+ "status": status,
+ "start_timestamp": str(liveaction_db.start_timestamp),
+ "result": {},
+ "parameters": {},
+ "action_ref": "core.local",
+ "runner_ref": "local-shell-cmd",
+ "execution_id": str(MOCK_EXECUTION.id),
+ "action_name": "core.local",
+ }
+ dispatch.assert_called_with(
+ "core.st2.generic.actiontrigger", payload=exp, trace_context={}
+ )
self.assertEqual(dispatch.call_count, len(LIVEACTION_COMPLETED_STATES))
- @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock(
- emit_when=['scheduled', 'pending', 'abandoned']))
- @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock(
- return_value='local-shell-cmd'))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(
- return_value={}))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch(
+ "oslo_config.cfg.CONF.action_sensor",
+ mock.MagicMock(emit_when=["scheduled", "pending", "abandoned"]),
+ )
+ @mock.patch.object(
+ Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd")
+ )
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_post_generic_trigger_with_emit_condition(self, dispatch):
for status in LIVEACTION_STATUSES:
- liveaction_db = LiveActionDB(action='core.local')
+ liveaction_db = LiveActionDB(action="core.local")
liveaction_db.status = status
execution = MOCK_EXECUTION
execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db))
@@ -251,36 +313,45 @@ def test_post_generic_trigger_with_emit_condition(self, dispatch):
notifier = Notifier(connection=None, queues=[])
notifier._post_generic_trigger(liveaction_db, execution)
- if status in ['scheduled', 'pending', 'abandoned']:
- exp = {'status': status,
- 'start_timestamp': str(liveaction_db.start_timestamp),
- 'result': {}, 'parameters': {},
- 'action_ref': u'core.local',
- 'runner_ref': 'local-shell-cmd',
- 'execution_id': str(MOCK_EXECUTION.id),
- 'action_name': u'core.local'}
- dispatch.assert_called_with('core.st2.generic.actiontrigger',
- payload=exp, trace_context={})
+ if status in ["scheduled", "pending", "abandoned"]:
+ exp = {
+ "status": status,
+ "start_timestamp": str(liveaction_db.start_timestamp),
+ "result": {},
+ "parameters": {},
+ "action_ref": "core.local",
+ "runner_ref": "local-shell-cmd",
+ "execution_id": str(MOCK_EXECUTION.id),
+ "action_name": "core.local",
+ }
+ dispatch.assert_called_with(
+ "core.st2.generic.actiontrigger", payload=exp, trace_context={}
+ )
self.assertEqual(dispatch.call_count, 3)
- @mock.patch('oslo_config.cfg.CONF.action_sensor.enable', mock.MagicMock(
- return_value=True))
- @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock(
- return_value='local-shell-cmd'))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(
- return_value={}))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
- @mock.patch('st2actions.notifier.notifier.LiveAction')
- @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock())
- def test_process_post_generic_notify_trigger_on_completed_state_default(self,
- mock_LiveAction, mock_dispatch):
+ @mock.patch(
+ "oslo_config.cfg.CONF.action_sensor.enable", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd")
+ )
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
+ @mock.patch("st2actions.notifier.notifier.LiveAction")
+ @mock.patch(
+ "st2actions.notifier.notifier.policy_service.apply_post_run_policies",
+ mock.Mock(),
+ )
+ def test_process_post_generic_notify_trigger_on_completed_state_default(
+ self, mock_LiveAction, mock_dispatch
+ ):
# Verify that generic action trigger is posted on all completed states when action sensor
# is enabled
for status in LIVEACTION_STATUSES:
notifier = Notifier(connection=None, queues=[])
- liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local')
+ liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local")
liveaction_db.status = status
execution = MOCK_EXECUTION
execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db))
@@ -292,35 +363,45 @@ def test_process_post_generic_notify_trigger_on_completed_state_default(self,
notifier.process(execution)
if status in LIVEACTION_COMPLETED_STATES:
- exp = {'status': status,
- 'start_timestamp': str(liveaction_db.start_timestamp),
- 'result': {}, 'parameters': {},
- 'action_ref': u'core.local',
- 'runner_ref': 'local-shell-cmd',
- 'execution_id': str(MOCK_EXECUTION.id),
- 'action_name': u'core.local'}
- mock_dispatch.assert_called_with('core.st2.generic.actiontrigger',
- payload=exp, trace_context={})
+ exp = {
+ "status": status,
+ "start_timestamp": str(liveaction_db.start_timestamp),
+ "result": {},
+ "parameters": {},
+ "action_ref": "core.local",
+ "runner_ref": "local-shell-cmd",
+ "execution_id": str(MOCK_EXECUTION.id),
+ "action_name": "core.local",
+ }
+ mock_dispatch.assert_called_with(
+ "core.st2.generic.actiontrigger", payload=exp, trace_context={}
+ )
self.assertEqual(mock_dispatch.call_count, len(LIVEACTION_COMPLETED_STATES))
- @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock(
- enable=True, emit_when=['scheduled', 'pending', 'abandoned']))
- @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock(
- return_value='local-shell-cmd'))
- @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(
- return_value={}))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
- @mock.patch('st2actions.notifier.notifier.LiveAction')
- @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock())
- def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self,
- mock_LiveAction, mock_dispatch):
+ @mock.patch(
+ "oslo_config.cfg.CONF.action_sensor",
+ mock.MagicMock(enable=True, emit_when=["scheduled", "pending", "abandoned"]),
+ )
+ @mock.patch.object(
+ Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd")
+ )
+ @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={}))
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
+ @mock.patch("st2actions.notifier.notifier.LiveAction")
+ @mock.patch(
+ "st2actions.notifier.notifier.policy_service.apply_post_run_policies",
+ mock.Mock(),
+ )
+ def test_process_post_generic_notify_trigger_on_custom_emit_when_states(
+ self, mock_LiveAction, mock_dispatch
+ ):
# Verify that generic action trigger is posted on all completed states when action sensor
# is enabled
for status in LIVEACTION_STATUSES:
notifier = Notifier(connection=None, queues=[])
- liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local')
+ liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local")
liveaction_db.status = status
execution = MOCK_EXECUTION
execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db))
@@ -331,15 +412,19 @@ def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self,
notifier = Notifier(connection=None, queues=[])
notifier.process(execution)
- if status in ['scheduled', 'pending', 'abandoned']:
- exp = {'status': status,
- 'start_timestamp': str(liveaction_db.start_timestamp),
- 'result': {}, 'parameters': {},
- 'action_ref': u'core.local',
- 'runner_ref': 'local-shell-cmd',
- 'execution_id': str(MOCK_EXECUTION.id),
- 'action_name': u'core.local'}
- mock_dispatch.assert_called_with('core.st2.generic.actiontrigger',
- payload=exp, trace_context={})
+ if status in ["scheduled", "pending", "abandoned"]:
+ exp = {
+ "status": status,
+ "start_timestamp": str(liveaction_db.start_timestamp),
+ "result": {},
+ "parameters": {},
+ "action_ref": "core.local",
+ "runner_ref": "local-shell-cmd",
+ "execution_id": str(MOCK_EXECUTION.id),
+ "action_name": "core.local",
+ }
+ mock_dispatch.assert_called_with(
+ "core.st2.generic.actiontrigger", payload=exp, trace_context={}
+ )
self.assertEqual(mock_dispatch.call_count, 3)
diff --git a/st2actions/tests/unit/test_parallel_ssh.py b/st2actions/tests/unit/test_parallel_ssh.py
index bf8c1df87b..67052a53e0 100644
--- a/st2actions/tests/unit/test_parallel_ssh.py
+++ b/st2actions/tests/unit/test_parallel_ssh.py
@@ -17,13 +17,14 @@
import json
import os
-from mock import (patch, Mock, MagicMock)
+from mock import patch, Mock, MagicMock
import unittest2
from st2common.runners.parallel_ssh import ParallelSSHClient
from st2common.runners.paramiko_ssh import ParamikoSSHClient
from st2common.runners.paramiko_ssh import SSHCommandTimeoutError
import st2tests.config as tests_config
+
tests_config.parse_args()
MOCK_STDERR_SUDO_PASSWORD_ERROR = """
@@ -35,251 +36,294 @@
class ParallelSSHTests(unittest2.TestCase):
-
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_connect_with_password(self):
- hosts = ['localhost', '127.0.0.1']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- password='ubuntu',
- connect=False)
+ hosts = ["localhost", "127.0.0.1"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", password="ubuntu", connect=False
+ )
client.connect()
expected_conn = {
- 'allow_agent': False,
- 'look_for_keys': False,
- 'password': 'ubuntu',
- 'username': 'ubuntu',
- 'timeout': 60,
- 'port': 22
+ "allow_agent": False,
+ "look_for_keys": False,
+ "password": "ubuntu",
+ "username": "ubuntu",
+ "timeout": 60,
+ "port": 22,
}
for host in hosts:
- expected_conn['hostname'] = host
- client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn)
+ expected_conn["hostname"] = host
+ client._hosts_client[host].client.connect.assert_called_once_with(
+ **expected_conn
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_connect_with_random_ports(self):
- hosts = ['localhost:22', '127.0.0.1:55', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- password='ubuntu',
- connect=False)
+ hosts = ["localhost:22", "127.0.0.1:55", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", password="ubuntu", connect=False
+ )
client.connect()
expected_conn = {
- 'allow_agent': False,
- 'look_for_keys': False,
- 'password': 'ubuntu',
- 'username': 'ubuntu',
- 'timeout': 60,
- 'port': 22
+ "allow_agent": False,
+ "look_for_keys": False,
+ "password": "ubuntu",
+ "username": "ubuntu",
+ "timeout": 60,
+ "port": 22,
}
for host in hosts:
hostname, port = client._get_host_port_info(host)
- expected_conn['hostname'] = hostname
- expected_conn['port'] = port
- client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
+ expected_conn["hostname"] = hostname
+ expected_conn["port"] = port
+ client._hosts_client[hostname].client.connect.assert_called_once_with(
+ **expected_conn
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_connect_with_key(self):
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=False)
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=False
+ )
client.connect()
expected_conn = {
- 'allow_agent': False,
- 'look_for_keys': False,
- 'key_filename': '~/.ssh/id_rsa',
- 'username': 'ubuntu',
- 'timeout': 60,
- 'port': 22
+ "allow_agent": False,
+ "look_for_keys": False,
+ "key_filename": "~/.ssh/id_rsa",
+ "username": "ubuntu",
+ "timeout": 60,
+ "port": 22,
}
for host in hosts:
hostname, port = client._get_host_port_info(host)
- expected_conn['hostname'] = hostname
- expected_conn['port'] = port
- client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn)
+ expected_conn["hostname"] = hostname
+ expected_conn["port"] = port
+ client._hosts_client[hostname].client.connect.assert_called_once_with(
+ **expected_conn
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_connect_with_bastion(self):
- hosts = ['localhost', '127.0.0.1']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- bastion_host='testing_bastion_host',
- connect=False)
+ hosts = ["localhost", "127.0.0.1"]
+ client = ParallelSSHClient(
+ hosts=hosts,
+ user="ubuntu",
+ pkey_file="~/.ssh/id_rsa",
+ bastion_host="testing_bastion_host",
+ connect=False,
+ )
client.connect()
for host in hosts:
hostname, _ = client._get_host_port_info(host)
- self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host')
+ self.assertEqual(
+ client._hosts_client[hostname].bastion_host, "testing_bastion_host"
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'run', MagicMock(return_value=('/home/ubuntu', '', 0)))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "run", MagicMock(return_value=("/home/ubuntu", "", 0))
+ )
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_run_command(self):
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- client.run('pwd', timeout=60)
- expected_kwargs = {
- 'timeout': 60,
- 'call_line_handler_func': True
- }
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ client.run("pwd", timeout=60)
+ expected_kwargs = {"timeout": 60, "call_line_handler_func": True}
for host in hosts:
hostname, _ = client._get_host_port_info(host)
- client._hosts_client[hostname].run.assert_called_with('pwd', **expected_kwargs)
+ client._hosts_client[hostname].run.assert_called_with(
+ "pwd", **expected_kwargs
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_run_command_timeout(self):
# Make sure stdout and stderr is included on timeout
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10,
- stdout='a',
- stderr='b',
- ssh_connect_timeout=30))
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ mock_run = Mock(
+ side_effect=SSHCommandTimeoutError(
+ cmd="pwd", timeout=10, stdout="a", stderr="b", ssh_connect_timeout=30
+ )
+ )
for host in hosts:
hostname, _ = client._get_host_port_info(host)
host_client = client._hosts_client[host]
host_client.run = mock_run
- results = client.run('pwd')
+ results = client.run("pwd")
for host in hosts:
result = results[host]
- self.assertEqual(result['failed'], True)
- self.assertEqual(result['stdout'], 'a')
- self.assertEqual(result['stderr'], 'b')
- self.assertEqual(result['return_code'], -9)
+ self.assertEqual(result["failed"], True)
+ self.assertEqual(result["stdout"], "a")
+ self.assertEqual(result["stderr"], "b")
+ self.assertEqual(result["return_code"], -9)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'put', MagicMock(return_value={}))
- @patch.object(os.path, 'exists', MagicMock(return_value=True))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(ParamikoSSHClient, "put", MagicMock(return_value={}))
+ @patch.object(os.path, "exists", MagicMock(return_value=True))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_put(self):
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- client.put('/local/stuff', '/remote/stuff', mode=0o744)
- expected_kwargs = {
- 'mode': 0o744,
- 'mirror_local_mode': False
- }
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ client.put("/local/stuff", "/remote/stuff", mode=0o744)
+ expected_kwargs = {"mode": 0o744, "mirror_local_mode": False}
for host in hosts:
hostname, _ = client._get_host_port_info(host)
- client._hosts_client[hostname].put.assert_called_with('/local/stuff', '/remote/stuff',
- **expected_kwargs)
+ client._hosts_client[hostname].put.assert_called_with(
+ "/local/stuff", "/remote/stuff", **expected_kwargs
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'delete_file', MagicMock(return_value={}))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(ParamikoSSHClient, "delete_file", MagicMock(return_value={}))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_delete_file(self):
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- client.delete_file('/remote/stuff')
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ client.delete_file("/remote/stuff")
for host in hosts:
hostname, _ = client._get_host_port_info(host)
- client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff')
+ client._hosts_client[hostname].delete_file.assert_called_with(
+ "/remote/stuff"
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'delete_dir', MagicMock(return_value={}))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(ParamikoSSHClient, "delete_dir", MagicMock(return_value={}))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_delete_dir(self):
- hosts = ['localhost', '127.0.0.1', 'st2build001']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- client.delete_dir('/remote/stuff/', force=True)
- expected_kwargs = {
- 'force': True,
- 'timeout': None
- }
+ hosts = ["localhost", "127.0.0.1", "st2build001"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ client.delete_dir("/remote/stuff/", force=True)
+ expected_kwargs = {"force": True, "timeout": None}
for host in hosts:
hostname, _ = client._get_host_port_info(host)
- client._hosts_client[hostname].delete_dir.assert_called_with('/remote/stuff/',
- **expected_kwargs)
+ client._hosts_client[hostname].delete_dir.assert_called_with(
+ "/remote/stuff/", **expected_kwargs
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_host_port_info(self):
- client = ParallelSSHClient(hosts=['dummy'],
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
+ client = ParallelSSHClient(
+ hosts=["dummy"], user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
# No port case. Port should be 22.
- host_str = '1.2.3.4'
+ host_str = "1.2.3.4"
host, port = client._get_host_port_info(host_str)
self.assertEqual(host, host_str)
self.assertEqual(port, 22)
# IPv6 with square brackets with port specified.
- host_str = '[fec2::10]:55'
+ host_str = "[fec2::10]:55"
host, port = client._get_host_port_info(host_str)
- self.assertEqual(host, 'fec2::10')
+ self.assertEqual(host, "fec2::10")
self.assertEqual(port, 55)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'run', MagicMock(
- return_value=(json.dumps({'foo': 'bar'}), '', 0))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "run",
+ MagicMock(return_value=(json.dumps({"foo": "bar"}), "", 0)),
+ )
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
def test_run_command_json_output_transformed_to_object(self):
- hosts = ['127.0.0.1']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True)
- results = client.run('stuff', timeout=60)
- self.assertIn('127.0.0.1', results)
- self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'})
+ hosts = ["127.0.0.1"]
+ client = ParallelSSHClient(
+ hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True
+ )
+ results = client.run("stuff", timeout=60)
+ self.assertIn("127.0.0.1", results)
+ self.assertDictEqual(results["127.0.0.1"]["stdout"], {"foo": "bar"})
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, 'run', MagicMock(
- return_value=('', MOCK_STDERR_SUDO_PASSWORD_ERROR, 0))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "run",
+ MagicMock(return_value=("", MOCK_STDERR_SUDO_PASSWORD_ERROR, 0)),
+ )
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
def test_run_sudo_password_user_friendly_error(self):
- hosts = ['127.0.0.1']
- client = ParallelSSHClient(hosts=hosts,
- user='ubuntu',
- pkey_file='~/.ssh/id_rsa',
- connect=True,
- sudo_password=True)
- results = client.run('stuff', timeout=60)
+ hosts = ["127.0.0.1"]
+ client = ParallelSSHClient(
+ hosts=hosts,
+ user="ubuntu",
+ pkey_file="~/.ssh/id_rsa",
+ connect=True,
+ sudo_password=True,
+ )
+ results = client.run("stuff", timeout=60)
- expected_error = ('Failed executing command "stuff" on host "127.0.0.1" '
- 'Invalid sudo password provided or sudo is not configured for '
- 'this user (bar)')
+ expected_error = (
+ 'Failed executing command "stuff" on host "127.0.0.1" '
+ "Invalid sudo password provided or sudo is not configured for "
+ "this user (bar)"
+ )
- self.assertIn('127.0.0.1', results)
- self.assertEqual(results['127.0.0.1']['succeeded'], False)
- self.assertEqual(results['127.0.0.1']['failed'], True)
- self.assertIn(expected_error, results['127.0.0.1']['error'])
+ self.assertIn("127.0.0.1", results)
+ self.assertEqual(results["127.0.0.1"]["succeeded"], False)
+ self.assertEqual(results["127.0.0.1"]["failed"], True)
+ self.assertIn(expected_error, results["127.0.0.1"]["error"])
diff --git a/st2actions/tests/unit/test_paramiko_remote_script_runner.py b/st2actions/tests/unit/test_paramiko_remote_script_runner.py
index 1246f1cbe2..1bf67a9503 100644
--- a/st2actions/tests/unit/test_paramiko_remote_script_runner.py
+++ b/st2actions/tests/unit/test_paramiko_remote_script_runner.py
@@ -21,6 +21,7 @@
# XXX: There is an import dependency. Config needs to setup
# before importing remote_script_runner classes.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.util import jsonify
@@ -35,234 +36,254 @@
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'ParamikoScriptRunnerTestCase'
-]
+__all__ = ["ParamikoScriptRunnerTestCase"]
-FIXTURES_PACK = 'generic'
-TEST_MODELS = {
- 'actions': ['a1.yaml']
-}
+FIXTURES_PACK = "generic"
+TEST_MODELS = {"actions": ["a1.yaml"]}
-MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
-ACTION_1 = MODELS['actions']['a1.yaml']
+MODELS = FixturesLoader().load_models(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+)
+ACTION_1 = MODELS["actions"]["a1.yaml"]
class ParamikoScriptRunnerTestCase(unittest2.TestCase):
- @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock)
- @patch.object(jsonify, 'json_loads', MagicMock(return_value={}))
- @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={}))
- @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={}))
+ @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock)
+ @patch.object(jsonify, "json_loads", MagicMock(return_value={}))
+ @patch.object(ParallelSSHClient, "run", MagicMock(return_value={}))
+ @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={}))
def test_cwd_used_correctly(self):
remote_action = ParamikoRemoteScriptAction(
- 'foo-script', bson.ObjectId(),
- script_local_path_abs='/home/stanley/shiz_storm.py',
+ "foo-script",
+ bson.ObjectId(),
+ script_local_path_abs="/home/stanley/shiz_storm.py",
script_local_libs_path_abs=None,
- named_args={}, positional_args=['blank space'], env_vars={},
- on_behalf_user='svetlana', user='stanley',
- private_key='---SOME RSA KEY---',
- remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/'
+ named_args={},
+ positional_args=["blank space"],
+ env_vars={},
+ on_behalf_user="svetlana",
+ user="stanley",
+ private_key="---SOME RSA KEY---",
+ remote_dir="/tmp",
+ hosts=["127.0.0.1"],
+ cwd="/test/cwd/",
+ )
+ paramiko_runner = ParamikoRemoteScriptRunner("runner_1")
+ paramiko_runner._parallel_ssh_client = ParallelSSHClient(
+ ["127.0.0.1"], "stanley"
)
- paramiko_runner = ParamikoRemoteScriptRunner('runner_1')
- paramiko_runner._parallel_ssh_client = ParallelSSHClient(['127.0.0.1'], 'stanley')
paramiko_runner._run_script_on_remote_host(remote_action)
exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'"
- ParallelSSHClient.run.assert_called_with(exp_cmd,
- timeout=None)
+ ParallelSSHClient.run.assert_called_with(exp_cmd, timeout=None)
def test_username_invalid_private_key(self):
- paramiko_runner = ParamikoRemoteScriptRunner('runner_1')
+ paramiko_runner = ParamikoRemoteScriptRunner("runner_1")
paramiko_runner.runner_parameters = {
- 'username': 'test_user',
- 'hosts': '127.0.0.1',
- 'private_key': 'invalid private key',
+ "username": "test_user",
+ "hosts": "127.0.0.1",
+ "private_key": "invalid private key",
}
paramiko_runner.context = {}
self.assertRaises(NoHostsConnectedToException, paramiko_runner.pre_run)
- @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock)
- @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={}))
- @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={}))
+ @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock)
+ @patch.object(ParallelSSHClient, "run", MagicMock(return_value={}))
+ @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={}))
def test_top_level_error_is_correctly_reported(self):
# Verify that a top-level error doesn't cause an exception to be thrown.
# In a top-level error case, result dict doesn't contain entry per host
- paramiko_runner = ParamikoRemoteScriptRunner('runner_1')
+ paramiko_runner = ParamikoRemoteScriptRunner("runner_1")
paramiko_runner.runner_parameters = {
- 'username': 'test_user',
- 'hosts': '127.0.0.1'
+ "username": "test_user",
+ "hosts": "127.0.0.1",
}
paramiko_runner.action = ACTION_1
- paramiko_runner.liveaction_id = 'foo'
- paramiko_runner.entry_point = 'foo'
+ paramiko_runner.liveaction_id = "foo"
+ paramiko_runner.entry_point = "foo"
paramiko_runner.context = {}
- paramiko_runner._cwd = '/tmp'
- paramiko_runner._copy_artifacts = Mock(side_effect=Exception('fail!'))
+ paramiko_runner._cwd = "/tmp"
+ paramiko_runner._copy_artifacts = Mock(side_effect=Exception("fail!"))
status, result, _ = paramiko_runner.run(action_parameters={})
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
- self.assertEqual(result['failed'], True)
- self.assertEqual(result['succeeded'], False)
- self.assertIn('Failed copying content to remote boxes', result['error'])
+ self.assertEqual(result["failed"], True)
+ self.assertEqual(result["succeeded"], False)
+ self.assertIn("Failed copying content to remote boxes", result["error"])
def test_command_construction_correct_default_parameter_values_are_used(self):
runner_parameters = {}
action_db_parameters = {
- 'project': {
- 'type': 'string',
- 'default': 'st2',
- 'position': 0,
- },
- 'version': {
- 'type': 'string',
- 'position': 1,
- 'required': True
+ "project": {
+ "type": "string",
+ "default": "st2",
+ "position": 0,
},
- 'fork': {
- 'type': 'string',
- 'position': 2,
- 'default': 'StackStorm',
+ "version": {"type": "string", "position": 1, "required": True},
+ "fork": {
+ "type": "string",
+ "position": 2,
+ "default": "StackStorm",
},
- 'branch': {
- 'type': 'string',
- 'position': 3,
- 'default': 'master',
+ "branch": {
+ "type": "string",
+ "position": 3,
+ "default": "master",
},
- 'update_changelog': {
- 'type': 'boolean',
- 'position': 4,
- 'default': False
+ "update_changelog": {"type": "boolean", "position": 4, "default": False},
+ "local_repo": {
+ "type": "string",
+ "position": 5,
},
- 'local_repo': {
- 'type': 'string',
- 'position': 5,
- }
}
context = {}
- action_db = ActionDB(pack='dummy', name='action')
+ action_db = ActionDB(pack="dummy", name="action")
- runner = ParamikoRemoteScriptRunner('id')
+ runner = ParamikoRemoteScriptRunner("id")
runner.runner_parameters = {}
runner.action = action_db
# 1. All default values used
live_action_db_parameters = {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'local_repo': '/tmp/repo'
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "local_repo": "/tmp/repo",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
- self.assertDictEqual(action_params, {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repo'
- })
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repo",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
remote_action = ParamikoRemoteScriptAction(
- 'foo-script', 'id',
- script_local_path_abs='/tmp/script.sh',
+ "foo-script",
+ "id",
+ script_local_path_abs="/tmp/script.sh",
script_local_libs_path_abs=None,
- named_args=named_args, positional_args=positional_args, env_vars={},
- on_behalf_user='svetlana', user='stanley',
- remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/'
+ named_args=named_args,
+ positional_args=positional_args,
+ env_vars={},
+ on_behalf_user="svetlana",
+ user="stanley",
+ remote_dir="/tmp",
+ hosts=["127.0.0.1"],
+ cwd="/test/cwd/",
)
command_string = remote_action.get_full_command_string()
- expected = 'cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo'
+ expected = "cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo"
self.assertEqual(command_string, expected)
# 2. Some default values used
live_action_db_parameters = {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'update_changelog': True,
- 'local_repo': '/tmp/repob'
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "update_changelog": True,
+ "local_repo": "/tmp/repob",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
- self.assertDictEqual(action_params, {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'branch': 'master', # default value used
- 'update_changelog': True, # default value used
- 'local_repo': '/tmp/repob'
- })
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "branch": "master", # default value used
+ "update_changelog": True, # default value used
+ "local_repo": "/tmp/repob",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
remote_action = ParamikoRemoteScriptAction(
- 'foo-script', 'id',
- script_local_path_abs='/tmp/script.sh',
+ "foo-script",
+ "id",
+ script_local_path_abs="/tmp/script.sh",
script_local_libs_path_abs=None,
- named_args=named_args, positional_args=positional_args, env_vars={},
- on_behalf_user='svetlana', user='stanley',
- remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/'
+ named_args=named_args,
+ positional_args=positional_args,
+ env_vars={},
+ on_behalf_user="svetlana",
+ user="stanley",
+ remote_dir="/tmp",
+ hosts=["127.0.0.1"],
+ cwd="/test/cwd/",
)
command_string = remote_action.get_full_command_string()
- expected = 'cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob'
+ expected = "cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob"
self.assertEqual(command_string, expected)
# 3. None is specified for a boolean parameter, should use a default
live_action_db_parameters = {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'update_changelog': None,
- 'local_repo': '/tmp/repoc'
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "update_changelog": None,
+ "local_repo": "/tmp/repoc",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
- self.assertDictEqual(action_params, {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repoc'
- })
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repoc",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
remote_action = ParamikoRemoteScriptAction(
- 'foo-script', 'id',
- script_local_path_abs='/tmp/script.sh',
+ "foo-script",
+ "id",
+ script_local_path_abs="/tmp/script.sh",
script_local_libs_path_abs=None,
- named_args=named_args, positional_args=positional_args, env_vars={},
- on_behalf_user='svetlana', user='stanley',
- remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/'
+ named_args=named_args,
+ positional_args=positional_args,
+ env_vars={},
+ on_behalf_user="svetlana",
+ user="stanley",
+ remote_dir="/tmp",
+ hosts=["127.0.0.1"],
+ cwd="/test/cwd/",
)
command_string = remote_action.get_full_command_string()
- expected = 'cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc'
+ expected = "cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc"
self.assertEqual(command_string, expected)
diff --git a/st2actions/tests/unit/test_paramiko_ssh.py b/st2actions/tests/unit/test_paramiko_ssh.py
index 7335f11a7e..eadc4a477a 100644
--- a/st2actions/tests/unit/test_paramiko_ssh.py
+++ b/st2actions/tests/unit/test_paramiko_ssh.py
@@ -28,363 +28,456 @@
from st2common.runners.paramiko_ssh import ParamikoSSHClient
from st2tests.fixturesloader import get_resources_base_path
import st2tests.config as tests_config
+
tests_config.parse_args()
-__all__ = [
- 'ParamikoSSHClientTestCase'
-]
+__all__ = ["ParamikoSSHClientTestCase"]
class ParamikoSSHClientTestCase(unittest2.TestCase):
-
- @patch('paramiko.SSHClient', Mock)
+ @patch("paramiko.SSHClient", Mock)
def setUp(self):
"""
Creates the object patching the actual connection.
"""
- cfg.CONF.set_override(name='ssh_key_file', override=None, group='system_user')
- cfg.CONF.set_override(name='use_ssh_config', override=False, group='ssh_runner')
- cfg.CONF.set_override(name='ssh_connect_timeout', override=30, group='ssh_runner')
-
- conn_params = {'hostname': 'dummy.host.org',
- 'port': 8822,
- 'username': 'ubuntu',
- 'key_files': '~/.ssh/ubuntu_ssh',
- 'timeout': 30}
+ cfg.CONF.set_override(name="ssh_key_file", override=None, group="system_user")
+ cfg.CONF.set_override(name="use_ssh_config", override=False, group="ssh_runner")
+ cfg.CONF.set_override(
+ name="ssh_connect_timeout", override=30, group="ssh_runner"
+ )
+
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "port": 8822,
+ "username": "ubuntu",
+ "key_files": "~/.ssh/ubuntu_ssh",
+ "timeout": 30,
+ }
self.ssh_cli = ParamikoSSHClient(**conn_params)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
- @patch('paramiko.ProxyCommand')
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
+ @patch("paramiko.ProxyCommand")
def test_set_proxycommand(self, mock_ProxyCommand):
"""
Loads proxy commands from ssh config file
"""
- ssh_config_file_path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_ssh_config')
- cfg.CONF.set_override(name='ssh_config_file_path',
- override=ssh_config_file_path,
- group='ssh_runner')
- cfg.CONF.set_override(name='use_ssh_config', override=True,
- group='ssh_runner')
-
- conn_params = {'hostname': 'dummy.host.org', 'username': 'ubuntu', 'password': 'foo'}
+ ssh_config_file_path = os.path.join(
+ get_resources_base_path(), "ssh", "dummy_ssh_config"
+ )
+ cfg.CONF.set_override(
+ name="ssh_config_file_path",
+ override=ssh_config_file_path,
+ group="ssh_runner",
+ )
+ cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner")
+
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "foo",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- mock_ProxyCommand.assert_called_once_with('ssh -q -W dummy.host.org:22 dummy_bastion')
-
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
- @patch('paramiko.ProxyCommand')
+ mock_ProxyCommand.assert_called_once_with(
+ "ssh -q -W dummy.host.org:22 dummy_bastion"
+ )
+
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
+ @patch("paramiko.ProxyCommand")
def test_fail_set_proxycommand(self, mock_ProxyCommand):
"""
Loads proxy commands from ssh config file
"""
- ssh_config_file_path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_ssh_config_fail')
- cfg.CONF.set_override(name='ssh_config_file_path',
- override=ssh_config_file_path,
- group='ssh_runner')
- cfg.CONF.set_override(name='use_ssh_config',
- override=True, group='ssh_runner')
-
- conn_params = {'hostname': 'dummy.host.org'}
+ ssh_config_file_path = os.path.join(
+ get_resources_base_path(), "ssh", "dummy_ssh_config_fail"
+ )
+ cfg.CONF.set_override(
+ name="ssh_config_file_path",
+ override=ssh_config_file_path,
+ group="ssh_runner",
+ )
+ cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner")
+
+ conn_params = {"hostname": "dummy.host.org"}
mock = ParamikoSSHClient(**conn_params)
self.assertRaises(Exception, mock.connect)
mock_ProxyCommand.assert_not_called()
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_with_password(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'ubuntu'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "ubuntu",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'password': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "password": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_deprecated_key_argument(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_files': 'id_rsa'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_files": "id_rsa",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'key_filename': 'id_rsa',
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "key_filename": "id_rsa",
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
def test_key_files_and_key_material_arguments_are_mutual_exclusive(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_files': 'id_rsa',
- 'key_material': 'key'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_files": "id_rsa",
+ "key_material": "key",
+ }
- expected_msg = ('key_files and key_material arguments are mutually exclusive. '
- 'Supply only one.')
+ expected_msg = (
+ "key_files and key_material arguments are mutually exclusive. "
+ "Supply only one."
+ )
client = ParamikoSSHClient(**conn_params)
- self.assertRaisesRegexp(ValueError, expected_msg,
- client.connect)
+ self.assertRaisesRegexp(ValueError, expected_msg, client.connect)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_key_material_argument(self):
- path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_rsa')
+ path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa")
- with open(path, 'r') as fp:
+ with open(path, "r") as fp:
private_key = fp.read()
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_material': private_key}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_material": private_key,
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
pkey = paramiko.RSAKey.from_private_key(StringIO(private_key))
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'pkey': pkey,
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "pkey": pkey,
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_key_material_argument_invalid_key(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_material': 'id_rsa'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_material": "id_rsa",
+ }
mock = ParamikoSSHClient(**conn_params)
- expected_msg = 'Invalid or unsupported key type'
- self.assertRaisesRegexp(paramiko.ssh_exception.SSHException,
- expected_msg, mock.connect)
+ expected_msg = "Invalid or unsupported key type"
+ self.assertRaisesRegexp(
+ paramiko.ssh_exception.SSHException, expected_msg, mock.connect
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=True))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True)
+ )
def test_passphrase_no_key_provided(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'passphrase': 'testphrase'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "passphrase": "testphrase",
+ }
- expected_msg = 'passphrase should accompany private key material'
+ expected_msg = "passphrase should accompany private key material"
client = ParamikoSSHClient(**conn_params)
self.assertRaisesRegexp(ValueError, expected_msg, client.connect)
- @patch('paramiko.SSHClient', Mock)
+ @patch("paramiko.SSHClient", Mock)
def test_passphrase_not_provided_for_encrypted_key_file(self):
- path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_rsa_passphrase')
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_files': path}
+ path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase")
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_files": path,
+ }
mock = ParamikoSSHClient(**conn_params)
- self.assertRaises(paramiko.ssh_exception.PasswordRequiredException, mock.connect)
-
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=True))
+ self.assertRaises(
+ paramiko.ssh_exception.PasswordRequiredException, mock.connect
+ )
+
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True)
+ )
def test_key_with_passphrase_success(self):
- path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_rsa_passphrase')
+ path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase")
- with open(path, 'r') as fp:
+ with open(path, "r") as fp:
private_key = fp.read()
# Key material provided
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_material': private_key,
- 'passphrase': 'testphrase'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_material": private_key,
+ "passphrase": "testphrase",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), 'testphrase')
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'pkey': pkey,
- 'timeout': 30,
- 'port': 22}
+ pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), "testphrase")
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "pkey": pkey,
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
# Path to private key file provided
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_files': path,
- 'passphrase': 'testphrase'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_files": path,
+ "passphrase": "testphrase",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'key_filename': path,
- 'password': 'testphrase',
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "key_filename": path,
+ "password": "testphrase",
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=True))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True)
+ )
def test_passphrase_and_no_key(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'passphrase': 'testphrase'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "passphrase": "testphrase",
+ }
- expected_msg = 'passphrase should accompany private key material'
+ expected_msg = "passphrase should accompany private key material"
client = ParamikoSSHClient(**conn_params)
- self.assertRaisesRegexp(ValueError, expected_msg,
- client.connect)
+ self.assertRaisesRegexp(ValueError, expected_msg, client.connect)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=True))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True)
+ )
def test_incorrect_passphrase(self):
- path = os.path.join(get_resources_base_path(),
- 'ssh', 'dummy_rsa_passphrase')
+ path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase")
- with open(path, 'r') as fp:
+ with open(path, "r") as fp:
private_key = fp.read()
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_material': private_key,
- 'passphrase': 'incorrect'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_material": private_key,
+ "passphrase": "incorrect",
+ }
mock = ParamikoSSHClient(**conn_params)
- expected_msg = 'Invalid passphrase or invalid/unsupported key type'
- self.assertRaisesRegexp(paramiko.ssh_exception.SSHException,
- expected_msg, mock.connect)
-
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ expected_msg = "Invalid passphrase or invalid/unsupported key type"
+ self.assertRaisesRegexp(
+ paramiko.ssh_exception.SSHException, expected_msg, mock.connect
+ )
+
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_key_material_contains_path_not_contents(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu'}
- key_materials = [
- '~/.ssh/id_rsa',
- '/tmp/id_rsa',
- 'C:\\id_rsa'
- ]
+ conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"}
+ key_materials = ["~/.ssh/id_rsa", "/tmp/id_rsa", "C:\\id_rsa"]
- expected_msg = ('"private_key" parameter needs to contain private key data / content and '
- 'not a path')
+ expected_msg = (
+ '"private_key" parameter needs to contain private key data / content and '
+ "not a path"
+ )
for key_material in key_materials:
conn_params = conn_params.copy()
- conn_params['key_material'] = key_material
+ conn_params["key_material"] = key_material
mock = ParamikoSSHClient(**conn_params)
- self.assertRaisesRegexp(paramiko.ssh_exception.SSHException,
- expected_msg, mock.connect)
+ self.assertRaisesRegexp(
+ paramiko.ssh_exception.SSHException, expected_msg, mock.connect
+ )
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_with_key(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'key_files': 'id_rsa'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "key_files": "id_rsa",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'key_filename': 'id_rsa',
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "key_filename": "id_rsa",
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_with_key_via_bastion(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'bastion_host': 'bastion.host.org',
- 'username': 'ubuntu',
- 'key_files': 'id_rsa'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "bastion_host": "bastion.host.org",
+ "username": "ubuntu",
+ "key_files": "id_rsa",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_bastion_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'bastion.host.org',
- 'look_for_keys': False,
- 'key_filename': 'id_rsa',
- 'timeout': 30,
- 'port': 22}
+ expected_bastion_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "bastion.host.org",
+ "look_for_keys": False,
+ "key_filename": "id_rsa",
+ "timeout": 30,
+ "port": 22,
+ }
mock.bastion_client.connect.assert_called_once_with(**expected_bastion_conn)
- expected_conn = {'username': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'key_filename': 'id_rsa',
- 'timeout': 30,
- 'port': 22,
- 'sock': mock.bastion_socket}
+ expected_conn = {
+ "username": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "key_filename": "id_rsa",
+ "timeout": 30,
+ "port": 22,
+ "sock": mock.bastion_socket,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_with_password_and_key(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'ubuntu',
- 'key_files': 'id_rsa'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "ubuntu",
+ "key_files": "id_rsa",
+ }
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'password': 'ubuntu',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'key_filename': 'id_rsa',
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "password": "ubuntu",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "key_filename": "id_rsa",
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_without_credentials(self):
"""
Initialize object with no credentials.
@@ -394,44 +487,54 @@ def test_create_without_credentials(self):
the final parameters at the last moment when we explicitly
try to connect, all the credentials should be set to None.
"""
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu'}
+ conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"}
mock = ParamikoSSHClient(**conn_params)
self.assertEqual(mock.password, None)
self.assertEqual(mock.key_material, None)
self.assertEqual(mock.key_files, None)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_create_without_credentials_use_default_key(self):
# No credentials are provided by default stanley ssh key exists so it should use that
- cfg.CONF.set_override(name='ssh_key_file', override='stanley_rsa', group='system_user')
+ cfg.CONF.set_override(
+ name="ssh_key_file", override="stanley_rsa", group="system_user"
+ )
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu'}
+ conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"}
mock = ParamikoSSHClient(**conn_params)
mock.connect()
- expected_conn = {'username': 'ubuntu',
- 'hostname': 'dummy.host.org',
- 'key_filename': 'stanley_rsa',
- 'allow_agent': False,
- 'look_for_keys': False,
- 'timeout': 30,
- 'port': 22}
+ expected_conn = {
+ "username": "ubuntu",
+ "hostname": "dummy.host.org",
+ "key_filename": "stanley_rsa",
+ "allow_agent": False,
+ "look_for_keys": False,
+ "timeout": 30,
+ "port": 22,
+ }
mock.client.connect.assert_called_once_with(**expected_conn)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_consume_stdout',
- MagicMock(return_value=StringIO('')))
- @patch.object(ParamikoSSHClient, '_consume_stderr',
- MagicMock(return_value=StringIO('')))
- @patch.object(os.path, 'exists', MagicMock(return_value=True))
- @patch.object(os, 'stat', MagicMock(return_value=None))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO(""))
+ )
+ @patch.object(
+ ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO(""))
+ )
+ @patch.object(os.path, "exists", MagicMock(return_value=True))
+ @patch.object(os, "stat", MagicMock(return_value=None))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_basic_usage_absolute_path(self):
"""
Basic execution.
@@ -443,13 +546,15 @@ def test_basic_usage_absolute_path(self):
# Connect behavior
mock.connect()
mock_cli = mock.client # The actual mocked object: SSHClient
- expected_conn = {'username': 'ubuntu',
- 'key_filename': '~/.ssh/ubuntu_ssh',
- 'allow_agent': False,
- 'hostname': 'dummy.host.org',
- 'look_for_keys': False,
- 'timeout': 28,
- 'port': 8822}
+ expected_conn = {
+ "username": "ubuntu",
+ "key_filename": "~/.ssh/ubuntu_ssh",
+ "allow_agent": False,
+ "hostname": "dummy.host.org",
+ "look_for_keys": False,
+ "timeout": 28,
+ "port": 8822,
+ }
mock_cli.connect.assert_called_once_with(**expected_conn)
mock.put(sd, sd, mirror_local_mode=False)
@@ -458,21 +563,23 @@ def test_basic_usage_absolute_path(self):
mock.run(sd)
# Make assertions over 'run' method
- mock_cli.get_transport().open_session().exec_command \
- .assert_called_once_with(sd)
+ mock_cli.get_transport().open_session().exec_command.assert_called_once_with(sd)
mock.close()
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_delete_script(self):
"""
Provide a basic test with 'delete' action.
"""
mock = self.ssh_cli
# script to execute
- sd = '/root/random_script.sh'
+ sd = "/root/random_script.sh"
mock.connect()
@@ -482,91 +589,110 @@ def test_delete_script(self):
mock.close()
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
- @patch.object(ParamikoSSHClient, 'exists', return_value=False)
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
+ @patch.object(ParamikoSSHClient, "exists", return_value=False)
def test_put_dir(self, *args):
mock = self.ssh_cli
mock.connect()
- local_dir = os.path.join(get_resources_base_path(), 'packs')
- mock.put_dir(local_path=local_dir, remote_path='/tmp')
+ local_dir = os.path.join(get_resources_base_path(), "packs")
+ mock.put_dir(local_path=local_dir, remote_path="/tmp")
mock_cli = mock.client # The actual mocked object: SSHClient
# Assert that expected dirs are created on remote box.
- calls = [call('/tmp/packs/pythonactions'), call('/tmp/packs/pythonactions/actions')]
+ calls = [
+ call("/tmp/packs/pythonactions"),
+ call("/tmp/packs/pythonactions/actions"),
+ ]
mock_cli.open_sftp().mkdir.assert_has_calls(calls, any_order=True)
# Assert that expected files are copied to remote box.
- local_file = os.path.join(get_resources_base_path(),
- 'packs/pythonactions/actions/pascal_row.py')
- remote_file = os.path.join('/tmp', 'packs/pythonactions/actions/pascal_row.py')
+ local_file = os.path.join(
+ get_resources_base_path(), "packs/pythonactions/actions/pascal_row.py"
+ )
+ remote_file = os.path.join("/tmp", "packs/pythonactions/actions/pascal_row.py")
calls = [call(local_file, remote_file)]
mock_cli.open_sftp().put.assert_has_calls(calls, any_order=True)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_consume_stdout(self):
# Test utf-8 decoding of ``stdout`` still works fine when reading CHUNK_SIZE splits a
# multi-byte utf-8 character in the middle. We should wait to collect all bytes
# and finally decode.
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu'}
+ conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"}
mock = ParamikoSSHClient(**conn_params)
mock.CHUNK_SIZE = 1
chan = Mock()
chan.recv_ready.side_effect = [True, True, True, True, False]
- chan.recv.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88']
+ chan.recv.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"]
try:
- b'\xF0'.decode('utf-8')
- self.fail('Test fixture is not right.')
+ b"\xF0".decode("utf-8")
+ self.fail("Test fixture is not right.")
except UnicodeDecodeError:
pass
stdout = mock._consume_stdout(chan)
- self.assertEqual(u'\U00010348', stdout.getvalue())
+ self.assertEqual("\U00010348", stdout.getvalue())
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_consume_stderr(self):
# Test utf-8 decoding of ``stderr`` still works fine when reading CHUNK_SIZE splits a
# multi-byte utf-8 character in the middle. We should wait to collect all bytes
# and finally decode.
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu'}
+ conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"}
mock = ParamikoSSHClient(**conn_params)
mock.CHUNK_SIZE = 1
chan = Mock()
chan.recv_stderr_ready.side_effect = [True, True, True, True, False]
- chan.recv_stderr.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88']
+ chan.recv_stderr.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"]
try:
- b'\xF0'.decode('utf-8')
- self.fail('Test fixture is not right.')
+ b"\xF0".decode("utf-8")
+ self.fail("Test fixture is not right.")
except UnicodeDecodeError:
pass
stderr = mock._consume_stderr(chan)
- self.assertEqual(u'\U00010348', stderr.getvalue())
-
- @patch('paramiko.SSHClient', Mock)
- @patch.object(ParamikoSSHClient, '_consume_stdout',
- MagicMock(return_value=StringIO('')))
- @patch.object(ParamikoSSHClient, '_consume_stderr',
- MagicMock(return_value=StringIO('')))
- @patch.object(os.path, 'exists', MagicMock(return_value=True))
- @patch.object(os, 'stat', MagicMock(return_value=None))
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ self.assertEqual("\U00010348", stderr.getvalue())
+
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(
+ ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO(""))
+ )
+ @patch.object(
+ ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO(""))
+ )
+ @patch.object(os.path, "exists", MagicMock(return_value=True))
+ @patch.object(os, "stat", MagicMock(return_value=None))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_sftp_connection_is_only_established_if_required(self):
# Verify that SFTP connection is lazily established only if and when needed.
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu', 'password': 'ubuntu'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "ubuntu",
+ }
# Verify sftp connection and client hasn't been established yet
client = ParamikoSSHClient(**conn_params)
@@ -577,7 +703,7 @@ def test_sftp_connection_is_only_established_if_required(self):
# run method doesn't require sftp access so it shouldn't establish connection
client = ParamikoSSHClient(**conn_params)
client.connect()
- client.run(cmd='whoami')
+ client.run(cmd="whoami")
self.assertIsNone(client.sftp_client)
@@ -585,7 +711,7 @@ def test_sftp_connection_is_only_established_if_required(self):
# put
client = ParamikoSSHClient(**conn_params)
client.connect()
- path = '/root/random_script.sh'
+ path = "/root/random_script.sh"
client.put(path, path, mirror_local_mode=False)
self.assertIsNotNone(client.sftp_client)
@@ -593,14 +719,14 @@ def test_sftp_connection_is_only_established_if_required(self):
# exists
client = ParamikoSSHClient(**conn_params)
client.connect()
- client.exists('/root/somepath.txt')
+ client.exists("/root/somepath.txt")
self.assertIsNotNone(client.sftp_client)
# mkdir
client = ParamikoSSHClient(**conn_params)
client.connect()
- client.mkdir('/root/somedirfoo')
+ client.mkdir("/root/somedirfoo")
self.assertIsNotNone(client.sftp_client)
@@ -614,26 +740,26 @@ def test_sftp_connection_is_only_established_if_required(self):
# Verify SFTP connection is closed if it's opened
client = ParamikoSSHClient(**conn_params)
client.connect()
- client.mkdir('/root/somedirfoo')
+ client.mkdir("/root/somedirfoo")
self.assertIsNotNone(client.sftp_client)
client.close()
self.assertEqual(client.sftp_client.close.call_count, 1)
- @patch('paramiko.SSHClient', Mock)
- @patch.object(os.path, 'exists', MagicMock(return_value=True))
- @patch.object(os, 'stat', MagicMock(return_value=None))
+ @patch("paramiko.SSHClient", Mock)
+ @patch.object(os.path, "exists", MagicMock(return_value=True))
+ @patch.object(os, "stat", MagicMock(return_value=None))
def test_handle_stdout_and_stderr_line_funcs(self):
mock_handle_stdout_line_func = mock.Mock()
mock_handle_stderr_line_func = mock.Mock()
conn_params = {
- 'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'ubuntu',
- 'handle_stdout_line_func': mock_handle_stdout_line_func,
- 'handle_stderr_line_func': mock_handle_stderr_line_func
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "ubuntu",
+ "handle_stdout_line_func": mock_handle_stdout_line_func,
+ "handle_stderr_line_func": mock_handle_stderr_line_func,
}
client = ParamikoSSHClient(**conn_params)
client.connect()
@@ -654,6 +780,7 @@ def mock_recv_ready():
return True
return False
+
return mock_recv_ready
def mock_recv_stderr_ready_factory(chan):
@@ -665,12 +792,13 @@ def mock_recv_stderr_ready():
return True
return False
+
return mock_recv_stderr_ready
mock_chan.recv_ready = mock_recv_ready_factory(mock_chan)
mock_chan.recv_stderr_ready = mock_recv_stderr_ready_factory(mock_chan)
- mock_chan.recv.return_value = 'stdout 1\nstdout 2\nstdout 3'
- mock_chan.recv_stderr.return_value = 'stderr 1\nstderr 2\nstderr 3'
+ mock_chan.recv.return_value = "stdout 1\nstdout 2\nstdout 3"
+ mock_chan.recv_stderr.return_value = "stderr 1\nstderr 2\nstderr 3"
# call_line_handler_func is False so handler functions shouldn't be called
client.run(cmd='echo "test"', call_line_handler_func=False)
@@ -686,132 +814,176 @@ def mock_recv_stderr_ready():
client.run(cmd='echo "test"', call_line_handler_func=True)
self.assertEqual(mock_handle_stdout_line_func.call_count, 3)
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n')
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n')
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n')
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n"
+ )
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n"
+ )
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n"
+ )
self.assertEqual(mock_handle_stderr_line_func.call_count, 3)
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n')
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n')
- self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n')
-
- @patch('paramiko.SSHClient')
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n"
+ )
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n"
+ )
+ self.assertEqual(
+ mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n"
+ )
+
+ @patch("paramiko.SSHClient")
def test_use_ssh_config_port_value_provided_in_the_config(self, mock_sshclient):
- cfg.CONF.set_override(name='use_ssh_config', override=True, group='ssh_runner')
+ cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner")
- ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', 'empty_config')
- cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path,
- group='ssh_runner')
+ ssh_config_file_path = os.path.join(
+ get_resources_base_path(), "ssh", "empty_config"
+ )
+ cfg.CONF.set_override(
+ name="ssh_config_file_path",
+ override=ssh_config_file_path,
+ group="ssh_runner",
+ )
# 1. Default port is used (not explicitly provided)
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 22)
+ self.assertEqual(call_kwargs["port"], 22)
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'port': None,
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "port": None,
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 22)
+ self.assertEqual(call_kwargs["port"], 22)
# 2. Default port is used (explicitly provided)
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'port': DEFAULT_SSH_PORT,
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "port": DEFAULT_SSH_PORT,
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], DEFAULT_SSH_PORT)
- self.assertEqual(call_kwargs['port'], 22)
+ self.assertEqual(call_kwargs["port"], DEFAULT_SSH_PORT)
+ self.assertEqual(call_kwargs["port"], 22)
# 3. Custom port is used (explicitly provided)
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'port': 5555,
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "port": 5555,
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 5555)
+ self.assertEqual(call_kwargs["port"], 5555)
# 4. Custom port is specified in the ssh config (it has precedence over default port)
- ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh',
- 'ssh_config_custom_port')
- cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path,
- group='ssh_runner')
+ ssh_config_file_path = os.path.join(
+ get_resources_base_path(), "ssh", "ssh_config_custom_port"
+ )
+ cfg.CONF.set_override(
+ name="ssh_config_file_path",
+ override=ssh_config_file_path,
+ group="ssh_runner",
+ )
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 6677)
+ self.assertEqual(call_kwargs["port"], 6677)
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'port': DEFAULT_SSH_PORT}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "port": DEFAULT_SSH_PORT,
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 6677)
+ self.assertEqual(call_kwargs["port"], 6677)
# 5. Custom port is specified in ssh config, but one is also provided via runner parameter
# (runner parameter one has precedence)
- ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh',
- 'ssh_config_custom_port')
- cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path,
- group='ssh_runner')
+ ssh_config_file_path = os.path.join(
+ get_resources_base_path(), "ssh", "ssh_config_custom_port"
+ )
+ cfg.CONF.set_override(
+ name="ssh_config_file_path",
+ override=ssh_config_file_path,
+ group="ssh_runner",
+ )
mock_client = mock.Mock()
mock_sshclient.return_value = mock_client
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'port': 9999}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "port": 9999,
+ }
ssh_client = ParamikoSSHClient(**conn_params)
ssh_client.connect()
call_kwargs = mock_client.connect.call_args[1]
- self.assertEqual(call_kwargs['port'], 9999)
+ self.assertEqual(call_kwargs["port"], 9999)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_socket_closed(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
# Make sure .close() doesn't actually call anything real
@@ -840,13 +1012,18 @@ def test_socket_closed(self):
self.assertEqual(ssh_client.bastion_socket.close.call_count, 1)
self.assertEqual(ssh_client.bastion_client.close.call_count, 1)
- @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase',
- MagicMock(return_value=False))
+ @patch.object(
+ ParamikoSSHClient,
+ "_is_key_file_needs_passphrase",
+ MagicMock(return_value=False),
+ )
def test_socket_not_closed_if_none(self):
- conn_params = {'hostname': 'dummy.host.org',
- 'username': 'ubuntu',
- 'password': 'pass',
- 'timeout': '600'}
+ conn_params = {
+ "hostname": "dummy.host.org",
+ "username": "ubuntu",
+ "password": "pass",
+ "timeout": "600",
+ }
ssh_client = ParamikoSSHClient(**conn_params)
# Make sure .close() doesn't actually call anything real
diff --git a/st2actions/tests/unit/test_paramiko_ssh_runner.py b/st2actions/tests/unit/test_paramiko_ssh_runner.py
index 8467264c5b..f42746d602 100644
--- a/st2actions/tests/unit/test_paramiko_ssh_runner.py
+++ b/st2actions/tests/unit/test_paramiko_ssh_runner.py
@@ -29,6 +29,7 @@
import st2tests.config as tests_config
from st2tests.fixturesloader import get_resources_base_path
+
tests_config.parse_args()
@@ -38,195 +39,192 @@ def run(self):
class ParamikoSSHRunnerTestCase(unittest2.TestCase):
- @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient')
+ @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient")
def test_pre_run(self, mock_client):
# Test case which verifies that ParamikoSSHClient is instantiated with the correct arguments
- private_key_path = os.path.join(get_resources_base_path(), 'ssh', 'dummy_rsa')
+ private_key_path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa")
- with open(private_key_path, 'r') as fp:
+ with open(private_key_path, "r") as fp:
private_key = fp.read()
# Username and password provided
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost',
- RUNNER_USERNAME: 'someuser1',
- RUNNER_PASSWORD: 'somepassword'
+ RUNNER_HOSTS: "localhost",
+ RUNNER_USERNAME: "someuser1",
+ RUNNER_PASSWORD: "somepassword",
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost'],
- 'user': 'someuser1',
- 'password': 'somepassword',
- 'port': None,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost"],
+ "user": "someuser1",
+ "password": "somepassword",
+ "port": None,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
# Private key provided as raw key material
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost',
- RUNNER_USERNAME: 'someuser2',
+ RUNNER_HOSTS: "localhost",
+ RUNNER_USERNAME: "someuser2",
RUNNER_PRIVATE_KEY: private_key,
- RUNNER_SSH_PORT: 22
+ RUNNER_SSH_PORT: 22,
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost'],
- 'user': 'someuser2',
- 'pkey_material': private_key,
- 'port': 22,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost"],
+ "user": "someuser2",
+ "pkey_material": private_key,
+ "port": 22,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
# Private key provided as raw key material + passphrase
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost21',
- RUNNER_USERNAME: 'someuser21',
+ RUNNER_HOSTS: "localhost21",
+ RUNNER_USERNAME: "someuser21",
RUNNER_PRIVATE_KEY: private_key,
- RUNNER_PASSPHRASE: 'passphrase21',
- RUNNER_SSH_PORT: 22
+ RUNNER_PASSPHRASE: "passphrase21",
+ RUNNER_SSH_PORT: 22,
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost21'],
- 'user': 'someuser21',
- 'pkey_material': private_key,
- 'passphrase': 'passphrase21',
- 'port': 22,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost21"],
+ "user": "someuser21",
+ "pkey_material": private_key,
+ "passphrase": "passphrase21",
+ "port": 22,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
# Private key provided as path to the private key file
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost',
- RUNNER_USERNAME: 'someuser3',
+ RUNNER_HOSTS: "localhost",
+ RUNNER_USERNAME: "someuser3",
RUNNER_PRIVATE_KEY: private_key_path,
- RUNNER_SSH_PORT: 22
+ RUNNER_SSH_PORT: 22,
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost'],
- 'user': 'someuser3',
- 'pkey_file': private_key_path,
- 'port': 22,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost"],
+ "user": "someuser3",
+ "pkey_file": private_key_path,
+ "port": 22,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
# Private key provided as path to the private key file + passphrase
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost31',
- RUNNER_USERNAME: 'someuser31',
+ RUNNER_HOSTS: "localhost31",
+ RUNNER_USERNAME: "someuser31",
RUNNER_PRIVATE_KEY: private_key_path,
- RUNNER_PASSPHRASE: 'passphrase31',
- RUNNER_SSH_PORT: 22
+ RUNNER_PASSPHRASE: "passphrase31",
+ RUNNER_SSH_PORT: 22,
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost31'],
- 'user': 'someuser31',
- 'pkey_file': private_key_path,
- 'passphrase': 'passphrase31',
- 'port': 22,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost31"],
+ "user": "someuser31",
+ "pkey_file": private_key_path,
+ "passphrase": "passphrase31",
+ "port": 22,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
# No password or private key provided, should default to system user private key
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
- runner_parameters = {
- RUNNER_HOSTS: 'localhost4',
- RUNNER_SSH_PORT: 22
- }
+ runner_parameters = {RUNNER_HOSTS: "localhost4", RUNNER_SSH_PORT: 22}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost4'],
- 'user': None,
- 'pkey_file': None,
- 'port': 22,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost4"],
+ "user": None,
+ "pkey_file": None,
+ "port": 22,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
- @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient')
+ @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient")
def test_post_run(self, mock_client):
# Verify that the SSH connections are closed on post_run
- runner = Runner('id')
+ runner = Runner("id")
runner.context = {}
runner_parameters = {
- RUNNER_HOSTS: 'localhost',
- RUNNER_USERNAME: 'someuser1',
- RUNNER_PASSWORD: 'somepassword'
+ RUNNER_HOSTS: "localhost",
+ RUNNER_USERNAME: "someuser1",
+ RUNNER_PASSWORD: "somepassword",
}
runner.runner_parameters = runner_parameters
runner.pre_run()
expected_kwargs = {
- 'hosts': ['localhost'],
- 'user': 'someuser1',
- 'password': 'somepassword',
- 'port': None,
- 'concurrency': 1,
- 'bastion_host': None,
- 'raise_on_any_error': False,
- 'connect': True,
- 'handle_stdout_line_func': mock.ANY,
- 'handle_stderr_line_func': mock.ANY
+ "hosts": ["localhost"],
+ "user": "someuser1",
+ "password": "somepassword",
+ "port": None,
+ "concurrency": 1,
+ "bastion_host": None,
+ "raise_on_any_error": False,
+ "connect": True,
+ "handle_stdout_line_func": mock.ANY,
+ "handle_stderr_line_func": mock.ANY,
}
mock_client.assert_called_with(**expected_kwargs)
self.assertEqual(runner._parallel_ssh_client.close.call_count, 0)
diff --git a/st2actions/tests/unit/test_policies.py b/st2actions/tests/unit/test_policies.py
index 4be7af59e5..f16ffbcb0b 100644
--- a/st2actions/tests/unit/test_policies.py
+++ b/st2actions/tests/unit/test_policies.py
@@ -37,37 +37,34 @@
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml'
- ],
- 'policytypes': [
- 'fake_policy_type_1.yaml',
- 'fake_policy_type_2.yaml'
- ],
- 'policies': [
- 'policy_1.yaml',
- 'policy_2.yaml'
- ]
+ "actions": ["action1.yaml"],
+ "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"],
+ "policies": ["policy_1.yaml", "policy_2.yaml"],
}
-PACK = 'generic'
+PACK = "generic"
LOADER = FixturesLoader()
FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
@mock.patch.object(
- CUDPublisher, 'publish_update',
- mock.MagicMock(side_effect=MockExecutionPublisher.publish_update))
+ CUDPublisher,
+ "publish_update",
+ mock.MagicMock(side_effect=MockExecutionPublisher.publish_update),
+)
+@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None))
@mock.patch.object(
- CUDPublisher, 'publish_create',
- mock.MagicMock(return_value=None))
-@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state))
-@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state),
+)
+@mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch(
+ "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
class SchedulingPolicyTest(ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(SchedulingPolicyTest, cls).setUpClass()
@@ -75,15 +72,15 @@ def setUpClass(cls):
# Register runners
runners_registrar.register_runners()
- for _, fixture in six.iteritems(FIXTURES['actions']):
+ for _, fixture in six.iteritems(FIXTURES["actions"]):
instance = ActionAPI(**fixture)
Action.add_or_update(ActionAPI.to_model(instance))
- for _, fixture in six.iteritems(FIXTURES['policytypes']):
+ for _, fixture in six.iteritems(FIXTURES["policytypes"]):
instance = PolicyTypeAPI(**fixture)
PolicyType.add_or_update(PolicyTypeAPI.to_model(instance))
- for _, fixture in six.iteritems(FIXTURES['policies']):
+ for _, fixture in six.iteritems(FIXTURES["policies"]):
instance = PolicyAPI(**fixture)
Policy.add_or_update(PolicyAPI.to_model(instance))
@@ -91,35 +88,54 @@ def tearDown(self):
# Ensure all liveactions are canceled at end of each test.
for liveaction in LiveAction.get_all():
action_service.update_status(
- liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
@mock.patch.object(
- FakeConcurrencyApplicator, 'apply_before',
+ FakeConcurrencyApplicator,
+ "apply_before",
mock.MagicMock(
- side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before))
+ side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before
+ ),
+ )
@mock.patch.object(
- RaiseExceptionApplicator, 'apply_before',
- mock.MagicMock(
- side_effect=RaiseExceptionApplicator(None, None).apply_before))
+ RaiseExceptionApplicator,
+ "apply_before",
+ mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_before),
+ )
@mock.patch.object(
- FakeConcurrencyApplicator, 'apply_after',
+ FakeConcurrencyApplicator,
+ "apply_after",
mock.MagicMock(
- side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after))
+ side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after
+ ),
+ )
@mock.patch.object(
- RaiseExceptionApplicator, 'apply_after',
- mock.MagicMock(
- side_effect=RaiseExceptionApplicator(None, None).apply_after))
+ RaiseExceptionApplicator,
+ "apply_after",
+ mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_after),
+ )
def test_apply(self):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED
+ )
FakeConcurrencyApplicator.apply_before.assert_called_once_with(liveaction)
RaiseExceptionApplicator.apply_before.assert_called_once_with(liveaction)
FakeConcurrencyApplicator.apply_after.assert_called_once_with(liveaction)
RaiseExceptionApplicator.apply_after.assert_called_once_with(liveaction)
- @mock.patch.object(FakeConcurrencyApplicator, 'get_threshold', mock.MagicMock(return_value=0))
+ @mock.patch.object(
+ FakeConcurrencyApplicator, "get_threshold", mock.MagicMock(return_value=0)
+ )
def test_enforce(self):
- liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'})
+ liveaction = LiveActionDB(
+ action="wolfpack.action-1", parameters={"actionstr": "foo"}
+ )
liveaction, _ = action_service.request(liveaction)
- liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED)
+ liveaction = self._wait_on_status(
+ liveaction, action_constants.LIVEACTION_STATUS_CANCELED
+ )
diff --git a/st2actions/tests/unit/test_polling_async_runner.py b/st2actions/tests/unit/test_polling_async_runner.py
index 435f7eb9b6..c48bb9aa67 100644
--- a/st2actions/tests/unit/test_polling_async_runner.py
+++ b/st2actions/tests/unit/test_polling_async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import PollingAsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class PollingAsyncTestRunner(PollingAsyncActionRunner):
def __init__(self):
- super(PollingAsyncTestRunner, self).__init__(runner_id='1')
+ super(PollingAsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2actions/tests/unit/test_queue_consumers.py b/st2actions/tests/unit/test_queue_consumers.py
index 1550a82e22..80d3a09c26 100644
--- a/st2actions/tests/unit/test_queue_consumers.py
+++ b/st2actions/tests/unit/test_queue_consumers.py
@@ -18,6 +18,7 @@
import st2tests
import st2tests.config as tests_config
+
tests_config.parse_args()
import mock
@@ -39,16 +40,13 @@
from st2tests.base import ExecutionDbTestCase
-PACKS = [
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
-]
+PACKS = [st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core"]
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
-@mock.patch.object(executions, 'update_execution', mock.MagicMock())
-@mock.patch.object(Message, 'ack', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+@mock.patch.object(executions, "update_execution", mock.MagicMock())
+@mock.patch.object(Message, "ack", mock.MagicMock())
class QueueConsumerTest(ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
super(QueueConsumerTest, cls).setUpClass()
@@ -58,8 +56,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -71,14 +68,16 @@ def __init__(self, *args, **kwargs):
self.scheduling_queue = scheduling_queue.get_handler()
self.dispatcher = worker.get_worker()
- def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED):
- action_db = action_utils.get_action_by_ref('core.noop')
+ def _create_liveaction_db(
+ self, status=action_constants.LIVEACTION_STATUS_REQUESTED
+ ):
+ action_db = action_utils.get_action_by_ref("core.noop")
liveaction_db = LiveActionDB(
action=action_db.ref,
parameters=None,
start_timestamp=date_utils.get_datetime_utc_now(),
- status=status
+ status=status,
)
liveaction_db = action.LiveAction.add_or_update(liveaction_db, publish=False)
@@ -91,15 +90,16 @@ def _process_request(self, liveaction_db):
queued_request = self.scheduling_queue._get_next_execution()
self.scheduling_queue._handle_execution(queued_request)
- @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value={'key': 'value'}))
+ @mock.patch.object(
+ RunnerContainer, "dispatch", mock.MagicMock(return_value={"key": "value"})
+ )
def test_execute(self):
liveaction_db = self._create_liveaction_db()
self._process_request(liveaction_db)
scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
scheduled_liveaction_db = self._wait_on_status(
- scheduled_liveaction_db,
- action_constants.LIVEACTION_STATUS_SCHEDULED
+ scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED
)
self.assertDictEqual(scheduled_liveaction_db.runner_info, {})
@@ -107,54 +107,56 @@ def test_execute(self):
dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
self.assertGreater(len(list(dispatched_liveaction_db.runner_info.keys())), 0)
self.assertEqual(
- dispatched_liveaction_db.status,
- action_constants.LIVEACTION_STATUS_RUNNING
+ dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_RUNNING
)
- @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(side_effect=Exception('Boom!')))
+ @mock.patch.object(
+ RunnerContainer, "dispatch", mock.MagicMock(side_effect=Exception("Boom!"))
+ )
def test_execute_failure(self):
liveaction_db = self._create_liveaction_db()
self._process_request(liveaction_db)
scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
scheduled_liveaction_db = self._wait_on_status(
- scheduled_liveaction_db,
- action_constants.LIVEACTION_STATUS_SCHEDULED
+ scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED
)
self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db)
dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
- self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED)
+ self.assertEqual(
+ dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED
+ )
- @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None))
+ @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None))
def test_execute_no_result(self):
liveaction_db = self._create_liveaction_db()
self._process_request(liveaction_db)
scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
scheduled_liveaction_db = self._wait_on_status(
- scheduled_liveaction_db,
- action_constants.LIVEACTION_STATUS_SCHEDULED
+ scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED
)
self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db)
dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
- self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED)
+ self.assertEqual(
+ dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED
+ )
- @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None))
+ @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None))
def test_execute_cancelation(self):
liveaction_db = self._create_liveaction_db()
self._process_request(liveaction_db)
scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
scheduled_liveaction_db = self._wait_on_status(
- scheduled_liveaction_db,
- action_constants.LIVEACTION_STATUS_SCHEDULED
+ scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED
)
action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_CANCELED,
- liveaction_id=liveaction_db.id
+ liveaction_id=liveaction_db.id,
)
canceled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
@@ -162,11 +164,10 @@ def test_execute_cancelation(self):
dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id)
self.assertEqual(
- dispatched_liveaction_db.status,
- action_constants.LIVEACTION_STATUS_CANCELED
+ dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_CANCELED
)
self.assertDictEqual(
dispatched_liveaction_db.result,
- {'message': 'Action execution canceled by user.'}
+ {"message": "Action execution canceled by user."},
)
diff --git a/st2actions/tests/unit/test_remote_runners.py b/st2actions/tests/unit/test_remote_runners.py
index 26d75cb5dc..7f84165dbb 100644
--- a/st2actions/tests/unit/test_remote_runners.py
+++ b/st2actions/tests/unit/test_remote_runners.py
@@ -16,6 +16,7 @@
# XXX: FabricRunner import depends on config being setup.
from __future__ import absolute_import
import st2tests.config as tests_config
+
tests_config.parse_args()
from unittest2 import TestCase
@@ -26,12 +27,20 @@
class RemoteScriptActionTestCase(TestCase):
def test_parameter_formatting(self):
# Only named args
- named_args = {'--foo1': 'bar1', '--foo2': 'bar2', '--foo3': True,
- '--foo4': False}
+ named_args = {
+ "--foo1": "bar1",
+ "--foo2": "bar2",
+ "--foo3": True,
+ "--foo4": False,
+ }
- action = RemoteScriptAction(name='foo', action_exec_id='dummy',
- script_local_path_abs='test.py',
- script_local_libs_path_abs='/',
- remote_dir='/tmp',
- named_args=named_args, positional_args=None)
- self.assertEqual(action.command, '/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3')
+ action = RemoteScriptAction(
+ name="foo",
+ action_exec_id="dummy",
+ script_local_path_abs="test.py",
+ script_local_libs_path_abs="/",
+ remote_dir="/tmp",
+ named_args=named_args,
+ positional_args=None,
+ )
+ self.assertEqual(action.command, "/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3")
diff --git a/st2actions/tests/unit/test_runner_container.py b/st2actions/tests/unit/test_runner_container.py
index 3ccfb7a4ea..f17eeceb71 100644
--- a/st2actions/tests/unit/test_runner_container.py
+++ b/st2actions/tests/unit/test_runner_container.py
@@ -21,7 +21,10 @@
from st2common.constants import action as action_constants
from st2common.runners.base import get_runner
-from st2common.exceptions.actionrunner import ActionRunnerCreateError, ActionRunnerDispatchError
+from st2common.exceptions.actionrunner import (
+ ActionRunnerCreateError,
+ ActionRunnerDispatchError,
+)
from st2common.models.system.common import ResourceReference
from st2common.models.db.liveaction import LiveActionDB
from st2common.models.db.runner import RunnerTypeDB
@@ -34,6 +37,7 @@
from st2tests.base import DbTestCase
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2tests.fixturesloader import FixturesLoader
@@ -44,39 +48,43 @@
from st2actions.container.base import get_runner_container
TEST_FIXTURES = {
- 'runners': [
- 'run-local.yaml',
- 'testrunner1.yaml',
- 'testfailingrunner1.yaml',
- 'testasyncrunner1.yaml',
- 'testasyncrunner2.yaml'
+ "runners": [
+ "run-local.yaml",
+ "testrunner1.yaml",
+ "testfailingrunner1.yaml",
+ "testasyncrunner1.yaml",
+ "testasyncrunner2.yaml",
+ ],
+ "actions": [
+ "local.yaml",
+ "action1.yaml",
+ "async_action1.yaml",
+ "async_action2.yaml",
+ "action-invalid-runner.yaml",
],
- 'actions': [
- 'local.yaml',
- 'action1.yaml',
- 'async_action1.yaml',
- 'async_action2.yaml',
- 'action-invalid-runner.yaml'
- ]
}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
NON_UTF8_RESULT = {
- 'stderr': '',
- 'stdout': '\x82\n',
- 'succeeded': True,
- 'failed': False,
- 'return_code': 0
+ "stderr": "",
+ "stdout": "\x82\n",
+ "succeeded": True,
+ "failed": False,
+ "return_code": 0,
}
from st2tests.mocks.runners import runner
from st2tests.mocks.runners import polling_async_runner
-@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner()))
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch(
+ "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch(
+ "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner())
+)
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class RunnerContainerTest(DbTestCase):
action_db = None
async_action_db = None
@@ -88,30 +96,38 @@ class RunnerContainerTest(DbTestCase):
def setUpClass(cls):
super(RunnerContainerTest, cls).setUpClass()
- cfg.CONF.set_override(name='validate_output_schema', override=False, group='system')
+ cfg.CONF.set_override(
+ name="validate_output_schema", override=False, group="system"
+ )
models = RunnerContainerTest.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
- RunnerContainerTest.runnertype_db = models['runners']['testrunner1.yaml']
- RunnerContainerTest.action_db = models['actions']['action1.yaml']
- RunnerContainerTest.local_action_db = models['actions']['local.yaml']
- RunnerContainerTest.async_action_db = models['actions']['async_action1.yaml']
- RunnerContainerTest.polling_async_action_db = models['actions']['async_action2.yaml']
- RunnerContainerTest.failingaction_db = models['actions']['action-invalid-runner.yaml']
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ RunnerContainerTest.runnertype_db = models["runners"]["testrunner1.yaml"]
+ RunnerContainerTest.action_db = models["actions"]["action1.yaml"]
+ RunnerContainerTest.local_action_db = models["actions"]["local.yaml"]
+ RunnerContainerTest.async_action_db = models["actions"]["async_action1.yaml"]
+ RunnerContainerTest.polling_async_action_db = models["actions"][
+ "async_action2.yaml"
+ ]
+ RunnerContainerTest.failingaction_db = models["actions"][
+ "action-invalid-runner.yaml"
+ ]
@classmethod
def tearDownClass(cls):
RunnerContainerTest.fixtures_loader.delete_fixtures_from_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
super(RunnerContainerTest, cls).tearDownClass()
def test_get_runner_module(self):
- runner = get_runner(name='local-shell-script')
- self.assertIsNotNone(runner, 'TestRunner must be valid.')
+ runner = get_runner(name="local-shell-script")
+ self.assertIsNotNone(runner, "TestRunner must be valid.")
def test_pre_run_runner_is_disabled(self):
runnertype_db = RunnerContainerTest.runnertype_db
- runner = get_runner(name='local-shell-cmd')
+ runner = get_runner(name="local-shell-cmd")
runner.runner_type = runnertype_db
runner.runner_type.enabled = False
@@ -119,10 +135,12 @@ def test_pre_run_runner_is_disabled(self):
expected_msg = 'Runner "test-runner-1" has been disabled by the administrator'
self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run)
- def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action(self):
+ def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action(
+ self,
+ ):
params = {
- 'actionstr': 'bar',
- 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED
+ "actionstr": "bar",
+ "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED,
}
global global_runner
@@ -141,15 +159,17 @@ def mock_get_runner(*args, **kwargs):
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
- liveaction_db.context = {'user': 'user_joe_1'}
+ liveaction_db.context = {"user": "user_joe_1"}
executions.create_execution_object(liveaction_db)
runner_container._get_runner = mock_get_runner
- self.assertEqual(getattr(global_runner, 'auth_token', None), None)
+ self.assertEqual(getattr(global_runner, "auth_token", None), None)
runner_container.dispatch(liveaction_db)
- self.assertEqual(global_runner.auth_token.user, 'user_joe_1')
- self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container')
+ self.assertEqual(global_runner.auth_token.user, "user_joe_1")
+ self.assertEqual(
+ global_runner.auth_token.metadata["service"], "actions_container"
+ )
runner_container._get_runner = original_get_runner
@@ -160,23 +180,25 @@ def mock_get_runner(*args, **kwargs):
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
- liveaction_db.context = {'user': 'user_mark_2'}
+ liveaction_db.context = {"user": "user_mark_2"}
executions.create_execution_object(liveaction_db)
original_get_runner = runner_container._get_runner
runner_container._get_runner = mock_get_runner
- self.assertEqual(getattr(global_runner, 'auth_token', None), None)
+ self.assertEqual(getattr(global_runner, "auth_token", None), None)
runner_container.dispatch(liveaction_db)
- self.assertEqual(global_runner.auth_token.user, 'user_mark_2')
- self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container')
+ self.assertEqual(global_runner.auth_token.user, "user_mark_2")
+ self.assertEqual(
+ global_runner.auth_token.metadata["service"], "actions_container"
+ )
def test_post_run_is_always_called_after_run(self):
# 1. post_run should be called on success, failure, etc.
runner_container = get_runner_container()
params = {
- 'actionstr': 'bar',
- 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED
+ "actionstr": "bar",
+ "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED,
}
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -191,6 +213,7 @@ def mock_get_runner(*args, **kwargs):
runner = original_get_runner(*args, **kwargs)
global_runner = runner
return runner
+
runner_container._get_runner = mock_get_runner
# Note: We can't assert here that post_run hasn't been called yet because runner instance
@@ -200,10 +223,7 @@ def mock_get_runner(*args, **kwargs):
# 2. Verify post_run is called if run() throws
runner_container = get_runner_container()
- params = {
- 'actionstr': 'bar',
- 'raise': True
- }
+ params = {"actionstr": "bar", "raise": True}
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
@@ -216,6 +236,7 @@ def mock_get_runner(*args, **kwargs):
runner = original_get_runner(*args, **kwargs)
global_runner = runner
return runner
+
runner_container._get_runner = mock_get_runner
# Note: We can't assert here that post_run hasn't been called yet because runner instance
@@ -225,10 +246,10 @@ def mock_get_runner(*args, **kwargs):
# 2. Verify post_run is also called if _delete_auth_token throws
runner_container = get_runner_container()
- runner_container._delete_auth_token = mock.Mock(side_effect=ValueError('throw'))
+ runner_container._delete_auth_token = mock.Mock(side_effect=ValueError("throw"))
params = {
- 'actionstr': 'bar',
- 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED
+ "actionstr": "bar",
+ "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED,
}
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -242,6 +263,7 @@ def mock_get_runner(*args, **kwargs):
runner = original_get_runner(*args, **kwargs)
global_runner = runner
return runner
+
runner_container._get_runner = mock_get_runner
# Note: We can't assert here that post_run hasn't been called yet because runner instance
@@ -250,43 +272,42 @@ def mock_get_runner(*args, **kwargs):
self.assertTrue(global_runner.post_run_called)
def test_get_runner_module_fail(self):
- runnertype_db = RunnerTypeDB(name='dummy', runner_module='absent.module')
+ runnertype_db = RunnerTypeDB(name="dummy", runner_module="absent.module")
runner = None
try:
- runner = get_runner(runnertype_db.runner_module, runnertype_db.runner_module)
+ runner = get_runner(
+ runnertype_db.runner_module, runnertype_db.runner_module
+ )
except ActionRunnerCreateError:
pass
- self.assertFalse(runner, 'TestRunner must be valid.')
+ self.assertFalse(runner, "TestRunner must be valid.")
def test_dispatch(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'bar'
- }
- liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params)
+ params = {"actionstr": "bar"}
+ liveaction_db = self._get_liveaction_model(
+ RunnerContainerTest.action_db, params
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
# Assert that execution ran successfully.
runner_container.dispatch(liveaction_db)
liveaction_db = LiveAction.get_by_id(liveaction_db.id)
result = liveaction_db.result
- self.assertTrue(result.get('action_params').get('actionint') == 10)
- self.assertTrue(result.get('action_params').get('actionstr') == 'bar')
+ self.assertTrue(result.get("action_params").get("actionint") == 10)
+ self.assertTrue(result.get("action_params").get("actionstr") == "bar")
# Assert that context is written correctly.
- context = {
- 'user': 'stanley',
- 'third_party_system': {
- 'ref_id': '1234'
- }
- }
+ context = {"user": "stanley", "third_party_system": {"ref_id": "1234"}}
self.assertDictEqual(liveaction_db.context, context)
def test_dispatch_unsupported_status(self):
runner_container = get_runner_container()
- params = {'actionstr': 'bar'}
- liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params)
+ params = {"actionstr": "bar"}
+ liveaction_db = self._get_liveaction_model(
+ RunnerContainerTest.action_db, params
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
@@ -295,86 +316,74 @@ def test_dispatch_unsupported_status(self):
# Assert exception is raised on dispatch.
self.assertRaises(
- ActionRunnerDispatchError,
- runner_container.dispatch,
- liveaction_db
+ ActionRunnerDispatchError, runner_container.dispatch, liveaction_db
)
def test_dispatch_runner_failure(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'bar'
- }
+ params = {"actionstr": "bar"}
liveaction_db = self._get_failingaction_exec_db_model(params)
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
runner_container.dispatch(liveaction_db)
# pickup updated liveaction_db
liveaction_db = LiveAction.get_by_id(liveaction_db.id)
- self.assertIn('error', liveaction_db.result)
- self.assertIn('traceback', liveaction_db.result)
+ self.assertIn("error", liveaction_db.result)
+ self.assertIn("traceback", liveaction_db.result)
def test_dispatch_override_default_action_params(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'foo',
- 'actionint': 20
- }
- liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params)
+ params = {"actionstr": "foo", "actionint": 20}
+ liveaction_db = self._get_liveaction_model(
+ RunnerContainerTest.action_db, params
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
# Assert that execution ran successfully.
runner_container.dispatch(liveaction_db)
liveaction_db = LiveAction.get_by_id(liveaction_db.id)
result = liveaction_db.result
- self.assertTrue(result.get('action_params').get('actionint') == 20)
- self.assertTrue(result.get('action_params').get('actionstr') == 'foo')
+ self.assertTrue(result.get("action_params").get("actionint") == 20)
+ self.assertTrue(result.get("action_params").get("actionstr") == "foo")
def test_state_db_created_for_polling_async_actions(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'foo',
- 'actionint': 20,
- 'async_test': True
- }
+ params = {"actionstr": "foo", "actionint": 20, "async_test": True}
liveaction_db = self._get_liveaction_model(
- RunnerContainerTest.polling_async_action_db,
- params
+ RunnerContainerTest.polling_async_action_db, params
)
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
# Assert that execution ran without exceptions.
- with mock.patch('st2actions.container.base.get_runner',
- mock.Mock(return_value=polling_async_runner.get_runner())):
+ with mock.patch(
+ "st2actions.container.base.get_runner",
+ mock.Mock(return_value=polling_async_runner.get_runner()),
+ ):
runner_container.dispatch(liveaction_db)
states = ActionExecutionState.get_all()
found = [state for state in states if state.execution_id == liveaction_db.id]
- self.assertTrue(len(found) > 0, 'There should be a state db object.')
- self.assertTrue(len(found) == 1, 'There should only be one state db object.')
+ self.assertTrue(len(found) > 0, "There should be a state db object.")
+ self.assertTrue(len(found) == 1, "There should only be one state db object.")
self.assertIsNotNone(found[0].query_context)
self.assertIsNotNone(found[0].query_module)
@mock.patch.object(
PollingAsyncActionRunner,
- 'is_polling_enabled',
- mock.MagicMock(return_value=False))
+ "is_polling_enabled",
+ mock.MagicMock(return_value=False),
+ )
def test_state_db_not_created_for_disabled_polling_async_actions(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'foo',
- 'actionint': 20,
- 'async_test': True
- }
+ params = {"actionstr": "foo", "actionint": 20, "async_test": True}
liveaction_db = self._get_liveaction_model(
- RunnerContainerTest.polling_async_action_db,
- params
+ RunnerContainerTest.polling_async_action_db, params
)
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -385,20 +394,15 @@ def test_state_db_not_created_for_disabled_polling_async_actions(self):
states = ActionExecutionState.get_all()
found = [state for state in states if state.execution_id == liveaction_db.id]
- self.assertTrue(len(found) == 0, 'There should not be a state db object.')
+ self.assertTrue(len(found) == 0, "There should not be a state db object.")
def test_state_db_not_created_for_async_actions(self):
runner_container = get_runner_container()
- params = {
- 'actionstr': 'foo',
- 'actionint': 20,
- 'async_test': True
- }
+ params = {"actionstr": "foo", "actionint": 20, "async_test": True}
liveaction_db = self._get_liveaction_model(
- RunnerContainerTest.async_action_db,
- params
+ RunnerContainerTest.async_action_db, params
)
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -409,17 +413,21 @@ def test_state_db_not_created_for_async_actions(self):
states = ActionExecutionState.get_all()
found = [state for state in states if state.execution_id == liveaction_db.id]
- self.assertTrue(len(found) == 0, 'There should not be a state db object.')
+ self.assertTrue(len(found) == 0, "There should not be a state db object.")
def _get_liveaction_model(self, action_db, params):
status = action_constants.LIVEACTION_STATUS_REQUESTED
start_timestamp = date_utils.get_datetime_utc_now()
action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref
parameters = params
- context = {'user': cfg.CONF.system_user.user}
- liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp,
- action=action_ref, parameters=parameters,
- context=context)
+ context = {"user": cfg.CONF.system_user.user}
+ liveaction_db = LiveActionDB(
+ status=status,
+ start_timestamp=start_timestamp,
+ action=action_ref,
+ parameters=parameters,
+ context=context,
+ )
return liveaction_db
def _get_failingaction_exec_db_model(self, params):
@@ -427,12 +435,17 @@ def _get_failingaction_exec_db_model(self, params):
start_timestamp = date_utils.get_datetime_utc_now()
action_ref = ResourceReference(
name=RunnerContainerTest.failingaction_db.name,
- pack=RunnerContainerTest.failingaction_db.pack).ref
+ pack=RunnerContainerTest.failingaction_db.pack,
+ ).ref
parameters = params
- context = {'user': cfg.CONF.system_user.user}
- liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp,
- action=action_ref, parameters=parameters,
- context=context)
+ context = {"user": cfg.CONF.system_user.user}
+ liveaction_db = LiveActionDB(
+ status=status,
+ start_timestamp=start_timestamp,
+ action=action_ref,
+ parameters=parameters,
+ context=context,
+ )
return liveaction_db
def _get_output_schema_exec_db_model(self, params):
@@ -440,10 +453,15 @@ def _get_output_schema_exec_db_model(self, params):
start_timestamp = date_utils.get_datetime_utc_now()
action_ref = ResourceReference(
name=RunnerContainerTest.schema_output_action_db.name,
- pack=RunnerContainerTest.schema_output_action_db.pack).ref
+ pack=RunnerContainerTest.schema_output_action_db.pack,
+ ).ref
parameters = params
- context = {'user': cfg.CONF.system_user.user}
- liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp,
- action=action_ref, parameters=parameters,
- context=context)
+ context = {"user": cfg.CONF.system_user.user}
+ liveaction_db = LiveActionDB(
+ status=status,
+ start_timestamp=start_timestamp,
+ action=action_ref,
+ parameters=parameters,
+ context=context,
+ )
return liveaction_db
diff --git a/st2actions/tests/unit/test_scheduler.py b/st2actions/tests/unit/test_scheduler.py
index c23568eea1..1a7d4b9beb 100644
--- a/st2actions/tests/unit/test_scheduler.py
+++ b/st2actions/tests/unit/test_scheduler.py
@@ -20,6 +20,7 @@
import eventlet
from st2tests import config as test_config
+
test_config.parse_args()
import st2common
@@ -45,31 +46,28 @@
LIVE_ACTION = {
- 'parameters': {
- 'cmd': 'echo ":dat_face:"',
+ "parameters": {
+ "cmd": 'echo ":dat_face:"',
},
- 'action': 'core.local',
- 'status': 'requested'
+ "action": "core.local",
+ "status": "requested",
}
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml',
- 'action2.yaml'
- ],
- 'policies': [
- 'policy_3.yaml',
- 'policy_7.yaml'
- ]
+ "actions": ["action1.yaml", "action2.yaml"],
+ "policies": ["policy_3.yaml", "policy_7.yaml"],
}
@mock.patch.object(
- LiveActionPublisher, 'publish_state',
- mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state))
+ LiveActionPublisher,
+ "publish_state",
+ mock.MagicMock(
+ side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state
+ ),
+)
class ActionExecutionSchedulingQueueItemDBTest(ExecutionDbTestCase):
-
@classmethod
def setUpClass(cls):
ExecutionDbTestCase.setUpClass()
@@ -81,18 +79,21 @@ def setUpClass(cls):
register_policy_types(st2common)
loader = FixturesLoader()
- loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
def setUp(self):
super(ActionExecutionSchedulingQueueItemDBTest, self).setUp()
self.scheduler = scheduling.get_scheduler_entrypoint()
self.scheduling_queue = scheduling_queue.get_handler()
- def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED):
- action_ref = 'wolfpack.action-1'
- parameters = {'actionstr': 'fu'}
- liveaction_db = LiveActionDB(action=action_ref, parameters=parameters, status=status)
+ def _create_liveaction_db(
+ self, status=action_constants.LIVEACTION_STATUS_REQUESTED
+ ):
+ action_ref = "wolfpack.action-1"
+ parameters = {"actionstr": "fu"}
+ liveaction_db = LiveActionDB(
+ action=action_ref, parameters=parameters, status=status
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
execution_service.create_execution_object(liveaction_db, publish=False)
@@ -108,7 +109,9 @@ def test_create_from_liveaction(self):
delay,
)
- delay_date = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay)
+ delay_date = date.append_milliseconds_to_time(
+ liveaction_db.start_timestamp, delay
+ )
self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB)
self.assertEqual(schedule_q_db.scheduled_start_timestamp, delay_date)
@@ -125,12 +128,14 @@ def test_next_execution(self):
for delay in delays:
liveaction_db = self._create_liveaction_db()
- delayed_start = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay)
+ delayed_start = date.append_milliseconds_to_time(
+ liveaction_db.start_timestamp, delay
+ )
test_case = {
- 'liveaction': liveaction_db,
- 'delay': delay,
- 'delayed_start': delayed_start
+ "liveaction": liveaction_db,
+ "delay": delay,
+ "delayed_start": delayed_start,
}
test_cases.append(test_case)
@@ -139,8 +144,8 @@ def test_next_execution(self):
schedule_q_dbs.append(
ActionExecutionSchedulingQueue.add_or_update(
self.scheduler._create_execution_queue_item_db_from_liveaction(
- test_case['liveaction'],
- test_case['delay'],
+ test_case["liveaction"],
+ test_case["delay"],
)
)
)
@@ -152,22 +157,24 @@ def test_next_execution(self):
test_case = test_cases[index]
date_mock = mock.MagicMock()
- date_mock.get_datetime_utc_now.return_value = test_case['delayed_start']
+ date_mock.get_datetime_utc_now.return_value = test_case["delayed_start"]
date_mock.append_milliseconds_to_time = date.append_milliseconds_to_time
- with mock.patch('st2actions.scheduler.handler.date', date_mock):
+ with mock.patch("st2actions.scheduler.handler.date", date_mock):
schedule_q_db = self.scheduling_queue._get_next_execution()
ActionExecutionSchedulingQueue.delete(schedule_q_db)
self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB)
- self.assertEqual(schedule_q_db.delay, test_case['delay'])
- self.assertEqual(schedule_q_db.liveaction_id, str(test_case['liveaction'].id))
+ self.assertEqual(schedule_q_db.delay, test_case["delay"])
+ self.assertEqual(
+ schedule_q_db.liveaction_id, str(test_case["liveaction"].id)
+ )
# NOTE: We can't directly assert on the timestamp due to the delays on the code and
# timing variance
scheduled_start_timestamp = schedule_q_db.scheduled_start_timestamp
- test_case_start_timestamp = test_case['delayed_start']
- start_timestamp_diff = (scheduled_start_timestamp - test_case_start_timestamp)
+ test_case_start_timestamp = test_case["delayed_start"]
+ start_timestamp_diff = scheduled_start_timestamp - test_case_start_timestamp
self.assertTrue(start_timestamp_diff <= datetime.timedelta(seconds=1))
def test_next_executions_empty(self):
@@ -227,9 +234,11 @@ def test_garbage_collection(self):
schedule_q_db = self.scheduling_queue._get_next_execution()
self.assertIsNotNone(schedule_q_db)
- @mock.patch('st2actions.scheduler.handler.action_service')
- @mock.patch('st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete')
- def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_action_service):
+ @mock.patch("st2actions.scheduler.handler.action_service")
+ @mock.patch("st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete")
+ def test_processing_when_task_completed(
+ self, mock_execution_queue_delete, mock_action_service
+ ):
self.reset()
liveaction_db = self._create_liveaction_db()
@@ -245,7 +254,7 @@ def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_
mock_execution_queue_delete.assert_called_once()
ActionExecutionSchedulingQueue.delete(schedule_q_db)
- @mock.patch('st2actions.scheduler.handler.LOG')
+ @mock.patch("st2actions.scheduler.handler.LOG")
def test_failed_next_item(self, mocked_logger):
self.reset()
@@ -258,15 +267,17 @@ def test_failed_next_item(self, mocked_logger):
schedule_q_db = ActionExecutionSchedulingQueue.add_or_update(schedule_q_db)
with mock.patch(
- 'st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update',
- side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db)
+ "st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update",
+ side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db),
):
schedule_q_db = self.scheduling_queue._get_next_execution()
self.assertIsNone(schedule_q_db)
self.assertEqual(mocked_logger.info.call_count, 2)
call_args = mocked_logger.info.call_args_list[1][0]
- self.assertEqual(r'[%s] Item "%s" is already handled by another scheduler.', call_args[0])
+ self.assertEqual(
+ r'[%s] Item "%s" is already handled by another scheduler.', call_args[0]
+ )
schedule_q_db = self.scheduling_queue._get_next_execution()
self.assertIsNotNone(schedule_q_db)
@@ -288,33 +299,39 @@ def test_cleanup_policy_delayed(self):
# Manually update the liveaction to policy-delayed status.
# Using action_service.update_status will throw an exception on the
# deprecated action_constants.LIVEACTION_STATUS_POLICY_DELAYED.
- liveaction_db.status = 'policy-delayed'
+ liveaction_db.status = "policy-delayed"
liveaction_db = LiveAction.add_or_update(liveaction_db)
execution_db = execution_service.update_execution(liveaction_db)
# Check that the execution status is set to policy-delayed.
liveaction_db = LiveAction.get_by_id(str(liveaction_db.id))
- self.assertEqual(liveaction_db.status, 'policy-delayed')
+ self.assertEqual(liveaction_db.status, "policy-delayed")
execution_db = ActionExecution.get_by_id(str(execution_db.id))
- self.assertEqual(execution_db.status, 'policy-delayed')
+ self.assertEqual(execution_db.status, "policy-delayed")
# Run the clean up logic.
self.scheduling_queue._cleanup_policy_delayed()
# Check that the execution status is reset to requested.
liveaction_db = LiveAction.get_by_id(str(liveaction_db.id))
- self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
execution_db = ActionExecution.get_by_id(str(execution_db.id))
- self.assertEqual(execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
# The old entry should have been deleted. Since the execution is
# reset to requested, there should be a new scheduling entry.
new_schedule_q_db = self.scheduling_queue._get_next_execution()
self.assertIsNotNone(new_schedule_q_db)
self.assertNotEqual(str(schedule_q_db.id), str(new_schedule_q_db.id))
- self.assertEqual(schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id)
+ self.assertEqual(
+ schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id
+ )
self.assertEqual(schedule_q_db.liveaction_id, new_schedule_q_db.liveaction_id)
# TODO: Remove this test case for populating action_execution_id in v3.2.
diff --git a/st2actions/tests/unit/test_scheduler_entrypoint.py b/st2actions/tests/unit/test_scheduler_entrypoint.py
index ddcba287e7..2bc535d99d 100644
--- a/st2actions/tests/unit/test_scheduler_entrypoint.py
+++ b/st2actions/tests/unit/test_scheduler_entrypoint.py
@@ -17,6 +17,7 @@
import mock
from st2tests import config as test_config
+
test_config.parse_args()
from st2actions.cmd.scheduler import _run_scheduler
@@ -25,32 +26,30 @@
from st2tests.base import CleanDbTestCase
-__all__ = [
- 'SchedulerServiceEntryPointTestCase'
-]
+__all__ = ["SchedulerServiceEntryPointTestCase"]
def mock_handler_run(self):
# NOTE: We use eventlet.sleep to emulate async nature of this process
eventlet.sleep(0.2)
- raise Exception('handler run exception')
+ raise Exception("handler run exception")
def mock_handler_cleanup(self):
# NOTE: We use eventlet.sleep to emulate async nature of this process
eventlet.sleep(0.2)
- raise Exception('handler clean exception')
+ raise Exception("handler clean exception")
def mock_entrypoint_start(self):
# NOTE: We use eventlet.sleep to emulate async nature of this process
eventlet.sleep(0.2)
- raise Exception('entrypoint start exception')
+ raise Exception("entrypoint start exception")
class SchedulerServiceEntryPointTestCase(CleanDbTestCase):
- @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'run', mock_handler_run)
- @mock.patch('st2actions.cmd.scheduler.LOG')
+ @mock.patch.object(ActionExecutionSchedulingQueueHandler, "run", mock_handler_run)
+ @mock.patch("st2actions.cmd.scheduler.LOG")
def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_log):
run_thread = eventlet.spawn(_run_scheduler)
result = run_thread.wait()
@@ -58,26 +57,32 @@ def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_lo
self.assertEqual(result, 1)
mock_log_exception_call = mock_log.exception.call_args_list[0][0][0]
- self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call)
-
- @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'cleanup', mock_handler_cleanup)
- @mock.patch('st2actions.cmd.scheduler.LOG')
- def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup(self, mock_log):
+ self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call)
+
+ @mock.patch.object(
+ ActionExecutionSchedulingQueueHandler, "cleanup", mock_handler_cleanup
+ )
+ @mock.patch("st2actions.cmd.scheduler.LOG")
+ def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup(
+ self, mock_log
+ ):
run_thread = eventlet.spawn(_run_scheduler)
result = run_thread.wait()
self.assertEqual(result, 1)
mock_log_exception_call = mock_log.exception.call_args_list[0][0][0]
- self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call)
+ self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call)
- @mock.patch.object(SchedulerEntrypoint, 'start', mock_entrypoint_start)
- @mock.patch('st2actions.cmd.scheduler.LOG')
- def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start(self, mock_log):
+ @mock.patch.object(SchedulerEntrypoint, "start", mock_entrypoint_start)
+ @mock.patch("st2actions.cmd.scheduler.LOG")
+ def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start(
+ self, mock_log
+ ):
run_thread = eventlet.spawn(_run_scheduler)
result = run_thread.wait()
self.assertEqual(result, 1)
mock_log_exception_call = mock_log.exception.call_args_list[0][0][0]
- self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call)
+ self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call)
diff --git a/st2actions/tests/unit/test_scheduler_retry.py b/st2actions/tests/unit/test_scheduler_retry.py
index e47a2ad3eb..ad1f221df1 100644
--- a/st2actions/tests/unit/test_scheduler_retry.py
+++ b/st2actions/tests/unit/test_scheduler_retry.py
@@ -19,6 +19,7 @@
import uuid
from st2tests import config as test_config
+
test_config.parse_args()
from st2actions.scheduler import handler
@@ -27,22 +28,23 @@
from st2tests.base import CleanDbTestCase
-__all__ = [
- 'SchedulerHandlerRetryTestCase'
-]
+__all__ = ["SchedulerHandlerRetryTestCase"]
-MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB(liveaction_id=uuid.uuid4().hex)
+MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB(
+ liveaction_id=uuid.uuid4().hex
+)
class SchedulerHandlerRetryTestCase(CleanDbTestCase):
-
- @mock.patch.object(
- handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution',
- mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM]))
@mock.patch.object(
- eventlet.GreenPool, 'spawn',
- mock.MagicMock(return_value=None))
+ handler.ActionExecutionSchedulingQueueHandler,
+ "_get_next_execution",
+ mock.MagicMock(
+ side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM]
+ ),
+ )
+ @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None))
def test_handler_retry_connection_error(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
scheduling_queue_handler.process()
@@ -52,69 +54,88 @@ def test_handler_retry_connection_error(self):
eventlet.GreenPool.spawn.assert_has_calls(calls)
@mock.patch.object(
- handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution',
- mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3))
- @mock.patch.object(
- eventlet.GreenPool, 'spawn',
- mock.MagicMock(return_value=None))
+ handler.ActionExecutionSchedulingQueueHandler,
+ "_get_next_execution",
+ mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3),
+ )
+ @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None))
def test_handler_retries_exhausted(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
- self.assertRaises(pymongo.errors.ConnectionFailure, scheduling_queue_handler.process)
+ self.assertRaises(
+ pymongo.errors.ConnectionFailure, scheduling_queue_handler.process
+ )
self.assertEqual(eventlet.GreenPool.spawn.call_count, 0)
@mock.patch.object(
- handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution',
- mock.MagicMock(side_effect=KeyError()))
- @mock.patch.object(
- eventlet.GreenPool, 'spawn',
- mock.MagicMock(return_value=None))
+ handler.ActionExecutionSchedulingQueueHandler,
+ "_get_next_execution",
+ mock.MagicMock(side_effect=KeyError()),
+ )
+ @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None))
def test_handler_retry_unexpected_error(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
self.assertRaises(KeyError, scheduling_queue_handler.process)
self.assertEqual(eventlet.GreenPool.spawn.call_count, 0)
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'query',
- mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]]))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "query",
+ mock.MagicMock(
+ side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]]
+ ),
+ )
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update',
- mock.MagicMock(return_value=None))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "add_or_update",
+ mock.MagicMock(return_value=None),
+ )
def test_handler_gc_retry_connection_error(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
scheduling_queue_handler._handle_garbage_collection()
# Make sure retry occurs and that _handle_execution in process is called.
calls = [mock.call(MOCK_QUEUE_ITEM, publish=False)]
- ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls(calls)
+ ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls(
+ calls
+ )
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'query',
- mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "query",
+ mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3),
+ )
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update',
- mock.MagicMock(return_value=None))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "add_or_update",
+ mock.MagicMock(return_value=None),
+ )
def test_handler_gc_retries_exhausted(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
self.assertRaises(
pymongo.errors.ConnectionFailure,
- scheduling_queue_handler._handle_garbage_collection
+ scheduling_queue_handler._handle_garbage_collection,
)
- self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0)
+ self.assertEqual(
+ ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0
+ )
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'query',
- mock.MagicMock(side_effect=KeyError()))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "query",
+ mock.MagicMock(side_effect=KeyError()),
+ )
@mock.patch.object(
- ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update',
- mock.MagicMock(return_value=None))
+ ex_q_db_access.ActionExecutionSchedulingQueue,
+ "add_or_update",
+ mock.MagicMock(return_value=None),
+ )
def test_handler_gc_unexpected_error(self):
scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler()
- self.assertRaises(
- KeyError,
- scheduling_queue_handler._handle_garbage_collection
- )
+ self.assertRaises(KeyError, scheduling_queue_handler._handle_garbage_collection)
- self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0)
+ self.assertEqual(
+ ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0
+ )
diff --git a/st2actions/tests/unit/test_worker.py b/st2actions/tests/unit/test_worker.py
index 19ffd69553..d8637b9ac7 100644
--- a/st2actions/tests/unit/test_worker.py
+++ b/st2actions/tests/unit/test_worker.py
@@ -36,16 +36,20 @@
from st2tests.fixturesloader import FixturesLoader
import st2tests.config as tests_config
from six.moves import range
+
tests_config.parse_args()
-TEST_FIXTURES = {
- 'actions': ['local.yaml']
-}
+TEST_FIXTURES = {"actions": ["local.yaml"]}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
-NON_UTF8_RESULT = {'stderr': '', 'stdout': '\x82\n', 'succeeded': True, 'failed': False,
- 'return_code': 0}
+NON_UTF8_RESULT = {
+ "stderr": "",
+ "stdout": "\x82\n",
+ "succeeded": True,
+ "failed": False,
+ "return_code": 0,
+}
class WorkerTestCase(DbTestCase):
@@ -58,28 +62,42 @@ def setUpClass(cls):
runners_registrar.register_runners()
models = WorkerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
- WorkerTestCase.local_action_db = models['actions']['local.yaml']
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ WorkerTestCase.local_action_db = models["actions"]["local.yaml"]
def _get_liveaction_model(self, action_db, params):
status = action_constants.LIVEACTION_STATUS_REQUESTED
start_timestamp = date_utils.get_datetime_utc_now()
action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref
parameters = params
- context = {'user': cfg.CONF.system_user.user}
- liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp,
- action=action_ref, parameters=parameters,
- context=context)
+ context = {"user": cfg.CONF.system_user.user}
+ liveaction_db = LiveActionDB(
+ status=status,
+ start_timestamp=start_timestamp,
+ action=action_ref,
+ parameters=parameters,
+ context=context,
+ )
return liveaction_db
- @mock.patch.object(LocalShellCommandRunner, 'run', mock.MagicMock(
- return_value=(action_constants.LIVEACTION_STATUS_SUCCEEDED, NON_UTF8_RESULT, None)))
+ @mock.patch.object(
+ LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(
+ return_value=(
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
+ NON_UTF8_RESULT,
+ None,
+ )
+ ),
+ )
def test_non_utf8_action_result_string(self):
action_worker = actions_worker.get_worker()
- params = {
- 'cmd': "python -c 'print \"\\x82\"'"
- }
- liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params)
+ params = {"cmd": "python -c 'print \"\\x82\"'"}
+ liveaction_db = self._get_liveaction_model(
+ WorkerTestCase.local_action_db, params
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
execution_db = executions.create_execution_object(liveaction_db)
@@ -87,11 +105,15 @@ def test_non_utf8_action_result_string(self):
action_worker._run_action(liveaction_db)
except InvalidStringData:
liveaction_db = LiveAction.get_by_id(liveaction_db.id)
- self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED)
- self.assertIn('error', liveaction_db.result)
- self.assertIn('traceback', liveaction_db.result)
+ self.assertEqual(
+ liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED
+ )
+ self.assertIn("error", liveaction_db.result)
+ self.assertIn("traceback", liveaction_db.result)
execution_db = ActionExecution.get_by_id(execution_db.id)
- self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED)
+ self.assertEqual(
+ liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED
+ )
def test_worker_shutdown(self):
action_worker = actions_worker.get_worker()
@@ -107,8 +129,10 @@ def test_worker_shutdown(self):
self.assertTrue(os.path.isfile(temp_file))
# Launch the action execution in a separate thread.
- params = {'cmd': 'while [ -e \'%s\' ]; do sleep 0.1; done' % temp_file}
- liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params)
+ params = {"cmd": "while [ -e '%s' ]; do sleep 0.1; done" % temp_file}
+ liveaction_db = self._get_liveaction_model(
+ WorkerTestCase.local_action_db, params
+ )
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
runner_thread = eventlet.spawn(action_worker._run_action, liveaction_db)
@@ -127,8 +151,11 @@ def test_worker_shutdown(self):
# Verify that _running_liveactions is empty and the liveaction is abandoned.
self.assertEqual(len(action_worker._running_liveactions), 0)
- self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_ABANDONED,
- str(liveaction_db))
+ self.assertEqual(
+ liveaction_db.status,
+ action_constants.LIVEACTION_STATUS_ABANDONED,
+ str(liveaction_db),
+ )
# Make sure the temporary file has been deleted.
self.assertFalse(os.path.isfile(temp_file))
diff --git a/st2actions/tests/unit/test_workflow_engine.py b/st2actions/tests/unit/test_workflow_engine.py
index 916682d569..b8e4fae83f 100644
--- a/st2actions/tests/unit/test_workflow_engine.py
+++ b/st2actions/tests/unit/test_workflow_engine.py
@@ -26,6 +26,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2actions.workflows import workflows
@@ -46,37 +47,45 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class WorkflowExecutionHandlerTest(st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(WorkflowExecutionHandlerTest, cls).setUpClass()
@@ -86,50 +95,57 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def test_process(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_ac_ex_db)
t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id)
self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED)
# Process task2.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"}
t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0]
+ t2_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t2_ex_db.id)
+ )[0]
workflows.get_engine().process(t2_ac_ex_db)
t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id)
self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED)
# Process task3.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"}
t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0]
+ t3_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t3_ex_db.id)
+ )[0]
workflows.get_engine().process(t3_ac_ex_db)
t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id)
self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED)
# Assert the workflow has completed successfully with expected output.
- expected_output = {'msg': 'Stanley, All your base are belong to us!'}
+ expected_output = {"msg": "Stanley, All your base are belong to us!"}
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED)
self.assertDictEqual(wf_ex_db.output, expected_output)
@@ -137,37 +153,43 @@ def test_process(self):
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED)
@mock.patch.object(
- coordination_service.NoOpDriver, 'get_lock',
- mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar')))
+ coordination_service.NoOpDriver,
+ "get_lock",
+ mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")),
+ )
def test_process_error_handling(self):
expected_errors = [
{
- 'message': 'Execution failed. See result for details.',
- 'type': 'error',
- 'task_id': 'task1'
+ "message": "Execution failed. See result for details.",
+ "type": "error",
+ "task_id": "task1",
},
{
- 'type': 'error',
- 'message': 'ToozConnectionError: foobar',
- 'task_id': 'task1',
- 'route': 0
- }
+ "type": "error",
+ "message": "ToozConnectionError: foobar",
+ "task_id": "task1",
+ "route": 0,
+ },
]
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
workflows.get_engine().process(t1_ac_ex_db)
# Assert the task is marked as failed.
@@ -182,36 +204,42 @@ def test_process_error_handling(self):
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED)
@mock.patch.object(
- coordination_service.NoOpDriver, 'get_lock',
- mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar')))
+ coordination_service.NoOpDriver,
+ "get_lock",
+ mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")),
+ )
@mock.patch.object(
workflows.WorkflowExecutionHandler,
- 'fail_workflow_execution',
- mock.MagicMock(side_effect=Exception('Unexpected error.')))
+ "fail_workflow_execution",
+ mock.MagicMock(side_effect=Exception("Unexpected error.")),
+ )
def test_process_error_handling_has_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.request(lv_ac_db)
# Assert action execution is running.
lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id))
self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Process task1.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0]
+ t1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(t1_ex_db.id)
+ )[0]
self.assertRaisesRegexp(
- Exception,
- 'Unexpected error.',
- workflows.get_engine().process,
- t1_ac_ex_db
+ Exception, "Unexpected error.", workflows.get_engine().process, t1_ac_ex_db
)
- self.assertTrue(workflows.WorkflowExecutionHandler.fail_workflow_execution.called)
+ self.assertTrue(
+ workflows.WorkflowExecutionHandler.fail_workflow_execution.called
+ )
# Since error handling failed, the workflow will have status of running.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id)
diff --git a/st2api/dist_utils.py b/st2api/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2api/dist_utils.py
+++ b/st2api/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2api/setup.py b/st2api/setup.py
index 932f2e90f4..b0cfa24067 100644
--- a/st2api/setup.py
+++ b/st2api/setup.py
@@ -22,9 +22,9 @@
from dist_utils import apply_vagrant_workaround
from st2api import __version__
-ST2_COMPONENT = 'st2api'
+ST2_COMPONENT = "st2api"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -32,18 +32,18 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- scripts=[
- 'bin/st2api'
- ]
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ scripts=["bin/st2api"],
)
diff --git a/st2api/st2api/__init__.py b/st2api/st2api/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2api/st2api/__init__.py
+++ b/st2api/st2api/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2api/st2api/app.py b/st2api/st2api/app.py
index 5b10e58c3f..2483d0ef9e 100644
--- a/st2api/st2api/app.py
+++ b/st2api/st2api/app.py
@@ -36,55 +36,60 @@
def setup_app(config=None):
config = config or {}
- LOG.info('Creating st2api: %s as OpenAPI app.', VERSION_STRING)
+ LOG.info("Creating st2api: %s as OpenAPI app.", VERSION_STRING)
- is_gunicorn = config.get('is_gunicorn', False)
+ is_gunicorn = config.get("is_gunicorn", False)
if is_gunicorn:
# NOTE: We only want to perform this logic in the WSGI worker
st2api_config.register_opts()
capabilities = {
- 'name': 'api',
- 'listen_host': cfg.CONF.api.host,
- 'listen_port': cfg.CONF.api.port,
- 'type': 'active'
+ "name": "api",
+ "listen_host": cfg.CONF.api.host,
+ "listen_port": cfg.CONF.api.port,
+ "type": "active",
}
# This should be called in gunicorn case because we only want
# workers to connect to db, rabbbitmq etc. In standalone HTTP
# server case, this setup would have already occurred.
- common_setup(service='api', config=st2api_config, setup_db=True,
- register_mq_exchanges=True,
- register_signal_handlers=True,
- register_internal_trigger_types=True,
- run_migrations=True,
- service_registry=True,
- capabilities=capabilities,
- config_args=config.get('config_args', None))
+ common_setup(
+ service="api",
+ config=st2api_config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=True,
+ run_migrations=True,
+ service_registry=True,
+ capabilities=capabilities,
+ config_args=config.get("config_args", None),
+ )
# Additional pre-run time checks
validate_rbac_is_correctly_configured()
- router = Router(debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable,
- is_gunicorn=is_gunicorn)
+ router = Router(
+ debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn
+ )
- spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2')
+ spec = spec_loader.load_spec("st2common", "openapi.yaml.j2")
transforms = {
- '^/api/v1/$': ['/v1'],
- '^/api/v1/': ['/', '/v1/'],
- '^/api/v1/executions': ['/actionexecutions', '/v1/actionexecutions'],
- '^/api/exp/': ['/exp/']
+ "^/api/v1/$": ["/v1"],
+ "^/api/v1/": ["/", "/v1/"],
+ "^/api/v1/executions": ["/actionexecutions", "/v1/actionexecutions"],
+ "^/api/exp/": ["/exp/"],
}
router.add_spec(spec, transforms=transforms)
app = router.as_wsgi
# Order is important. Check middleware for detailed explanation.
- app = StreamingMiddleware(app, path_whitelist=['/v1/executions/*/output*'])
+ app = StreamingMiddleware(app, path_whitelist=["/v1/executions/*/output*"])
app = ErrorHandlingMiddleware(app)
app = CorsMiddleware(app)
app = LoggingMiddleware(app, router)
- app = ResponseInstrumentationMiddleware(app, router, service_name='api')
+ app = ResponseInstrumentationMiddleware(app, router, service_name="api")
app = RequestIDMiddleware(app)
- app = RequestInstrumentationMiddleware(app, router, service_name='api')
+ app = RequestInstrumentationMiddleware(app, router, service_name="api")
return app
diff --git a/st2api/st2api/cmd/__init__.py b/st2api/st2api/cmd/__init__.py
index 4e28bca433..0b9307922a 100644
--- a/st2api/st2api/cmd/__init__.py
+++ b/st2api/st2api/cmd/__init__.py
@@ -15,4 +15,4 @@
from st2api.cmd import api
-__all__ = ['api']
+__all__ = ["api"]
diff --git a/st2api/st2api/cmd/api.py b/st2api/st2api/cmd/api.py
index 73d3520444..1cf01d0544 100644
--- a/st2api/st2api/cmd/api.py
+++ b/st2api/st2api/cmd/api.py
@@ -21,6 +21,7 @@
# See https://github.com/StackStorm/st2/issues/4832 and https://github.com/gevent/gevent/issues/1016
# for details.
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import eventlet
@@ -31,14 +32,13 @@
from st2common.service_setup import setup as common_setup
from st2common.service_setup import teardown as common_teardown
from st2api import config
+
config.register_opts()
from st2api import app
from st2api.validation import validate_rbac_is_correctly_configured
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
@@ -48,15 +48,22 @@
def _setup():
capabilities = {
- 'name': 'api',
- 'listen_host': cfg.CONF.api.host,
- 'listen_port': cfg.CONF.api.port,
- 'type': 'active'
+ "name": "api",
+ "listen_host": cfg.CONF.api.host,
+ "listen_port": cfg.CONF.api.port,
+ "type": "active",
}
- common_setup(service='api', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, register_internal_trigger_types=True,
- service_registry=True, capabilities=capabilities)
+ common_setup(
+ service="api",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=True,
+ service_registry=True,
+ capabilities=capabilities,
+ )
# Additional pre-run time checks
validate_rbac_is_correctly_configured()
@@ -66,13 +73,15 @@ def _run_server():
host = cfg.CONF.api.host
port = cfg.CONF.api.port
- LOG.info('(PID=%s) ST2 API is serving on http://%s:%s.', os.getpid(), host, port)
+ LOG.info("(PID=%s) ST2 API is serving on http://%s:%s.", os.getpid(), host, port)
max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS
worker_pool = eventlet.GreenPool(max_pool_size)
sock = eventlet.listen((host, port))
- wsgi.server(sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False)
+ wsgi.server(
+ sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False
+ )
return 0
@@ -87,7 +96,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except Exception:
- LOG.exception('(PID=%s) ST2 API quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) ST2 API quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2api/st2api/config.py b/st2api/st2api/config.py
index 71378da9ad..35a21d87d5 100644
--- a/st2api/st2api/config.py
+++ b/st2api/st2api/config.py
@@ -32,8 +32,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -52,32 +55,38 @@ def get_logging_config_path():
def _register_app_opts():
# Note "host", "port", "allow_origin", "mask_secrets" options are registered as part of
# st2common config since they are also used outside st2api
- static_root = os.path.join(cfg.CONF.system.base_path, 'static')
- template_path = os.path.join(BASE_DIR, 'templates/')
+ static_root = os.path.join(cfg.CONF.system.base_path, "static")
+ template_path = os.path.join(BASE_DIR, "templates/")
pecan_opts = [
cfg.StrOpt(
- 'root', default='st2api.controllers.root.RootController',
- help='Action root controller'),
- cfg.StrOpt('static_root', default=static_root),
- cfg.StrOpt('template_path', default=template_path),
- cfg.ListOpt('modules', default=['st2api']),
- cfg.BoolOpt('debug', default=False),
- cfg.BoolOpt('auth_enable', default=True),
- cfg.DictOpt('errors', default={'__force_dict__': True})
+ "root",
+ default="st2api.controllers.root.RootController",
+ help="Action root controller",
+ ),
+ cfg.StrOpt("static_root", default=static_root),
+ cfg.StrOpt("template_path", default=template_path),
+ cfg.ListOpt("modules", default=["st2api"]),
+ cfg.BoolOpt("debug", default=False),
+ cfg.BoolOpt("auth_enable", default=True),
+ cfg.DictOpt("errors", default={"__force_dict__": True}),
]
- CONF.register_opts(pecan_opts, group='api_pecan')
+ CONF.register_opts(pecan_opts, group="api_pecan")
logging_opts = [
- cfg.BoolOpt('debug', default=False),
+ cfg.BoolOpt("debug", default=False),
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.api.conf',
- help='location of the logging.conf file'),
+ "logging",
+ default="/etc/st2/logging.api.conf",
+ help="location of the logging.conf file",
+ ),
cfg.IntOpt(
- 'max_page_size', default=100,
- help='Maximum limit (page size) argument which can be '
- 'specified by the user in a query string.')
+ "max_page_size",
+ default=100,
+ help="Maximum limit (page size) argument which can be "
+ "specified by the user in a query string.",
+ ),
]
- CONF.register_opts(logging_opts, group='api')
+ CONF.register_opts(logging_opts, group="api")
diff --git a/st2api/st2api/controllers/base.py b/st2api/st2api/controllers/base.py
index a3f24e2f0f..e4f13d8f1e 100644
--- a/st2api/st2api/controllers/base.py
+++ b/st2api/st2api/controllers/base.py
@@ -20,9 +20,7 @@
from st2api.controllers.controller_transforms import transform_to_bool
from st2common.rbac.backends import get_rbac_backend
-__all__ = [
- 'BaseRestControllerMixin'
-]
+__all__ = ["BaseRestControllerMixin"]
class BaseRestControllerMixin(object):
@@ -41,7 +39,9 @@ def _parse_query_params(self, request):
return query_params
- def _get_query_param_value(self, request, param_name, param_type, default_value=None):
+ def _get_query_param_value(
+ self, request, param_name, param_type, default_value=None
+ ):
"""
Return a value for the provided query param and optionally cast it for boolean types.
@@ -61,7 +61,7 @@ def _get_query_param_value(self, request, param_name, param_type, default_value=
query_params = self._parse_query_params(request=request)
value = query_params.get(param_name, default_value)
- if param_type == 'bool' and isinstance(value, six.string_types):
+ if param_type == "bool" and isinstance(value, six.string_types):
value = transform_to_bool(value)
return value
diff --git a/st2api/st2api/controllers/controller_transforms.py b/st2api/st2api/controllers/controller_transforms.py
index 8afff88da6..0ca51a0a75 100644
--- a/st2api/st2api/controllers/controller_transforms.py
+++ b/st2api/st2api/controllers/controller_transforms.py
@@ -14,9 +14,7 @@
# limitations under the License.
-__all__ = [
- 'transform_to_bool'
-]
+__all__ = ["transform_to_bool"]
def transform_to_bool(value):
@@ -27,8 +25,8 @@ def transform_to_bool(value):
Any other representation will be rejected.
"""
- if value in ['1', 'true', 'True', True]:
+ if value in ["1", "true", "True", True]:
return True
- elif value in ['0', 'false', 'False', False]:
+ elif value in ["0", "false", "False", False]:
return False
raise ValueError('Invalid bool representation "%s" provided.' % value)
diff --git a/st2api/st2api/controllers/resource.py b/st2api/st2api/controllers/resource.py
index 72611a90dc..a2391ff9aa 100644
--- a/st2api/st2api/controllers/resource.py
+++ b/st2api/st2api/controllers/resource.py
@@ -35,21 +35,19 @@
LOG = logging.getLogger(__name__)
-RESERVED_QUERY_PARAMS = {
- 'id': 'id',
- 'name': 'name',
- 'sort': 'order_by'
-}
+RESERVED_QUERY_PARAMS = {"id": "id", "name": "name", "sort": "order_by"}
def split_id_value(value):
if not value or isinstance(value, (list, tuple)):
return value
- split = value.split(',')
+ split = value.split(",")
if len(split) > 100:
- raise ValueError('Maximum of 100 items can be provided for a query parameter value')
+ raise ValueError(
+ "Maximum of 100 items can be provided for a query parameter value"
+ )
return split
@@ -57,7 +55,7 @@ def split_id_value(value):
DEFAULT_FILTER_TRANSFORM_FUNCTIONS = {
# Support for filtering on multiple ids when a commona delimited string is provided
# (e.g. ?id=1,2,3)
- 'id': split_id_value
+ "id": split_id_value
}
@@ -65,14 +63,14 @@ def parameter_validation(validator, properties, instance, schema):
parameter_specific_schema = {
"description": "Input parameters for the action.",
"type": "object",
- "patternProperties": {
- r"^\w+$": util_schema.get_action_parameters_schema()
- },
- 'additionalProperties': False,
- "default": {}
+ "patternProperties": {r"^\w+$": util_schema.get_action_parameters_schema()},
+ "additionalProperties": False,
+ "default": {},
}
- parameter_specific_validator = util_schema.CustomValidator(parameter_specific_schema)
+ parameter_specific_validator = util_schema.CustomValidator(
+ parameter_specific_schema
+ )
for error in parameter_specific_validator.iter_errors(instance=instance):
yield error
@@ -91,18 +89,16 @@ class ResourceController(object):
# ?include_attributes filter. Those attributes need to be included because a lot of code
# depends on compound references and primary keys. In addition to that, it's needed for secrets
# masking to work, etc.
- mandatory_include_fields_retrieve = ['id']
+ mandatory_include_fields_retrieve = ["id"]
# A list of fields which are always included in the response when ?include_attributes filter is
# used. Those are things such as primary keys and similar.
- mandatory_include_fields_response = ['id']
+ mandatory_include_fields_response = ["id"]
# Default number of items returned per page if no limit is explicitly provided
default_limit = 100
- query_options = {
- 'sort': []
- }
+ query_options = {"sort": []}
# A list of optional transformation functions for user provided filter values
filter_transform_functions = {}
@@ -120,7 +116,9 @@ def __init__(self):
self.supported_filters = copy.deepcopy(self.__class__.supported_filters)
self.supported_filters.update(RESERVED_QUERY_PARAMS)
- self.filter_transform_functions = copy.deepcopy(self.__class__.filter_transform_functions)
+ self.filter_transform_functions = copy.deepcopy(
+ self.__class__.filter_transform_functions
+ )
self.filter_transform_functions.update(DEFAULT_FILTER_TRANSFORM_FUNCTIONS)
self.get_one_db_method = self._get_by_name_or_id
@@ -130,9 +128,19 @@ def __init__(self):
def max_limit(self):
return cfg.CONF.api.max_page_size
- def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=None,
- sort=None, offset=0, limit=None, query_options=None,
- from_model_kwargs=None, raw_filters=None, requester_user=None):
+ def _get_all(
+ self,
+ exclude_fields=None,
+ include_fields=None,
+ advanced_filters=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ query_options=None,
+ from_model_kwargs=None,
+ raw_filters=None,
+ requester_user=None,
+ ):
"""
:param exclude_fields: A list of object fields to exclude.
:type exclude_fields: ``list``
@@ -144,8 +152,10 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
query_options = query_options if query_options else self.query_options
if exclude_fields and include_fields:
- msg = ('exclude_fields and include_fields arguments are mutually exclusive. '
- 'You need to provide either one or another, but not both.')
+ msg = (
+ "exclude_fields and include_fields arguments are mutually exclusive. "
+ "You need to provide either one or another, but not both."
+ )
raise ValueError(msg)
exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_fields)
@@ -153,18 +163,18 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
# TODO: Why do we use comma delimited string, user can just specify
# multiple values using ?sort=foo&sort=bar and we get a list back
- sort = sort.split(',') if sort else []
+ sort = sort.split(",") if sort else []
db_sort_values = []
for sort_key in sort:
- if sort_key.startswith('-'):
- direction = '-'
+ if sort_key.startswith("-"):
+ direction = "-"
sort_key = sort_key[1:]
- elif sort_key.startswith('+'):
- direction = '+'
+ elif sort_key.startswith("+"):
+ direction = "+"
sort_key = sort_key[1:]
else:
- direction = ''
+ direction = ""
if sort_key not in self.supported_filters:
# Skip unsupported sort key
@@ -173,12 +183,12 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
sort_value = direction + self.supported_filters[sort_key]
db_sort_values.append(sort_value)
- default_sort_values = copy.copy(query_options.get('sort'))
- raw_filters['sort'] = db_sort_values if db_sort_values else default_sort_values
+ default_sort_values = copy.copy(query_options.get("sort"))
+ raw_filters["sort"] = db_sort_values if db_sort_values else default_sort_values
# TODO: To protect us from DoS, we need to make max_limit mandatory
offset = int(offset)
- if offset >= 2**31:
+ if offset >= 2 ** 31:
raise ValueError('Offset "%s" specified is more than 32-bit int' % (offset))
limit = validate_limit_query_param(limit=limit, requester_user=requester_user)
@@ -195,32 +205,35 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
value_transform_function = value_transform_function or (lambda value: value)
filter_value = value_transform_function(value=filter_value)
- if k in ['id', 'name'] and isinstance(filter_value, list):
- filters[k + '__in'] = filter_value
+ if k in ["id", "name"] and isinstance(filter_value, list):
+ filters[k + "__in"] = filter_value
else:
- field_name_split = v.split('.')
+ field_name_split = v.split(".")
# Make sure filter value is a list when using "in" filter
- if field_name_split[-1] == 'in' and not isinstance(filter_value, (list, tuple)):
+ if field_name_split[-1] == "in" and not isinstance(
+ filter_value, (list, tuple)
+ ):
filter_value = [filter_value]
- filters['__'.join(field_name_split)] = filter_value
+ filters["__".join(field_name_split)] = filter_value
if advanced_filters:
- for token in advanced_filters.split(' '):
+ for token in advanced_filters.split(" "):
try:
- [k, v] = token.split(':', 1)
+ [k, v] = token.split(":", 1)
except ValueError:
raise ValueError('invalid format for filter "%s"' % token)
- path = k.split('.')
+ path = k.split(".")
try:
self.model.model._lookup_field(path)
- filters['__'.join(path)] = v
+ filters["__".join(path)] = v
except LookUpError as e:
raise ValueError(six.text_type(e))
- instances = self.access.query(exclude_fields=exclude_fields, only_fields=include_fields,
- **filters)
+ instances = self.access.query(
+ exclude_fields=exclude_fields, only_fields=include_fields, **filters
+ )
if limit == 1:
# Perform the filtering on the DB side
instances = instances.limit(limit)
@@ -228,44 +241,65 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No
from_model_kwargs = from_model_kwargs or {}
from_model_kwargs.update(self.from_model_kwargs)
- result = self.resources_model_filter(model=self.model,
- instances=instances,
- offset=offset,
- eop=eop,
- requester_user=requester_user,
- **from_model_kwargs)
+ result = self.resources_model_filter(
+ model=self.model,
+ instances=instances,
+ offset=offset,
+ eop=eop,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
resp = Response(json=result)
- resp.headers['X-Total-Count'] = str(instances.count())
+ resp.headers["X-Total-Count"] = str(instances.count())
if limit:
- resp.headers['X-Limit'] = str(limit)
+ resp.headers["X-Limit"] = str(limit)
return resp
- def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0,
- **from_model_kwargs):
+ def resources_model_filter(
+ self,
+ model,
+ instances,
+ requester_user=None,
+ offset=0,
+ eop=0,
+ **from_model_kwargs,
+ ):
"""
Method which converts DB objects to API objects and performs any additional filtering.
"""
result = []
for instance in instances[offset:eop]:
- item = self.resource_model_filter(model=model, instance=instance,
- requester_user=requester_user,
- **from_model_kwargs)
+ item = self.resource_model_filter(
+ model=model,
+ instance=instance,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
result.append(item)
return result
- def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs):
+ def resource_model_filter(
+ self, model, instance, requester_user=None, **from_model_kwargs
+ ):
"""
Method which converts DB object to API object and performs any additional filtering.
"""
item = model.from_model(instance, **from_model_kwargs)
return item
- def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=None,
- include_fields=None, from_model_kwargs=None):
+ def _get_one_by_id(
+ self,
+ id,
+ requester_user,
+ permission_type,
+ exclude_fields=None,
+ include_fields=None,
+ from_model_kwargs=None,
+ ):
"""
:param exclude_fields: A list of object fields to exclude.
:type exclude_fields: ``list``
@@ -273,14 +307,17 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non
:type include_fields: ``list``
"""
- instance = self._get_by_id(resource_id=id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ instance = self._get_by_id(
+ resource_id=id, exclude_fields=exclude_fields, include_fields=include_fields
+ )
if permission_type:
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=permission_type,
+ )
if not instance:
msg = 'Unable to identify resource with id "%s".' % id
@@ -289,21 +326,35 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non
from_model_kwargs = from_model_kwargs or {}
from_model_kwargs.update(self.from_model_kwargs)
- result = self.resource_model_filter(model=self.model, instance=instance,
- requester_user=requester_user,
- **from_model_kwargs)
+ result = self.resource_model_filter(
+ model=self.model,
+ instance=instance,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
if not result:
- LOG.debug('Not returning the result because RBAC resource isolation is enabled and '
- 'current user doesn\'t match the resource user')
- raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user,
- resource_api_or_db=instance,
- permission_type=permission_type)
+ LOG.debug(
+ "Not returning the result because RBAC resource isolation is enabled and "
+ "current user doesn't match the resource user"
+ )
+ raise ResourceAccessDeniedPermissionIsolationError(
+ user_db=requester_user,
+ resource_api_or_db=instance,
+ permission_type=permission_type,
+ )
return result
- def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type,
- exclude_fields=None, include_fields=None, from_model_kwargs=None):
+ def _get_one_by_name_or_id(
+ self,
+ name_or_id,
+ requester_user,
+ permission_type,
+ exclude_fields=None,
+ include_fields=None,
+ from_model_kwargs=None,
+ ):
"""
:param exclude_fields: A list of object fields to exclude.
:type exclude_fields: ``list``
@@ -311,14 +362,19 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type,
:type include_fields: ``list``
"""
- instance = self._get_by_name_or_id(name_or_id=name_or_id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ instance = self._get_by_name_or_id(
+ name_or_id=name_or_id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
if permission_type:
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=permission_type,
+ )
if not instance:
msg = 'Unable to identify resource with name_or_id "%s".' % (name_or_id)
@@ -330,10 +386,14 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type,
return result
- def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None,
- from_model_kwargs=None):
- instance = self._get_by_pack_ref(pack_ref=pack_ref, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ def _get_one_by_pack_ref(
+ self, pack_ref, exclude_fields=None, include_fields=None, from_model_kwargs=None
+ ):
+ instance = self._get_by_pack_ref(
+ pack_ref=pack_ref,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
if not instance:
msg = 'Unable to identify resource with pack_ref "%s".' % (pack_ref)
@@ -347,8 +407,11 @@ def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=Non
def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None):
try:
- resource_db = self.access.get(id=resource_id, exclude_fields=exclude_fields,
- only_fields=include_fields)
+ resource_db = self.access.get(
+ id=resource_id,
+ exclude_fields=exclude_fields,
+ only_fields=include_fields,
+ )
except ValidationError:
resource_db = None
@@ -356,8 +419,11 @@ def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None):
def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None):
try:
- resource_db = self.access.get(name=resource_name, exclude_fields=exclude_fields,
- only_fields=include_fields)
+ resource_db = self.access.get(
+ name=resource_name,
+ exclude_fields=exclude_fields,
+ only_fields=include_fields,
+ )
except Exception:
resource_db = None
@@ -365,8 +431,9 @@ def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None):
def _get_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None):
try:
- resource_db = self.access.get(pack=pack_ref, exclude_fields=exclude_fields,
- only_fields=include_fields)
+ resource_db = self.access.get(
+ pack=pack_ref, exclude_fields=exclude_fields, only_fields=include_fields
+ )
except Exception:
resource_db = None
@@ -376,13 +443,17 @@ def _get_by_name_or_id(self, name_or_id, exclude_fields=None, include_fields=Non
"""
Retrieve resource object by an id of a name.
"""
- resource_db = self._get_by_id(resource_id=name_or_id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ resource_db = self._get_by_id(
+ resource_id=name_or_id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
if not resource_db:
# Try name
- resource_db = self._get_by_name(resource_name=name_or_id,
- exclude_fields=exclude_fields)
+ resource_db = self._get_by_name(
+ resource_name=name_or_id, exclude_fields=exclude_fields
+ )
if not resource_db:
msg = 'Resource with a name or id "%s" not found' % (name_or_id)
@@ -402,11 +473,16 @@ def _get_one_by_scope_and_name(self, scope, name, from_model_kwargs=None):
"""
instance = self.access.get_by_scope_and_name(scope=scope, name=name)
if not instance:
- msg = 'KeyValuePair with name: %s and scope: %s not found in db.' % (name, scope)
+ msg = "KeyValuePair with name: %s and scope: %s not found in db." % (
+ name,
+ scope,
+ )
raise StackStormDBObjectNotFoundError(msg)
from_model_kwargs = from_model_kwargs or {}
result = self.model.from_model(instance, **from_model_kwargs)
- LOG.debug('GET with scope=%s and name=%s, client_result=%s', scope, name, result)
+ LOG.debug(
+ "GET with scope=%s and name=%s, client_result=%s", scope, name, result
+ )
return result
@@ -422,7 +498,7 @@ def _validate_exclude_fields(self, exclude_fields):
for field in exclude_fields:
if field not in self.valid_exclude_attributes:
- msg = ('Invalid or unsupported exclude attribute specified: %s' % (field))
+ msg = "Invalid or unsupported exclude attribute specified: %s" % (field)
raise ValueError(msg)
return exclude_fields
@@ -438,7 +514,7 @@ def _validate_include_fields(self, include_fields):
for field in self.mandatory_include_fields_retrieve:
# Don't add mandatory field if user already requested the whole dict object (e.g. user
# requests action and action.parameters is a mandatory field)
- partial_field = field.split('.')[0]
+ partial_field = field.split(".")[0]
if partial_field in include_fields:
continue
@@ -456,20 +532,38 @@ class BaseResourceIsolationControllerMixin(object):
users).
"""
- def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0,
- **from_model_kwargs):
+ def resources_model_filter(
+ self,
+ model,
+ instances,
+ requester_user=None,
+ offset=0,
+ eop=0,
+ **from_model_kwargs,
+ ):
# RBAC or permission isolation is disabled, bail out
if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation):
- result = super(BaseResourceIsolationControllerMixin, self).resources_model_filter(
- model=model, instances=instances, requester_user=requester_user,
- offset=offset, eop=eop, **from_model_kwargs)
+ result = super(
+ BaseResourceIsolationControllerMixin, self
+ ).resources_model_filter(
+ model=model,
+ instances=instances,
+ requester_user=requester_user,
+ offset=offset,
+ eop=eop,
+ **from_model_kwargs,
+ )
return result
result = []
for instance in instances[offset:eop]:
- item = self.resource_model_filter(model=model, instance=instance,
- requester_user=requester_user, **from_model_kwargs)
+ item = self.resource_model_filter(
+ model=model,
+ instance=instance,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
if not item:
continue
@@ -478,18 +572,25 @@ def resources_model_filter(self, model, instances, requester_user=None, offset=0
return result
- def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs):
+ def resource_model_filter(
+ self, model, instance, requester_user=None, **from_model_kwargs
+ ):
# RBAC or permission isolation is disabled, bail out
if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation):
- result = super(BaseResourceIsolationControllerMixin, self).resource_model_filter(
- model=model, instance=instance, requester_user=requester_user,
- **from_model_kwargs)
+ result = super(
+ BaseResourceIsolationControllerMixin, self
+ ).resource_model_filter(
+ model=model,
+ instance=instance,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
return result
rbac_utils = get_rbac_backend().get_utils_class()
user_is_admin = rbac_utils.user_is_admin(user_db=requester_user)
- user_is_system_user = (requester_user.name == cfg.CONF.system_user.user)
+ user_is_system_user = requester_user.name == cfg.CONF.system_user.user
item = model.from_model(instance, **from_model_kwargs)
@@ -497,7 +598,7 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod
if user_is_admin or user_is_system_user:
return item
- user = item.context.get('user', None)
+ user = item.context.get("user", None)
if user and (user == requester_user.name):
return item
@@ -506,21 +607,31 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod
class ContentPackResourceController(ResourceController):
# name and pack are mandatory because they compromise primary key - reference (.)
- mandatory_include_fields_retrieve = ['pack', 'name']
+ mandatory_include_fields_retrieve = ["pack", "name"]
# A list of fields which are always included in the response. Those are things such as primary
# keys and similar
- mandatory_include_fields_response = ['id', 'ref']
+ mandatory_include_fields_response = ["id", "ref"]
def __init__(self):
super(ContentPackResourceController, self).__init__()
self.get_one_db_method = self._get_by_ref_or_id
- def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=None,
- include_fields=None, from_model_kwargs=None):
+ def _get_one(
+ self,
+ ref_or_id,
+ requester_user,
+ permission_type,
+ exclude_fields=None,
+ include_fields=None,
+ from_model_kwargs=None,
+ ):
try:
- instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ instance = self._get_by_ref_or_id(
+ ref_or_id=ref_or_id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
except Exception as e:
LOG.exception(six.text_type(e))
abort(http_client.NOT_FOUND, six.text_type(e))
@@ -528,40 +639,59 @@ def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=No
if permission_type:
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=permission_type,
+ )
# Perform resource isolation check (if supported)
from_model_kwargs = from_model_kwargs or {}
from_model_kwargs.update(self.from_model_kwargs)
- result = self.resource_model_filter(model=self.model, instance=instance,
- requester_user=requester_user,
- **from_model_kwargs)
+ result = self.resource_model_filter(
+ model=self.model,
+ instance=instance,
+ requester_user=requester_user,
+ **from_model_kwargs,
+ )
if not result:
- LOG.debug('Not returning the result because RBAC resource isolation is enabled and '
- 'current user doesn\'t match the resource user')
- raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user,
- resource_api_or_db=instance,
- permission_type=permission_type)
+ LOG.debug(
+ "Not returning the result because RBAC resource isolation is enabled and "
+ "current user doesn't match the resource user"
+ )
+ raise ResourceAccessDeniedPermissionIsolationError(
+ user_db=requester_user,
+ resource_api_or_db=instance,
+ permission_type=permission_type,
+ )
return Response(json=result)
- def _get_all(self, exclude_fields=None, include_fields=None,
- sort=None, offset=0, limit=None, query_options=None,
- from_model_kwargs=None, raw_filters=None, requester_user=None):
- resp = super(ContentPackResourceController,
- self)._get_all(exclude_fields=exclude_fields,
- include_fields=include_fields,
- sort=sort,
- offset=offset,
- limit=limit,
- query_options=query_options,
- from_model_kwargs=from_model_kwargs,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def _get_all(
+ self,
+ exclude_fields=None,
+ include_fields=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ query_options=None,
+ from_model_kwargs=None,
+ raw_filters=None,
+ requester_user=None,
+ ):
+ resp = super(ContentPackResourceController, self)._get_all(
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ query_options=query_options,
+ from_model_kwargs=from_model_kwargs,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
return resp
@@ -574,8 +704,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None)
"""
if exclude_fields and include_fields:
- msg = ('exclude_fields and include_fields arguments are mutually exclusive. '
- 'You need to provide either one or another, but not both.')
+ msg = (
+ "exclude_fields and include_fields arguments are mutually exclusive. "
+ "You need to provide either one or another, but not both."
+ )
raise ValueError(msg)
if ResourceReference.is_resource_reference(ref_or_id):
@@ -585,11 +717,17 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None)
is_reference = False
if is_reference:
- resource_db = self._get_by_ref(resource_ref=ref_or_id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ resource_db = self._get_by_ref(
+ resource_ref=ref_or_id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
else:
- resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields,
- include_fields=include_fields)
+ resource_db = self._get_by_id(
+ resource_id=ref_or_id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ )
if not resource_db:
msg = 'Resource with a reference or id "%s" not found' % (ref_or_id)
@@ -599,8 +737,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None)
def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None):
if exclude_fields and include_fields:
- msg = ('exclude_fields and include_fields arguments are mutually exclusive. '
- 'You need to provide either one or another, but not both.')
+ msg = (
+ "exclude_fields and include_fields arguments are mutually exclusive. "
+ "You need to provide either one or another, but not both."
+ )
raise ValueError(msg)
try:
@@ -608,9 +748,12 @@ def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None):
except Exception:
return None
- resource_db = self.access.query(name=ref.name, pack=ref.pack,
- exclude_fields=exclude_fields,
- only_fields=include_fields).first()
+ resource_db = self.access.query(
+ name=ref.name,
+ pack=ref.pack,
+ exclude_fields=exclude_fields,
+ only_fields=include_fields,
+ ).first()
return resource_db
@@ -629,25 +772,29 @@ def validate_limit_query_param(limit, requester_user=None):
if int(limit) == -1:
if not user_is_admin:
# Only admins can specify limit -1
- message = ('Administrator access required to be able to specify limit=-1 and '
- 'retrieve all the records')
- raise AccessDeniedError(message=message,
- user_db=requester_user)
+ message = (
+ "Administrator access required to be able to specify limit=-1 and "
+ "retrieve all the records"
+ )
+ raise AccessDeniedError(message=message, user_db=requester_user)
return 0
elif int(limit) <= -2:
msg = 'Limit, "%s" specified, must be a positive number.' % (limit)
raise ValueError(msg)
elif int(limit) > cfg.CONF.api.max_page_size and not user_is_admin:
- msg = ('Limit "%s" specified, maximum value is "%s"' % (limit,
- cfg.CONF.api.max_page_size))
+ msg = 'Limit "%s" specified, maximum value is "%s"' % (
+ limit,
+ cfg.CONF.api.max_page_size,
+ )
- raise AccessDeniedError(message=msg,
- user_db=requester_user)
+ raise AccessDeniedError(message=msg, user_db=requester_user)
# Disable n = 0
elif limit == 0:
- msg = ('Limit, "%s" specified, must be a positive number or -1 for full result set.' %
- (limit))
+ msg = (
+ 'Limit, "%s" specified, must be a positive number or -1 for full result set.'
+ % (limit)
+ )
raise ValueError(msg)
return limit
diff --git a/st2api/st2api/controllers/root.py b/st2api/st2api/controllers/root.py
index c2db487b02..2d5d953afa 100644
--- a/st2api/st2api/controllers/root.py
+++ b/st2api/st2api/controllers/root.py
@@ -15,23 +15,21 @@
from st2common import __version__
-__all__ = [
- 'RootController'
-]
+__all__ = ["RootController"]
class RootController(object):
def index(self):
data = {}
- if 'dev' in __version__:
- docs_url = 'http://docs.stackstorm.com/latest'
+ if "dev" in __version__:
+ docs_url = "http://docs.stackstorm.com/latest"
else:
- docs_version = '.'.join(__version__.split('.')[:2])
- docs_url = 'http://docs.stackstorm.com/%s' % docs_version
+ docs_version = ".".join(__version__.split(".")[:2])
+ docs_url = "http://docs.stackstorm.com/%s" % docs_version
- data['version'] = __version__
- data['docs_url'] = docs_url
+ data["version"] = __version__
+ data["docs_url"] = docs_url
return data
diff --git a/st2api/st2api/controllers/v1/action_views.py b/st2api/st2api/controllers/v1/action_views.py
index d1701ebfbf..2e528b5b13 100644
--- a/st2api/st2api/controllers/v1/action_views.py
+++ b/st2api/st2api/controllers/v1/action_views.py
@@ -33,11 +33,7 @@
from st2common.router import abort
from st2common.router import Response
-__all__ = [
- 'OverviewController',
- 'ParametersViewController',
- 'EntryPointController'
-]
+__all__ = ["OverviewController", "ParametersViewController", "EntryPointController"]
http_client = six.moves.http_client
@@ -45,7 +41,6 @@
class LookupUtils(object):
-
@staticmethod
def _get_action_by_id(id):
try:
@@ -75,31 +70,33 @@ def _get_runner_by_name(name):
class ParametersViewController(object):
-
def get_one(self, action_id, requester_user):
return self._get_one(action_id, requester_user=requester_user)
@staticmethod
def _get_one(action_id, requester_user):
"""
- List merged action & runner parameters by action id.
+ List merged action & runner parameters by action id.
- Handle:
- GET /actions/views/parameters/1
+ Handle:
+ GET /actions/views/parameters/1
"""
action_db = LookupUtils._get_action_by_id(action_id)
permission_type = PermissionType.ACTION_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
- runner_db = LookupUtils._get_runner_by_name(action_db.runner_type['name'])
+ runner_db = LookupUtils._get_runner_by_name(action_db.runner_type["name"])
all_params = action_param_utils.get_params_view(
- action_db=action_db, runner_db=runner_db, merged_only=True)
+ action_db=action_db, runner_db=runner_db, merged_only=True
+ )
- return {'parameters': all_params}
+ return {"parameters": all_params}
class OverviewController(resource.ContentPackResourceController):
@@ -107,47 +104,54 @@ class OverviewController(resource.ContentPackResourceController):
access = Action
supported_filters = {}
- query_options = {
- 'sort': ['pack', 'name']
- }
+ query_options = {"sort": ["pack", "name"]}
- mandatory_include_fields_retrieve = [
- 'pack',
- 'name',
- 'parameters',
- 'runner_type'
- ]
+ mandatory_include_fields_retrieve = ["pack", "name", "parameters", "runner_type"]
def get_one(self, ref_or_id, requester_user):
"""
- List action by id.
+ List action by id.
- Handle:
- GET /actions/views/overview/1
+ Handle:
+ GET /actions/views/overview/1
"""
- resp = super(OverviewController, self)._get_one(ref_or_id,
- requester_user=requester_user,
- permission_type=PermissionType.ACTION_VIEW)
+ resp = super(OverviewController, self)._get_one(
+ ref_or_id,
+ requester_user=requester_user,
+ permission_type=PermissionType.ACTION_VIEW,
+ )
action_api = ActionAPI(**resp.json)
- result = self._transform_action_api(action_api=action_api, requester_user=requester_user)
+ result = self._transform_action_api(
+ action_api=action_api, requester_user=requester_user
+ )
resp.json = result
return resp
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
"""
- List all actions.
+ List all actions.
- Handles requests:
- GET /actions/views/overview
+ Handles requests:
+ GET /actions/views/overview
"""
- resp = super(OverviewController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ resp = super(OverviewController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
runner_type_names = set([])
action_ids = []
@@ -164,9 +168,12 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o
# N * 2 additional queries
# 1. Retrieve all the respective runner objects - we only need parameters
- runner_type_dbs = RunnerType.query(name__in=runner_type_names,
- only_fields=['name', 'runner_parameters'])
- runner_type_dbs = dict([(runner_db.name, runner_db) for runner_db in runner_type_dbs])
+ runner_type_dbs = RunnerType.query(
+ name__in=runner_type_names, only_fields=["name", "runner_parameters"]
+ )
+ runner_type_dbs = dict(
+ [(runner_db.name, runner_db) for runner_db in runner_type_dbs]
+ )
# 2. Retrieve all the respective action objects - we only need parameters
action_dbs = dict([(action_db.id, action_db) for action_db in result])
@@ -174,9 +181,9 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o
for action_api in result:
action_db = action_dbs.get(action_api.id, None)
runner_db = runner_type_dbs.get(action_api.runner_type, None)
- all_params = action_param_utils.get_params_view(action_db=action_db,
- runner_db=runner_db,
- merged_only=True)
+ all_params = action_param_utils.get_params_view(
+ action_db=action_db, runner_db=runner_db, merged_only=True
+ )
action_api.parameters = all_params
resp.json = result
@@ -185,9 +192,10 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o
@staticmethod
def _transform_action_api(action_api, requester_user):
action_id = action_api.id
- result = ParametersViewController._get_one(action_id=action_id,
- requester_user=requester_user)
- action_api.parameters = result.get('parameters', {})
+ result = ParametersViewController._get_one(
+ action_id=action_id, requester_user=requester_user
+ )
+ action_api.parameters = result.get("parameters", {})
return action_api
@@ -202,35 +210,38 @@ def get_all(self):
def get_one(self, ref_or_id, requester_user):
"""
- Outputs the file associated with action entry_point
+ Outputs the file associated with action entry_point
- Handles requests:
- GET /actions/views/entry_point/1
+ Handles requests:
+ GET /actions/views/entry_point/1
"""
- LOG.info('GET /actions/views/entry_point with ref_or_id=%s', ref_or_id)
+ LOG.info("GET /actions/views/entry_point with ref_or_id=%s", ref_or_id)
action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
permission_type = PermissionType.ACTION_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
- pack = getattr(action_db, 'pack', None)
- entry_point = getattr(action_db, 'entry_point', None)
+ pack = getattr(action_db, "pack", None)
+ entry_point = getattr(action_db, "entry_point", None)
abs_path = utils.get_entry_point_abs_path(pack, entry_point)
if not abs_path:
- raise StackStormDBObjectNotFoundError('Action ref_or_id=%s has no entry_point to output'
- % ref_or_id)
+ raise StackStormDBObjectNotFoundError(
+ "Action ref_or_id=%s has no entry_point to output" % ref_or_id
+ )
- with codecs.open(abs_path, 'r') as fp:
+ with codecs.open(abs_path, "r") as fp:
content = fp.read()
# Ensure content is utf-8
if isinstance(content, six.binary_type):
- content = content.decode('utf-8')
+ content = content.decode("utf-8")
try:
content_type = mimetypes.guess_type(abs_path)[0]
@@ -240,15 +251,15 @@ def get_one(self, ref_or_id, requester_user):
# Special case if /etc/mime.types doesn't contain entry for yaml, py
if not content_type:
_, extension = os.path.splitext(abs_path)
- if extension in ['.yaml', '.yml']:
- content_type = 'application/x-yaml'
- elif extension in ['.py']:
- content_type = 'application/x-python'
+ if extension in [".yaml", ".yml"]:
+ content_type = "application/x-yaml"
+ elif extension in [".py"]:
+ content_type = "application/x-python"
else:
- content_type = 'text/plain'
+ content_type = "text/plain"
response = Response()
- response.headers['Content-Type'] = content_type
+ response.headers["Content-Type"] = content_type
response.text = content
return response
diff --git a/st2api/st2api/controllers/v1/actionalias.py b/st2api/st2api/controllers/v1/actionalias.py
index 00e58675f9..5488300d6e 100644
--- a/st2api/st2api/controllers/v1/actionalias.py
+++ b/st2api/st2api/controllers/v1/actionalias.py
@@ -37,175 +37,219 @@
class ActionAliasController(resource.ContentPackResourceController):
"""
- Implements the RESTful interface for ActionAliases.
+ Implements the RESTful interface for ActionAliases.
"""
+
model = ActionAliasAPI
access = ActionAlias
- supported_filters = {
- 'name': 'name',
- 'pack': 'pack'
- }
-
- query_options = {
- 'sort': ['pack', 'name']
- }
-
- _custom_actions = {
- 'match': ['POST'],
- 'help': ['POST']
- }
-
- def get_all(self, exclude_attributes=None, include_attributes=None,
- sort=None, offset=0, limit=None, requester_user=None, **raw_filters):
- return super(ActionAliasController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ supported_filters = {"name": "name", "pack": "pack"}
+
+ query_options = {"sort": ["pack", "name"]}
+
+ _custom_actions = {"match": ["POST"], "help": ["POST"]}
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(ActionAliasController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user):
permission_type = PermissionType.ACTION_ALIAS_VIEW
- return super(ActionAliasController, self)._get_one(ref_or_id,
- requester_user=requester_user,
- permission_type=permission_type)
+ return super(ActionAliasController, self)._get_one(
+ ref_or_id, requester_user=requester_user, permission_type=permission_type
+ )
def match(self, action_alias_match_api):
"""
- Find a matching action alias.
+ Find a matching action alias.
- Handles requests:
- POST /actionalias/match
+ Handles requests:
+ POST /actionalias/match
"""
command = action_alias_match_api.command
try:
format_ = get_matching_alias(command=command)
except ActionAliasAmbiguityException as e:
- LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches))
+ LOG.exception(
+ 'Command "%s" matched (%s) patterns.', e.command, len(e.matches)
+ )
return abort(http_client.BAD_REQUEST, six.text_type(e))
# Convert ActionAliasDB to API
- action_alias_api = ActionAliasAPI.from_model(format_['alias'])
+ action_alias_api = ActionAliasAPI.from_model(format_["alias"])
return {
- 'actionalias': action_alias_api,
- 'display': format_['display'],
- 'representation': format_['representation'],
+ "actionalias": action_alias_api,
+ "display": format_["display"],
+ "representation": format_["representation"],
}
def help(self, filter, pack, limit, offset, **kwargs):
"""
- Get available help strings for action aliases.
+ Get available help strings for action aliases.
- Handles requests:
- GET /actionalias/help
+ Handles requests:
+ GET /actionalias/help
"""
try:
aliases_resp = super(ActionAliasController, self)._get_all(**kwargs)
aliases = [ActionAliasAPI(**alias) for alias in aliases_resp.json]
- return generate_helpstring_result(aliases, filter, pack, int(limit), int(offset))
+ return generate_helpstring_result(
+ aliases, filter, pack, int(limit), int(offset)
+ )
except (TypeError) as e:
- LOG.exception('Helpstring request contains an invalid data type: %s.', six.text_type(e))
+ LOG.exception(
+ "Helpstring request contains an invalid data type: %s.",
+ six.text_type(e),
+ )
return abort(http_client.BAD_REQUEST, six.text_type(e))
def post(self, action_alias, requester_user):
"""
- Create a new ActionAlias.
+ Create a new ActionAlias.
- Handles requests:
- POST /actionalias/
+ Handles requests:
+ POST /actionalias/
"""
permission_type = PermissionType.ACTION_ALIAS_CREATE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user,
- resource_api=action_alias,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_api_permission(
+ user_db=requester_user,
+ resource_api=action_alias,
+ permission_type=permission_type,
+ )
try:
action_alias_db = ActionAliasAPI.to_model(action_alias)
- LOG.debug('/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s',
- action_alias_db)
+ LOG.debug(
+ "/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s",
+ action_alias_db,
+ )
action_alias_db = ActionAlias.add_or_update(action_alias_db)
except (ValidationError, ValueError, ValueValidationException) as e:
- LOG.exception('Validation failed for action alias data=%s.', action_alias)
+ LOG.exception("Validation failed for action alias data=%s.", action_alias)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'action_alias_db': action_alias_db}
- LOG.audit('Action alias created. ActionAlias.id=%s' % (action_alias_db.id), extra=extra)
+ extra = {"action_alias_db": action_alias_db}
+ LOG.audit(
+ "Action alias created. ActionAlias.id=%s" % (action_alias_db.id),
+ extra=extra,
+ )
action_alias_api = ActionAliasAPI.from_model(action_alias_db)
return Response(json=action_alias_api, status=http_client.CREATED)
def put(self, action_alias, ref_or_id, requester_user):
"""
- Update an action alias.
+ Update an action alias.
- Handles requests:
- PUT /actionalias/1
+ Handles requests:
+ PUT /actionalias/1
"""
action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
- LOG.debug('PUT /actionalias/ lookup with id=%s found object: %s', ref_or_id,
- action_alias_db)
+ LOG.debug(
+ "PUT /actionalias/ lookup with id=%s found object: %s",
+ ref_or_id,
+ action_alias_db,
+ )
permission_type = PermissionType.ACTION_ALIAS_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_alias_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_alias_db,
+ permission_type=permission_type,
+ )
- if not hasattr(action_alias, 'id'):
+ if not hasattr(action_alias, "id"):
action_alias.id = None
try:
- if action_alias.id is not None and action_alias.id != '' and \
- action_alias.id != ref_or_id:
- LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.',
- action_alias.id, ref_or_id)
+ if (
+ action_alias.id is not None
+ and action_alias.id != ""
+ and action_alias.id != ref_or_id
+ ):
+ LOG.warning(
+ "Discarding mismatched id=%s found in payload and using uri_id=%s.",
+ action_alias.id,
+ ref_or_id,
+ )
old_action_alias_db = action_alias_db
action_alias_db = ActionAliasAPI.to_model(action_alias)
action_alias_db.id = ref_or_id
action_alias_db = ActionAlias.add_or_update(action_alias_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for action alias data=%s', action_alias)
+ LOG.exception("Validation failed for action alias data=%s", action_alias)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'old_action_alias_db': old_action_alias_db, 'new_action_alias_db': action_alias_db}
- LOG.audit('Action alias updated. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra)
+ extra = {
+ "old_action_alias_db": old_action_alias_db,
+ "new_action_alias_db": action_alias_db,
+ }
+ LOG.audit(
+ "Action alias updated. ActionAlias.id=%s." % (action_alias_db.id),
+ extra=extra,
+ )
action_alias_api = ActionAliasAPI.from_model(action_alias_db)
return action_alias_api
def delete(self, ref_or_id, requester_user):
"""
- Delete an action alias.
+ Delete an action alias.
- Handles requests:
- DELETE /actionalias/1
+ Handles requests:
+ DELETE /actionalias/1
"""
action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
- LOG.debug('DELETE /actionalias/ lookup with id=%s found object: %s', ref_or_id,
- action_alias_db)
+ LOG.debug(
+ "DELETE /actionalias/ lookup with id=%s found object: %s",
+ ref_or_id,
+ action_alias_db,
+ )
permission_type = PermissionType.ACTION_ALIAS_DELETE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_alias_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_alias_db,
+ permission_type=permission_type,
+ )
try:
ActionAlias.delete(action_alias_db)
except Exception as e:
- LOG.exception('Database delete encountered exception during delete of id="%s".',
- ref_or_id)
+ LOG.exception(
+ 'Database delete encountered exception during delete of id="%s".',
+ ref_or_id,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
- extra = {'action_alias_db': action_alias_db}
- LOG.audit('Action alias deleted. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra)
+ extra = {"action_alias_db": action_alias_db}
+ LOG.audit(
+ "Action alias deleted. ActionAlias.id=%s." % (action_alias_db.id),
+ extra=extra,
+ )
return Response(status=http_client.NO_CONTENT)
diff --git a/st2api/st2api/controllers/v1/actionexecutions.py b/st2api/st2api/controllers/v1/actionexecutions.py
index b0aa4e9e1d..3cc7741b2d 100644
--- a/st2api/st2api/controllers/v1/actionexecutions.py
+++ b/st2api/st2api/controllers/v1/actionexecutions.py
@@ -54,18 +54,15 @@
from st2common.rbac.types import PermissionType
from st2common.rbac.backends import get_rbac_backend
-__all__ = [
- 'ActionExecutionsController'
-]
+__all__ = ["ActionExecutionsController"]
LOG = logging.getLogger(__name__)
# Note: We initialize filters here and not in the constructor
SUPPORTED_EXECUTIONS_FILTERS = copy.deepcopy(SUPPORTED_FILTERS)
-SUPPORTED_EXECUTIONS_FILTERS.update({
- 'timestamp_gt': 'start_timestamp.gt',
- 'timestamp_lt': 'start_timestamp.lt'
-})
+SUPPORTED_EXECUTIONS_FILTERS.update(
+ {"timestamp_gt": "start_timestamp.gt", "timestamp_lt": "start_timestamp.lt"}
+)
MONITOR_THREAD_EMPTY_Q_SLEEP_TIME = 5
MONITOR_THREAD_NO_WORKERS_SLEEP_TIME = 1
@@ -82,29 +79,24 @@ class ActionExecutionsControllerMixin(BaseRestControllerMixin):
# Those two attributes are mandatory so we can correctly determine and mask secret execution
# parameters
mandatory_include_fields_retrieve = [
- 'action.parameters',
- 'runner.runner_parameters',
- 'parameters',
-
+ "action.parameters",
+ "runner.runner_parameters",
+ "parameters",
# Attributes below are mandatory for RBAC installations
- 'action.pack',
- 'action.uid',
-
+ "action.pack",
+ "action.uid",
# Required when rbac.permission_isolation is enabled
- 'context'
+ "context",
]
# A list of attributes which can be specified using ?exclude_attributes filter
# NOTE: Allowing user to exclude attribute such as action and runner would break secrets
# masking
- valid_exclude_attributes = [
- 'result',
- 'trigger_instance',
- 'status'
- ]
+ valid_exclude_attributes = ["result", "trigger_instance", "status"]
- def _handle_schedule_execution(self, liveaction_api, requester_user, context_string=None,
- show_secrets=False):
+ def _handle_schedule_execution(
+ self, liveaction_api, requester_user, context_string=None, show_secrets=False
+ ):
"""
:param liveaction: LiveActionAPI object.
:type liveaction: :class:`LiveActionAPI`
@@ -124,101 +116,129 @@ def _handle_schedule_execution(self, liveaction_api, requester_user, context_str
# Assert the permissions
permission_type = PermissionType.ACTION_EXECUTE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
# Validate that the authenticated user is admin if user query param is provided
user = liveaction_api.user or requester_user.name
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user
+ )
try:
- return self._schedule_execution(liveaction=liveaction_api,
- requester_user=requester_user,
- user=user,
- context_string=context_string,
- show_secrets=show_secrets,
- action_db=action_db)
+ return self._schedule_execution(
+ liveaction=liveaction_api,
+ requester_user=requester_user,
+ user=user,
+ context_string=context_string,
+ show_secrets=show_secrets,
+ action_db=action_db,
+ )
except ValueError as e:
- LOG.exception('Unable to execute action.')
+ LOG.exception("Unable to execute action.")
abort(http_client.BAD_REQUEST, six.text_type(e))
except jsonschema.ValidationError as e:
- LOG.exception('Unable to execute action. Parameter validation failed.')
- abort(http_client.BAD_REQUEST, re.sub("u'([^']*)'", r"'\1'",
- getattr(e, 'message', six.text_type(e))))
+ LOG.exception("Unable to execute action. Parameter validation failed.")
+ abort(
+ http_client.BAD_REQUEST,
+ re.sub("u'([^']*)'", r"'\1'", getattr(e, "message", six.text_type(e))),
+ )
except trace_exc.TraceNotFoundException as e:
abort(http_client.BAD_REQUEST, six.text_type(e))
except validation_exc.ValueValidationException as e:
raise e
except Exception as e:
- LOG.exception('Unable to execute action. Unexpected error encountered.')
+ LOG.exception("Unable to execute action. Unexpected error encountered.")
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
- def _schedule_execution(self, liveaction, requester_user, action_db, user=None,
- context_string=None, show_secrets=False):
+ def _schedule_execution(
+ self,
+ liveaction,
+ requester_user,
+ action_db,
+ user=None,
+ context_string=None,
+ show_secrets=False,
+ ):
# Initialize execution context if it does not exist.
- if not hasattr(liveaction, 'context'):
+ if not hasattr(liveaction, "context"):
liveaction.context = dict()
- liveaction.context['user'] = user
- liveaction.context['pack'] = action_db.pack
+ liveaction.context["user"] = user
+ liveaction.context["pack"] = action_db.pack
- LOG.debug('User is: %s' % liveaction.context['user'])
+ LOG.debug("User is: %s" % liveaction.context["user"])
# Retrieve other st2 context from request header.
if context_string:
context = try_loads(context_string)
if not isinstance(context, dict):
- raise ValueError('Unable to convert st2-context from the headers into JSON.')
+ raise ValueError(
+ "Unable to convert st2-context from the headers into JSON."
+ )
liveaction.context.update(context)
# Include RBAC context (if RBAC is available and enabled)
if cfg.CONF.rbac.enable:
user_db = UserDB(name=user)
rbac_service = get_rbac_backend().get_service_class()
- role_dbs = rbac_service.get_roles_for_user(user_db=user_db, include_remote=True)
+ role_dbs = rbac_service.get_roles_for_user(
+ user_db=user_db, include_remote=True
+ )
roles = [role_db.name for role_db in role_dbs]
- liveaction.context['rbac'] = {
- 'user': user,
- 'roles': roles
- }
+ liveaction.context["rbac"] = {"user": user, "roles": roles}
# Schedule the action execution.
liveaction_db = LiveActionAPI.to_model(liveaction)
- runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_utils.get_runnertype_by_name(
+ action_db.runner_type["name"]
+ )
try:
liveaction_db.parameters = param_utils.render_live_params(
- runnertype_db.runner_parameters, action_db.parameters, liveaction_db.parameters,
- liveaction_db.context)
+ runnertype_db.runner_parameters,
+ action_db.parameters,
+ liveaction_db.parameters,
+ liveaction_db.context,
+ )
except param_exc.ParamException:
# We still need to create a request, so liveaction_db is assigned an ID
liveaction_db, actionexecution_db = action_service.create_request(
liveaction=liveaction_db,
action_db=action_db,
- runnertype_db=runnertype_db)
+ runnertype_db=runnertype_db,
+ )
# By this point the execution is already in the DB therefore need to mark it failed.
_, e, tb = sys.exc_info()
action_service.update_status(
liveaction=liveaction_db,
new_status=action_constants.LIVEACTION_STATUS_FAILED,
- result={'error': six.text_type(e),
- 'traceback': ''.join(traceback.format_tb(tb, 20))})
+ result={
+ "error": six.text_type(e),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ },
+ )
# Might be a good idea to return the actual ActionExecution rather than bubble up
# the exception.
raise validation_exc.ValueValidationException(six.text_type(e))
# The request should be created after the above call to render_live_params
# so any templates in live parameters have a chance to render.
- liveaction_db, actionexecution_db = action_service.create_request(liveaction=liveaction_db,
- action_db=action_db,
- runnertype_db=runnertype_db)
+ liveaction_db, actionexecution_db = action_service.create_request(
+ liveaction=liveaction_db, action_db=action_db, runnertype_db=runnertype_db
+ )
- _, actionexecution_db = action_service.publish_request(liveaction_db, actionexecution_db)
+ _, actionexecution_db = action_service.publish_request(
+ liveaction_db, actionexecution_db
+ )
mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets)
- execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets)
+ execution_api = ActionExecutionAPI.from_model(
+ actionexecution_db, mask_secrets=mask_secrets
+ )
return Response(json=execution_api, status=http_client.CREATED)
@@ -231,25 +251,33 @@ def _get_result_object(self, id):
:rtype: ``dict``
"""
- fields = ['result']
- action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get()
+ fields = ["result"]
+ action_exec_db = (
+ self.access.impl.model.objects.filter(id=id).only(*fields).get()
+ )
return action_exec_db.result
- def _get_children(self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False):
+ def _get_children(
+ self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False
+ ):
# make sure depth is int. Url encoding will make it a string and needs to
# be converted back in that case.
depth = int(depth)
- LOG.debug('retrieving children for id: %s with depth: %s', id_, depth)
- descendants = execution_service.get_descendants(actionexecution_id=id_,
- descendant_depth=depth,
- result_fmt=result_fmt)
+ LOG.debug("retrieving children for id: %s with depth: %s", id_, depth)
+ descendants = execution_service.get_descendants(
+ actionexecution_id=id_, descendant_depth=depth, result_fmt=result_fmt
+ )
mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets)
- return [self.model.from_model(descendant, mask_secrets=mask_secrets) for
- descendant in descendants]
+ return [
+ self.model.from_model(descendant, mask_secrets=mask_secrets)
+ for descendant in descendants
+ ]
-class BaseActionExecutionNestedController(ActionExecutionsControllerMixin, ResourceController):
+class BaseActionExecutionNestedController(
+ ActionExecutionsControllerMixin, ResourceController
+):
# Note: We need to override "get_one" and "get_all" to return 404 since nested controller
# don't implement thos methods
@@ -265,24 +293,36 @@ def get_one(self, id):
class ActionExecutionChildrenController(BaseActionExecutionNestedController):
- def get_one(self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False):
+ def get_one(
+ self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False
+ ):
"""
Retrieve children for the provided action execution.
:rtype: ``list``
"""
- execution_db = self._get_one_by_id(id=id, requester_user=requester_user,
- permission_type=PermissionType.EXECUTION_VIEW)
+ execution_db = self._get_one_by_id(
+ id=id,
+ requester_user=requester_user,
+ permission_type=PermissionType.EXECUTION_VIEW,
+ )
id = str(execution_db.id)
- return self._get_children(id_=id, depth=depth, result_fmt=result_fmt,
- requester_user=requester_user, show_secrets=show_secrets)
+ return self._get_children(
+ id_=id,
+ depth=depth,
+ result_fmt=result_fmt,
+ requester_user=requester_user,
+ show_secrets=show_secrets,
+ )
class ActionExecutionAttributeController(BaseActionExecutionNestedController):
- valid_exclude_attributes = ['action__pack', 'action__uid'] + \
- ActionExecutionsControllerMixin.valid_exclude_attributes
+ valid_exclude_attributes = [
+ "action__pack",
+ "action__uid",
+ ] + ActionExecutionsControllerMixin.valid_exclude_attributes
def get(self, id, attribute, requester_user):
"""
@@ -294,76 +334,94 @@ def get(self, id, attribute, requester_user):
:rtype: ``dict``
"""
- fields = [attribute, 'action__pack', 'action__uid']
+ fields = [attribute, "action__pack", "action__uid"]
try:
fields = self._validate_exclude_fields(fields)
except ValueError:
- valid_attributes = ', '.join(ActionExecutionsControllerMixin.valid_exclude_attributes)
- msg = ('Invalid attribute "%s" specified. Valid attributes are: %s' %
- (attribute, valid_attributes))
+ valid_attributes = ", ".join(
+ ActionExecutionsControllerMixin.valid_exclude_attributes
+ )
+ msg = 'Invalid attribute "%s" specified. Valid attributes are: %s' % (
+ attribute,
+ valid_attributes,
+ )
raise ValueError(msg)
- action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get()
+ action_exec_db = (
+ self.access.impl.model.objects.filter(id=id).only(*fields).get()
+ )
permission_type = PermissionType.EXECUTION_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_exec_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_exec_db,
+ permission_type=permission_type,
+ )
result = getattr(action_exec_db, attribute, None)
return Response(json=result, status=http_client.OK)
-class ActionExecutionOutputController(ActionExecutionsControllerMixin, ResourceController):
- supported_filters = {
- 'output_type': 'output_type'
- }
+class ActionExecutionOutputController(
+ ActionExecutionsControllerMixin, ResourceController
+):
+ supported_filters = {"output_type": "output_type"}
exclude_fields = []
- def get_one(self, id, output_type='all', output_format='raw', existing_only=False,
- requester_user=None):
+ def get_one(
+ self,
+ id,
+ output_type="all",
+ output_format="raw",
+ existing_only=False,
+ requester_user=None,
+ ):
# Special case for id == "last"
- if id == 'last':
- execution_db = ActionExecution.query().order_by('-id').limit(1).first()
+ if id == "last":
+ execution_db = ActionExecution.query().order_by("-id").limit(1).first()
if not execution_db:
- raise ValueError('No executions found in the database')
+ raise ValueError("No executions found in the database")
id = str(execution_db.id)
- execution_db = self._get_one_by_id(id=id, requester_user=requester_user,
- permission_type=PermissionType.EXECUTION_VIEW)
+ execution_db = self._get_one_by_id(
+ id=id,
+ requester_user=requester_user,
+ permission_type=PermissionType.EXECUTION_VIEW,
+ )
execution_id = str(execution_db.id)
query_filters = {}
- if output_type and output_type != 'all':
- query_filters['output_type'] = output_type
+ if output_type and output_type != "all":
+ query_filters["output_type"] = output_type
def existing_output_iter():
# Consume and return all of the existing lines
# pylint: disable=no-member
- output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters)
+ output_dbs = ActionExecutionOutput.query(
+ execution_id=execution_id, **query_filters
+ )
- output = ''.join([output_db.data for output_db in output_dbs])
- yield six.binary_type(output.encode('utf-8'))
+ output = "".join([output_db.data for output_db in output_dbs])
+ yield six.binary_type(output.encode("utf-8"))
def make_response():
app_iter = existing_output_iter()
- res = Response(content_type='text/plain', app_iter=app_iter)
+ res = Response(content_type="text/plain", app_iter=app_iter)
return res
res = make_response()
return res
-class ActionExecutionReRunController(ActionExecutionsControllerMixin, ResourceController):
+class ActionExecutionReRunController(
+ ActionExecutionsControllerMixin, ResourceController
+):
supported_filters = {}
- exclude_fields = [
- 'result',
- 'trigger_instance'
- ]
+ exclude_fields = ["result", "trigger_instance"]
class ExecutionSpecificationAPI(object):
def __init__(self, parameters=None, tasks=None, reset=None, user=None):
@@ -374,8 +432,10 @@ def __init__(self, parameters=None, tasks=None, reset=None, user=None):
def validate(self):
if (self.tasks or self.reset) and self.parameters:
- raise ValueError('Parameters override is not supported when '
- 're-running task(s) for a workflow.')
+ raise ValueError(
+ "Parameters override is not supported when "
+ "re-running task(s) for a workflow."
+ )
if self.parameters:
assert isinstance(self.parameters, dict)
@@ -387,7 +447,9 @@ def validate(self):
assert isinstance(self.reset, list)
if list(set(self.reset) - set(self.tasks)):
- raise ValueError('List of tasks to reset does not match the tasks to rerun.')
+ raise ValueError(
+ "List of tasks to reset does not match the tasks to rerun."
+ )
return self
@@ -401,8 +463,10 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False)
"""
if (spec_api.tasks or spec_api.reset) and spec_api.parameters:
- raise ValueError('Parameters override is not supported when '
- 're-running task(s) for a workflow.')
+ raise ValueError(
+ "Parameters override is not supported when "
+ "re-running task(s) for a workflow."
+ )
if spec_api.parameters:
assert isinstance(spec_api.parameters, dict)
@@ -414,7 +478,9 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False)
assert isinstance(spec_api.reset, list)
if list(set(spec_api.reset) - set(spec_api.tasks)):
- raise ValueError('List of tasks to reset does not match the tasks to rerun.')
+ raise ValueError(
+ "List of tasks to reset does not match the tasks to rerun."
+ )
delay = None
@@ -422,59 +488,69 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False)
delay = spec_api.delay
no_merge = cast_argument_value(value_type=bool, value=no_merge)
- existing_execution = self._get_one_by_id(id=id, exclude_fields=self.exclude_fields,
- requester_user=requester_user,
- permission_type=PermissionType.EXECUTION_VIEW)
+ existing_execution = self._get_one_by_id(
+ id=id,
+ exclude_fields=self.exclude_fields,
+ requester_user=requester_user,
+ permission_type=PermissionType.EXECUTION_VIEW,
+ )
- if spec_api.tasks and \
- existing_execution.runner['name'] != 'orquesta':
- raise ValueError('Task option is only supported for Orquesta workflows.')
+ if spec_api.tasks and existing_execution.runner["name"] != "orquesta":
+ raise ValueError("Task option is only supported for Orquesta workflows.")
# Merge in any parameters provided by the user
new_parameters = {}
if not no_merge:
- new_parameters.update(getattr(existing_execution, 'parameters', {}))
+ new_parameters.update(getattr(existing_execution, "parameters", {}))
new_parameters.update(spec_api.parameters)
# Create object for the new execution
- action_ref = existing_execution.action['ref']
+ action_ref = existing_execution.action["ref"]
# Include additional option(s) for the execution
context = {
- 're-run': {
- 'ref': id,
+ "re-run": {
+ "ref": id,
}
}
if spec_api.tasks:
- context['re-run']['tasks'] = spec_api.tasks
+ context["re-run"]["tasks"] = spec_api.tasks
if spec_api.reset:
- context['re-run']['reset'] = spec_api.reset
+ context["re-run"]["reset"] = spec_api.reset
# Add trace to the new execution
trace = trace_service.get_trace_db_by_action_execution(
- action_execution_id=existing_execution.id)
+ action_execution_id=existing_execution.id
+ )
if trace:
- context['trace_context'] = {'id_': str(trace.id)}
-
- new_liveaction_api = LiveActionCreateAPI(action=action_ref,
- context=context,
- parameters=new_parameters,
- user=spec_api.user,
- delay=delay)
-
- return self._handle_schedule_execution(liveaction_api=new_liveaction_api,
- requester_user=requester_user,
- show_secrets=show_secrets)
-
-
-class ActionExecutionsController(BaseResourceIsolationControllerMixin,
- ActionExecutionsControllerMixin, ResourceController):
+ context["trace_context"] = {"id_": str(trace.id)}
+
+ new_liveaction_api = LiveActionCreateAPI(
+ action=action_ref,
+ context=context,
+ parameters=new_parameters,
+ user=spec_api.user,
+ delay=delay,
+ )
+
+ return self._handle_schedule_execution(
+ liveaction_api=new_liveaction_api,
+ requester_user=requester_user,
+ show_secrets=show_secrets,
+ )
+
+
+class ActionExecutionsController(
+ BaseResourceIsolationControllerMixin,
+ ActionExecutionsControllerMixin,
+ ResourceController,
+):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of ActionExecutions in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of ActionExecutions in the system.
"""
# Nested controllers
@@ -485,17 +561,25 @@ class ActionExecutionsController(BaseResourceIsolationControllerMixin,
re_run = ActionExecutionReRunController()
# ResourceController attributes
- query_options = {
- 'sort': ['-start_timestamp', 'action.ref']
- }
+ query_options = {"sort": ["-start_timestamp", "action.ref"]}
supported_filters = SUPPORTED_EXECUTIONS_FILTERS
filter_transform_functions = {
- 'timestamp_gt': lambda value: isotime.parse(value=value),
- 'timestamp_lt': lambda value: isotime.parse(value=value)
+ "timestamp_gt": lambda value: isotime.parse(value=value),
+ "timestamp_lt": lambda value: isotime.parse(value=value),
}
- def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0, limit=None,
- show_secrets=False, include_attributes=None, advanced_filters=None, **raw_filters):
+ def get_all(
+ self,
+ requester_user,
+ exclude_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ show_secrets=False,
+ include_attributes=None,
+ advanced_filters=None,
+ **raw_filters,
+ ):
"""
List all executions.
@@ -508,27 +592,37 @@ def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0,
# Use a custom sort order when filtering on a timestamp so we return a correct result as
# expected by the user
query_options = None
- if raw_filters.get('timestamp_lt', None) or raw_filters.get('sort_desc', None):
- query_options = {'sort': ['-start_timestamp', 'action.ref']}
- elif raw_filters.get('timestamp_gt', None) or raw_filters.get('sort_asc', None):
- query_options = {'sort': ['+start_timestamp', 'action.ref']}
+ if raw_filters.get("timestamp_lt", None) or raw_filters.get("sort_desc", None):
+ query_options = {"sort": ["-start_timestamp", "action.ref"]}
+ elif raw_filters.get("timestamp_gt", None) or raw_filters.get("sort_asc", None):
+ query_options = {"sort": ["+start_timestamp", "action.ref"]}
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
- return self._get_action_executions(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- from_model_kwargs=from_model_kwargs,
- sort=sort,
- offset=offset,
- limit=limit,
- query_options=query_options,
- raw_filters=raw_filters,
- advanced_filters=advanced_filters,
- requester_user=requester_user)
-
- def get_one(self, id, requester_user, exclude_attributes=None, include_attributes=None,
- show_secrets=False):
+ return self._get_action_executions(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ from_model_kwargs=from_model_kwargs,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ query_options=query_options,
+ raw_filters=raw_filters,
+ advanced_filters=advanced_filters,
+ requester_user=requester_user,
+ )
+
+ def get_one(
+ self,
+ id,
+ requester_user,
+ exclude_attributes=None,
+ include_attributes=None,
+ show_secrets=False,
+ ):
"""
Retrieve a single execution.
@@ -538,33 +632,48 @@ def get_one(self, id, requester_user, exclude_attributes=None, include_attribute
:param exclude_attributes: List of attributes to exclude from the object.
:type exclude_attributes: ``list``
"""
- exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_attributes)
- include_fields = self._validate_include_fields(include_fields=include_attributes)
+ exclude_fields = self._validate_exclude_fields(
+ exclude_fields=exclude_attributes
+ )
+ include_fields = self._validate_include_fields(
+ include_fields=include_attributes
+ )
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
# Special case for id == "last"
- if id == 'last':
- execution_db = ActionExecution.query().order_by('-id').limit(1).only('id').first()
+ if id == "last":
+ execution_db = (
+ ActionExecution.query().order_by("-id").limit(1).only("id").first()
+ )
if not execution_db:
- raise ValueError('No executions found in the database')
+ raise ValueError("No executions found in the database")
id = str(execution_db.id)
- return self._get_one_by_id(id=id, exclude_fields=exclude_fields,
- include_fields=include_fields,
- requester_user=requester_user,
- from_model_kwargs=from_model_kwargs,
- permission_type=PermissionType.EXECUTION_VIEW)
-
- def post(self, liveaction_api, requester_user, context_string=None, show_secrets=False):
- return self._handle_schedule_execution(liveaction_api=liveaction_api,
- requester_user=requester_user,
- context_string=context_string,
- show_secrets=show_secrets)
+ return self._get_one_by_id(
+ id=id,
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ requester_user=requester_user,
+ from_model_kwargs=from_model_kwargs,
+ permission_type=PermissionType.EXECUTION_VIEW,
+ )
+
+ def post(
+ self, liveaction_api, requester_user, context_string=None, show_secrets=False
+ ):
+ return self._handle_schedule_execution(
+ liveaction_api=liveaction_api,
+ requester_user=requester_user,
+ context_string=context_string,
+ show_secrets=show_secrets,
+ )
def put(self, id, liveaction_api, requester_user, show_secrets=False):
"""
@@ -578,76 +687,118 @@ def put(self, id, liveaction_api, requester_user, show_secrets=False):
requester_user = UserDB(cfg.CONF.system_user.user)
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
- execution_api = self._get_one_by_id(id=id, requester_user=requester_user,
- from_model_kwargs=from_model_kwargs,
- permission_type=PermissionType.EXECUTION_STOP)
+ execution_api = self._get_one_by_id(
+ id=id,
+ requester_user=requester_user,
+ from_model_kwargs=from_model_kwargs,
+ permission_type=PermissionType.EXECUTION_STOP,
+ )
if not execution_api:
- abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id)
+ abort(http_client.NOT_FOUND, "Execution with id %s not found." % id)
- liveaction_id = execution_api.liveaction['id']
+ liveaction_id = execution_api.liveaction["id"]
if not liveaction_id:
- abort(http_client.INTERNAL_SERVER_ERROR,
- 'Execution object missing link to liveaction %s.' % liveaction_id)
+ abort(
+ http_client.INTERNAL_SERVER_ERROR,
+ "Execution object missing link to liveaction %s." % liveaction_id,
+ )
try:
liveaction_db = LiveAction.get_by_id(liveaction_id)
except:
- abort(http_client.INTERNAL_SERVER_ERROR,
- 'Execution object missing link to liveaction %s.' % liveaction_id)
+ abort(
+ http_client.INTERNAL_SERVER_ERROR,
+ "Execution object missing link to liveaction %s." % liveaction_id,
+ )
if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES:
- abort(http_client.BAD_REQUEST, 'Execution is already in completed state.')
+ abort(http_client.BAD_REQUEST, "Execution is already in completed state.")
def update_status(liveaction_api, liveaction_db):
status = liveaction_api.status
- result = getattr(liveaction_api, 'result', None)
+ result = getattr(liveaction_api, "result", None)
liveaction_db = action_service.update_status(liveaction_db, status, result)
- actionexecution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id))
+ actionexecution_db = ActionExecution.get(
+ liveaction__id=str(liveaction_db.id)
+ )
return (liveaction_db, actionexecution_db)
try:
- if (liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING and
- liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED):
+ if (
+ liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING
+ and liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED
+ ):
if action_service.is_children_active(liveaction_id):
liveaction_api.status = action_constants.LIVEACTION_STATUS_CANCELING
- liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db)
- elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING or
- liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED):
+ liveaction_db, actionexecution_db = update_status(
+ liveaction_api, liveaction_db
+ )
+ elif (
+ liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING
+ or liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED
+ ):
liveaction_db, actionexecution_db = action_service.request_cancellation(
- liveaction_db, requester_user.name or cfg.CONF.system_user.user)
- elif (liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING and
- liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED):
+ liveaction_db, requester_user.name or cfg.CONF.system_user.user
+ )
+ elif (
+ liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING
+ and liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
if action_service.is_children_active(liveaction_id):
liveaction_api.status = action_constants.LIVEACTION_STATUS_PAUSING
- liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db)
- elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING or
- liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED):
+ liveaction_db, actionexecution_db = update_status(
+ liveaction_api, liveaction_db
+ )
+ elif (
+ liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING
+ or liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
liveaction_db, actionexecution_db = action_service.request_pause(
- liveaction_db, requester_user.name or cfg.CONF.system_user.user)
+ liveaction_db, requester_user.name or cfg.CONF.system_user.user
+ )
elif liveaction_api.status == action_constants.LIVEACTION_STATUS_RESUMING:
liveaction_db, actionexecution_db = action_service.request_resume(
- liveaction_db, requester_user.name or cfg.CONF.system_user.user)
+ liveaction_db, requester_user.name or cfg.CONF.system_user.user
+ )
else:
- liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db)
+ liveaction_db, actionexecution_db = update_status(
+ liveaction_api, liveaction_db
+ )
except runner_exc.InvalidActionRunnerOperationError as e:
- LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e))
- abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e))
+ LOG.exception(
+ "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e)
+ )
+ abort(
+ http_client.BAD_REQUEST,
+ "Failed updating execution. %s" % six.text_type(e),
+ )
except runner_exc.UnexpectedActionExecutionStatusError as e:
- LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e))
- abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e))
+ LOG.exception(
+ "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e)
+ )
+ abort(
+ http_client.BAD_REQUEST,
+ "Failed updating execution. %s" % six.text_type(e),
+ )
except Exception as e:
- LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e))
+ LOG.exception(
+ "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e)
+ )
abort(
http_client.INTERNAL_SERVER_ERROR,
- 'Failed updating execution due to unexpected error.'
+ "Failed updating execution due to unexpected error.",
)
mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets)
- execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets)
+ execution_api = ActionExecutionAPI.from_model(
+ actionexecution_db, mask_secrets=mask_secrets
+ )
return execution_api
@@ -663,50 +814,76 @@ def delete(self, id, requester_user, show_secrets=False):
requester_user = UserDB(cfg.CONF.system_user.user)
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
- execution_api = self._get_one_by_id(id=id, requester_user=requester_user,
- from_model_kwargs=from_model_kwargs,
- permission_type=PermissionType.EXECUTION_STOP)
+ execution_api = self._get_one_by_id(
+ id=id,
+ requester_user=requester_user,
+ from_model_kwargs=from_model_kwargs,
+ permission_type=PermissionType.EXECUTION_STOP,
+ )
if not execution_api:
- abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id)
+ abort(http_client.NOT_FOUND, "Execution with id %s not found." % id)
- liveaction_id = execution_api.liveaction['id']
+ liveaction_id = execution_api.liveaction["id"]
if not liveaction_id:
- abort(http_client.INTERNAL_SERVER_ERROR,
- 'Execution object missing link to liveaction %s.' % liveaction_id)
+ abort(
+ http_client.INTERNAL_SERVER_ERROR,
+ "Execution object missing link to liveaction %s." % liveaction_id,
+ )
try:
liveaction_db = LiveAction.get_by_id(liveaction_id)
except:
- abort(http_client.INTERNAL_SERVER_ERROR,
- 'Execution object missing link to liveaction %s.' % liveaction_id)
+ abort(
+ http_client.INTERNAL_SERVER_ERROR,
+ "Execution object missing link to liveaction %s." % liveaction_id,
+ )
if liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELED:
LOG.info(
'Action %s already in "canceled" state; \
- returning execution object.' % liveaction_db.id
+ returning execution object.'
+ % liveaction_db.id
)
return execution_api
if liveaction_db.status not in action_constants.LIVEACTION_CANCELABLE_STATES:
- abort(http_client.OK, 'Action cannot be canceled. State = %s.' % liveaction_db.status)
+ abort(
+ http_client.OK,
+ "Action cannot be canceled. State = %s." % liveaction_db.status,
+ )
try:
(liveaction_db, execution_db) = action_service.request_cancellation(
- liveaction_db, requester_user.name or cfg.CONF.system_user.user)
+ liveaction_db, requester_user.name or cfg.CONF.system_user.user
+ )
except:
- LOG.exception('Failed requesting cancellation for liveaction %s.', liveaction_db.id)
- abort(http_client.INTERNAL_SERVER_ERROR, 'Failed canceling execution.')
-
- return ActionExecutionAPI.from_model(execution_db,
- mask_secrets=from_model_kwargs['mask_secrets'])
-
- def _get_action_executions(self, exclude_fields=None, include_fields=None,
- sort=None, offset=0, limit=None, advanced_filters=None,
- query_options=None, raw_filters=None, from_model_kwargs=None,
- requester_user=None):
+ LOG.exception(
+ "Failed requesting cancellation for liveaction %s.", liveaction_db.id
+ )
+ abort(http_client.INTERNAL_SERVER_ERROR, "Failed canceling execution.")
+
+ return ActionExecutionAPI.from_model(
+ execution_db, mask_secrets=from_model_kwargs["mask_secrets"]
+ )
+
+ def _get_action_executions(
+ self,
+ exclude_fields=None,
+ include_fields=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ advanced_filters=None,
+ query_options=None,
+ raw_filters=None,
+ from_model_kwargs=None,
+ requester_user=None,
+ ):
"""
:param exclude_fields: A list of object fields to exclude.
:type exclude_fields: ``list``
@@ -717,18 +894,25 @@ def _get_action_executions(self, exclude_fields=None, include_fields=None,
limit = int(limit)
- LOG.debug('Retrieving all action executions with filters=%s,exclude_fields=%s,'
- 'include_fields=%s', raw_filters, exclude_fields, include_fields)
- return super(ActionExecutionsController, self)._get_all(exclude_fields=exclude_fields,
- include_fields=include_fields,
- from_model_kwargs=from_model_kwargs,
- sort=sort,
- offset=offset,
- limit=limit,
- query_options=query_options,
- raw_filters=raw_filters,
- advanced_filters=advanced_filters,
- requester_user=requester_user)
+ LOG.debug(
+ "Retrieving all action executions with filters=%s,exclude_fields=%s,"
+ "include_fields=%s",
+ raw_filters,
+ exclude_fields,
+ include_fields,
+ )
+ return super(ActionExecutionsController, self)._get_all(
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ from_model_kwargs=from_model_kwargs,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ query_options=query_options,
+ raw_filters=raw_filters,
+ advanced_filters=advanced_filters,
+ requester_user=requester_user,
+ )
action_executions_controller = ActionExecutionsController()
diff --git a/st2api/st2api/controllers/v1/actions.py b/st2api/st2api/controllers/v1/actions.py
index 1746e84b83..c78667076d 100644
--- a/st2api/st2api/controllers/v1/actions.py
+++ b/st2api/st2api/controllers/v1/actions.py
@@ -53,91 +53,102 @@
class ActionsController(resource.ContentPackResourceController):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of Actions in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of Actions in the system.
"""
+
views = ActionViewsController()
model = ActionAPI
access = Action
- supported_filters = {
- 'name': 'name',
- 'pack': 'pack',
- 'tags': 'tags.name'
- }
+ supported_filters = {"name": "name", "pack": "pack", "tags": "tags.name"}
- query_options = {
- 'sort': ['pack', 'name']
- }
+ query_options = {"sort": ["pack", "name"]}
- valid_exclude_attributes = [
- 'parameters',
- 'notify'
- ]
+ valid_exclude_attributes = ["parameters", "notify"]
def __init__(self, *args, **kwargs):
super(ActionsController, self).__init__(*args, **kwargs)
self._trigger_dispatcher = TriggerDispatcher(LOG)
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return super(ActionsController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(ActionsController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user):
- return super(ActionsController, self)._get_one(ref_or_id, requester_user=requester_user,
- permission_type=PermissionType.ACTION_VIEW)
+ return super(ActionsController, self)._get_one(
+ ref_or_id,
+ requester_user=requester_user,
+ permission_type=PermissionType.ACTION_VIEW,
+ )
def post(self, action, requester_user):
"""
- Create a new action.
+ Create a new action.
- Handles requests:
- POST /actions/
+ Handles requests:
+ POST /actions/
"""
permission_type = PermissionType.ACTION_CREATE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user,
- resource_api=action,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_api_permission(
+ user_db=requester_user, resource_api=action, permission_type=permission_type
+ )
try:
# Perform validation
validate_not_part_of_system_pack(action)
action_validator.validate_action(action)
- except (ValidationError, ValueError,
- ValueValidationException, InvalidActionParameterException) as e:
- LOG.exception('Unable to create action data=%s', action)
+ except (
+ ValidationError,
+ ValueError,
+ ValueValidationException,
+ InvalidActionParameterException,
+ ) as e:
+ LOG.exception("Unable to create action data=%s", action)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
# Write pack data files to disk (if any are provided)
- data_files = getattr(action, 'data_files', [])
+ data_files = getattr(action, "data_files", [])
written_data_files = []
if data_files:
- written_data_files = self._handle_data_files(pack_ref=action.pack,
- data_files=data_files)
+ written_data_files = self._handle_data_files(
+ pack_ref=action.pack, data_files=data_files
+ )
action_model = ActionAPI.to_model(action)
- LOG.debug('/actions/ POST verified ActionAPI object=%s', action)
+ LOG.debug("/actions/ POST verified ActionAPI object=%s", action)
action_db = Action.add_or_update(action_model)
- LOG.debug('/actions/ POST saved ActionDB object=%s', action_db)
+ LOG.debug("/actions/ POST saved ActionDB object=%s", action_db)
# Dispatch an internal trigger for each written data file. This way user
# automate comitting this files to git using StackStorm rule
if written_data_files:
- self._dispatch_trigger_for_written_data_files(action_db=action_db,
- written_data_files=written_data_files)
+ self._dispatch_trigger_for_written_data_files(
+ action_db=action_db, written_data_files=written_data_files
+ )
- extra = {'acion_db': action_db}
- LOG.audit('Action created. Action.id=%s' % (action_db.id), extra=extra)
+ extra = {"acion_db": action_db}
+ LOG.audit("Action created. Action.id=%s" % (action_db.id), extra=extra)
action_api = ActionAPI.from_model(action_db)
return Response(json=action_api, status=http_client.CREATED)
@@ -148,13 +159,15 @@ def put(self, action, ref_or_id, requester_user):
# Assert permissions
permission_type = PermissionType.ACTION_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
action_id = action_db.id
- if not getattr(action, 'pack', None):
+ if not getattr(action, "pack", None):
action.pack = action_db.pack
# Perform validation
@@ -162,70 +175,81 @@ def put(self, action, ref_or_id, requester_user):
action_validator.validate_action(action)
# Write pack data files to disk (if any are provided)
- data_files = getattr(action, 'data_files', [])
+ data_files = getattr(action, "data_files", [])
written_data_files = []
if data_files:
- written_data_files = self._handle_data_files(pack_ref=action.pack,
- data_files=data_files)
+ written_data_files = self._handle_data_files(
+ pack_ref=action.pack, data_files=data_files
+ )
try:
action_db = ActionAPI.to_model(action)
- LOG.debug('/actions/ PUT incoming action: %s', action_db)
+ LOG.debug("/actions/ PUT incoming action: %s", action_db)
action_db.id = action_id
action_db = Action.add_or_update(action_db)
- LOG.debug('/actions/ PUT after add_or_update: %s', action_db)
+ LOG.debug("/actions/ PUT after add_or_update: %s", action_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Unable to update action data=%s', action)
+ LOG.exception("Unable to update action data=%s", action)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
# Dispatch an internal trigger for each written data file. This way user
# automate committing this files to git using StackStorm rule
if written_data_files:
- self._dispatch_trigger_for_written_data_files(action_db=action_db,
- written_data_files=written_data_files)
+ self._dispatch_trigger_for_written_data_files(
+ action_db=action_db, written_data_files=written_data_files
+ )
action_api = ActionAPI.from_model(action_db)
- LOG.debug('PUT /actions/ client_result=%s', action_api)
+ LOG.debug("PUT /actions/ client_result=%s", action_api)
return action_api
def delete(self, ref_or_id, requester_user):
"""
- Delete an action.
+ Delete an action.
- Handles requests:
- POST /actions/1?_method=delete
- DELETE /actions/1
- DELETE /actions/mypack.myaction
+ Handles requests:
+ POST /actions/1?_method=delete
+ DELETE /actions/1
+ DELETE /actions/mypack.myaction
"""
action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
action_id = action_db.id
permission_type = PermissionType.ACTION_DELETE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
try:
validate_not_part_of_system_pack(action_db)
except ValueValidationException as e:
abort(http_client.BAD_REQUEST, six.text_type(e))
- LOG.debug('DELETE /actions/ lookup with ref_or_id=%s found object: %s',
- ref_or_id, action_db)
+ LOG.debug(
+ "DELETE /actions/ lookup with ref_or_id=%s found object: %s",
+ ref_or_id,
+ action_db,
+ )
try:
Action.delete(action_db)
except Exception as e:
- LOG.error('Database delete encountered exception during delete of id="%s". '
- 'Exception was %s', action_id, e)
+ LOG.error(
+ 'Database delete encountered exception during delete of id="%s". '
+ "Exception was %s",
+ action_id,
+ e,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
- extra = {'action_db': action_db}
- LOG.audit('Action deleted. Action.id=%s' % (action_db.id), extra=extra)
+ extra = {"action_db": action_db}
+ LOG.audit("Action deleted. Action.id=%s" % (action_db.id), extra=extra)
return Response(status=http_client.NO_CONTENT)
def _handle_data_files(self, pack_ref, data_files):
@@ -238,13 +262,17 @@ def _handle_data_files(self, pack_ref, data_files):
2. Updates affected PackDB model
"""
# Write files to disk
- written_file_paths = self._write_data_files_to_disk(pack_ref=pack_ref,
- data_files=data_files)
+ written_file_paths = self._write_data_files_to_disk(
+ pack_ref=pack_ref, data_files=data_files
+ )
# Update affected PackDB model (update a list of files)
# Update PackDB
- self._update_pack_model(pack_ref=pack_ref, data_files=data_files,
- written_file_paths=written_file_paths)
+ self._update_pack_model(
+ pack_ref=pack_ref,
+ data_files=data_files,
+ written_file_paths=written_file_paths,
+ )
return written_file_paths
@@ -255,23 +283,27 @@ def _write_data_files_to_disk(self, pack_ref, data_files):
written_file_paths = []
for data_file in data_files:
- file_path = data_file['file_path']
- content = data_file['content']
+ file_path = data_file["file_path"]
+ content = data_file["content"]
- file_path = get_pack_resource_file_abs_path(pack_ref=pack_ref,
- resource_type='action',
- file_path=file_path)
+ file_path = get_pack_resource_file_abs_path(
+ pack_ref=pack_ref, resource_type="action", file_path=file_path
+ )
LOG.debug('Writing data file "%s" to "%s"' % (str(data_file), file_path))
try:
- self._write_data_file(pack_ref=pack_ref, file_path=file_path, content=content)
+ self._write_data_file(
+ pack_ref=pack_ref, file_path=file_path, content=content
+ )
except (OSError, IOError) as e:
# Throw a more user-friendly exception on Permission denied error
if e.errno == errno.EACCES:
- msg = ('Unable to write data to "%s" (permission denied). Make sure '
- 'permissions for that pack directory are configured correctly so '
- 'st2api can write to it.' % (file_path))
+ msg = (
+ 'Unable to write data to "%s" (permission denied). Make sure '
+ "permissions for that pack directory are configured correctly so "
+ "st2api can write to it." % (file_path)
+ )
raise ValueError(msg)
raise e
@@ -285,7 +317,9 @@ def _update_pack_model(self, pack_ref, data_files, written_file_paths):
"""
file_paths = [] # A list of paths relative to the pack directory for new files
for file_path in written_file_paths:
- file_path = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path)
+ file_path = get_relative_path_to_pack_file(
+ pack_ref=pack_ref, file_path=file_path
+ )
file_paths.append(file_path)
pack_db = Pack.get_by_ref(pack_ref)
@@ -314,18 +348,18 @@ def _write_data_file(self, pack_ref, file_path, content):
mode = stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH
os.makedirs(directory, mode)
- with open(file_path, 'w') as fp:
+ with open(file_path, "w") as fp:
fp.write(content)
def _dispatch_trigger_for_written_data_files(self, action_db, written_data_files):
- trigger = ACTION_FILE_WRITTEN_TRIGGER['name']
+ trigger = ACTION_FILE_WRITTEN_TRIGGER["name"]
host_info = get_host_info()
for file_path in written_data_files:
payload = {
- 'ref': action_db.ref,
- 'file_path': file_path,
- 'host_info': host_info
+ "ref": action_db.ref,
+ "file_path": file_path,
+ "host_info": host_info,
}
self._trigger_dispatcher.dispatch(trigger=trigger, payload=payload)
diff --git a/st2api/st2api/controllers/v1/aliasexecution.py b/st2api/st2api/controllers/v1/aliasexecution.py
index 7ecc14d62e..ecbea2028e 100644
--- a/st2api/st2api/controllers/v1/aliasexecution.py
+++ b/st2api/st2api/controllers/v1/aliasexecution.py
@@ -30,7 +30,9 @@
from st2common.models.db.liveaction import LiveActionDB
from st2common.models.db.notification import NotificationSchema, NotificationSubSchema
from st2common.models.utils import action_param_utils
-from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db
+from st2common.models.utils.action_alias_utils import (
+ extract_parameters_for_action_alias_db,
+)
from st2common.models.utils.action_alias_utils import inject_immutable_parameters
from st2common.persistence.actionalias import ActionAlias
from st2common.services import action as action_service
@@ -53,57 +55,60 @@ def cast_array(value):
# Already a list, no casting needed nor wanted.
return value
- return [v.strip() for v in value.split(',')]
+ return [v.strip() for v in value.split(",")]
CAST_OVERRIDES = {
- 'array': cast_array,
+ "array": cast_array,
}
class ActionAliasExecutionController(BaseRestControllerMixin):
def match_and_execute(self, input_api, requester_user, show_secrets=False):
"""
- Try to find a matching alias and if one is found, schedule a new
- execution by parsing parameters from the provided command against
- the matched alias.
+ Try to find a matching alias and if one is found, schedule a new
+ execution by parsing parameters from the provided command against
+ the matched alias.
- Handles requests:
- POST /aliasexecution/match_and_execute
+ Handles requests:
+ POST /aliasexecution/match_and_execute
"""
command = input_api.command
try:
format_ = get_matching_alias(command=command)
except ActionAliasAmbiguityException as e:
- LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches))
+ LOG.exception(
+ 'Command "%s" matched (%s) patterns.', e.command, len(e.matches)
+ )
return abort(http_client.BAD_REQUEST, six.text_type(e))
- action_alias_db = format_['alias']
- representation = format_['representation']
+ action_alias_db = format_["alias"]
+ representation = format_["representation"]
params = {
- 'name': action_alias_db.name,
- 'format': representation,
- 'command': command,
- 'user': input_api.user,
- 'source_channel': input_api.source_channel,
+ "name": action_alias_db.name,
+ "format": representation,
+ "command": command,
+ "user": input_api.user,
+ "source_channel": input_api.source_channel,
}
# Add in any additional parameters provided by the user
if input_api.notification_channel:
- params['notification_channel'] = input_api.notification_channel
+ params["notification_channel"] = input_api.notification_channel
if input_api.notification_route:
- params['notification_route'] = input_api.notification_route
+ params["notification_route"] = input_api.notification_route
alias_execution_api = AliasMatchAndExecuteInputAPI(**params)
results = self._post(
payload=alias_execution_api,
requester_user=requester_user,
show_secrets=show_secrets,
- match_multiple=format_['match_multiple'])
- return Response(json={'results': results}, status=http_client.CREATED)
+ match_multiple=format_["match_multiple"],
+ )
+ return Response(json={"results": results}, status=http_client.CREATED)
def _post(self, payload, requester_user, show_secrets=False, match_multiple=False):
action_alias_name = payload.name if payload else None
@@ -115,8 +120,8 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
- format_str = payload.format or ''
- command = payload.command or ''
+ format_str = payload.format or ""
+ command = payload.command or ""
try:
action_alias_db = ActionAlias.get_by_name(action_alias_name)
@@ -124,7 +129,9 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals
action_alias_db = None
if not action_alias_db:
- msg = 'Unable to identify action alias with name "%s".' % (action_alias_name)
+ msg = 'Unable to identify action alias with name "%s".' % (
+ action_alias_name
+ )
abort(http_client.NOT_FOUND, msg)
return
@@ -138,132 +145,163 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals
action_alias_db=action_alias_db,
format_str=format_str,
param_stream=command,
- match_multiple=match_multiple)
+ match_multiple=match_multiple,
+ )
else:
multiple_execution_parameters = [
extract_parameters_for_action_alias_db(
action_alias_db=action_alias_db,
format_str=format_str,
param_stream=command,
- match_multiple=match_multiple)
+ match_multiple=match_multiple,
+ )
]
notify = self._get_notify_field(payload)
context = {
- 'action_alias_ref': reference.get_ref_from_model(action_alias_db),
- 'api_user': payload.user,
- 'user': requester_user.name,
- 'source_channel': payload.source_channel,
+ "action_alias_ref": reference.get_ref_from_model(action_alias_db),
+ "api_user": payload.user,
+ "user": requester_user.name,
+ "source_channel": payload.source_channel,
}
inject_immutable_parameters(
action_alias_db=action_alias_db,
multiple_execution_parameters=multiple_execution_parameters,
- action_context=context)
+ action_context=context,
+ )
results = []
for execution_parameters in multiple_execution_parameters:
- execution = self._schedule_execution(action_alias_db=action_alias_db,
- params=execution_parameters,
- notify=notify,
- context=context,
- show_secrets=show_secrets,
- requester_user=requester_user)
+ execution = self._schedule_execution(
+ action_alias_db=action_alias_db,
+ params=execution_parameters,
+ notify=notify,
+ context=context,
+ show_secrets=show_secrets,
+ requester_user=requester_user,
+ )
result = {
- 'execution': execution,
- 'actionalias': ActionAliasAPI.from_model(action_alias_db)
+ "execution": execution,
+ "actionalias": ActionAliasAPI.from_model(action_alias_db),
}
if action_alias_db.ack:
try:
- if 'format' in action_alias_db.ack:
- message = render({'alias': action_alias_db.ack['format']}, result)['alias']
+ if "format" in action_alias_db.ack:
+ message = render(
+ {"alias": action_alias_db.ack["format"]}, result
+ )["alias"]
- result.update({
- 'message': message
- })
+ result.update({"message": message})
except UndefinedError as e:
- result.update({
- 'message': ('Cannot render "format" in field "ack" for alias. ' +
- six.text_type(e))
- })
+ result.update(
+ {
+ "message": (
+ 'Cannot render "format" in field "ack" for alias. '
+ + six.text_type(e)
+ )
+ }
+ )
try:
- if 'extra' in action_alias_db.ack:
- result.update({
- 'extra': render(action_alias_db.ack['extra'], result)
- })
+ if "extra" in action_alias_db.ack:
+ result.update(
+ {"extra": render(action_alias_db.ack["extra"], result)}
+ )
except UndefinedError as e:
- result.update({
- 'extra': ('Cannot render "extra" in field "ack" for alias. ' +
- six.text_type(e))
- })
+ result.update(
+ {
+ "extra": (
+ 'Cannot render "extra" in field "ack" for alias. '
+ + six.text_type(e)
+ )
+ }
+ )
results.append(result)
return results
def post(self, payload, requester_user, show_secrets=False):
- results = self._post(payload, requester_user, show_secrets, match_multiple=False)
+ results = self._post(
+ payload, requester_user, show_secrets, match_multiple=False
+ )
return Response(json=results[0], status=http_client.CREATED)
def _tokenize_alias_execution(self, alias_execution):
- tokens = alias_execution.strip().split(' ', 1)
+ tokens = alias_execution.strip().split(" ", 1)
return (tokens[0], tokens[1] if len(tokens) > 1 else None)
def _get_notify_field(self, payload):
on_complete = NotificationSubSchema()
- route = (getattr(payload, 'notification_route', None) or
- getattr(payload, 'notification_channel', None))
+ route = getattr(payload, "notification_route", None) or getattr(
+ payload, "notification_channel", None
+ )
on_complete.routes = [route]
on_complete.data = {
- 'user': payload.user,
- 'source_channel': payload.source_channel,
- 'source_context': getattr(payload, 'source_context', None),
+ "user": payload.user,
+ "source_channel": payload.source_channel,
+ "source_context": getattr(payload, "source_context", None),
}
notify = NotificationSchema()
notify.on_complete = on_complete
return notify
- def _schedule_execution(self, action_alias_db, params, notify, context, requester_user,
- show_secrets):
+ def _schedule_execution(
+ self, action_alias_db, params, notify, context, requester_user, show_secrets
+ ):
action_ref = action_alias_db.action_ref
action_db = action_utils.get_action_by_ref(action_ref)
if not action_db:
- raise StackStormDBObjectNotFoundError('Action with ref "%s" not found ' % (action_ref))
+ raise StackStormDBObjectNotFoundError(
+ 'Action with ref "%s" not found ' % (action_ref)
+ )
rbac_utils = get_rbac_backend().get_utils_class()
permission_type = PermissionType.ACTION_EXECUTE
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=action_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=action_db,
+ permission_type=permission_type,
+ )
try:
# prior to shipping off the params cast them to the right type.
- params = action_param_utils.cast_params(action_ref=action_alias_db.action_ref,
- params=params,
- cast_overrides=CAST_OVERRIDES)
+ params = action_param_utils.cast_params(
+ action_ref=action_alias_db.action_ref,
+ params=params,
+ cast_overrides=CAST_OVERRIDES,
+ )
if not context:
context = {
- 'action_alias_ref': reference.get_ref_from_model(action_alias_db),
- 'user': get_system_username()
+ "action_alias_ref": reference.get_ref_from_model(action_alias_db),
+ "user": get_system_username(),
}
- liveaction = LiveActionDB(action=action_alias_db.action_ref, context=context,
- parameters=params, notify=notify)
+ liveaction = LiveActionDB(
+ action=action_alias_db.action_ref,
+ context=context,
+ parameters=params,
+ notify=notify,
+ )
_, action_execution_db = action_service.request(liveaction)
- mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets)
- return ActionExecutionAPI.from_model(action_execution_db, mask_secrets=mask_secrets)
+ mask_secrets = self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
+ return ActionExecutionAPI.from_model(
+ action_execution_db, mask_secrets=mask_secrets
+ )
except ValueError as e:
- LOG.exception('Unable to execute action.')
+ LOG.exception("Unable to execute action.")
abort(http_client.BAD_REQUEST, six.text_type(e))
except jsonschema.ValidationError as e:
- LOG.exception('Unable to execute action. Parameter validation failed.')
+ LOG.exception("Unable to execute action. Parameter validation failed.")
abort(http_client.BAD_REQUEST, six.text_type(e))
except Exception as e:
- LOG.exception('Unable to execute action. Unexpected error encountered.')
+ LOG.exception("Unable to execute action. Unexpected error encountered.")
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
diff --git a/st2api/st2api/controllers/v1/auth.py b/st2api/st2api/controllers/v1/auth.py
index 909c8ff4fe..d4741c4bf1 100644
--- a/st2api/st2api/controllers/v1/auth.py
+++ b/st2api/st2api/controllers/v1/auth.py
@@ -37,9 +37,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'ApiKeyController'
-]
+__all__ = ["ApiKeyController"]
# See st2common.rbac.resolvers.ApiKeyPermissionResolver#user_has_resource_db_permission for resaon
@@ -49,13 +47,9 @@ class ApiKeyController(BaseRestControllerMixin):
Implements the REST endpoint for managing the key value store.
"""
- supported_filters = {
- 'user': 'user'
- }
+ supported_filters = {"user": "user"}
- query_options = {
- 'sort': ['user']
- }
+ query_options = {"sort": ["user"]}
def __init__(self):
super(ApiKeyController, self).__init__()
@@ -63,31 +57,36 @@ def __init__(self):
def get_one(self, api_key_id_or_key, requester_user, show_secrets=None):
"""
- List api keys.
+ List api keys.
- Handle:
- GET /apikeys/1
+ Handle:
+ GET /apikeys/1
"""
api_key_db = None
try:
api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key)
except ApiKeyNotFoundError:
- msg = ('ApiKey matching %s for reference and id not found.' % (api_key_id_or_key))
+ msg = "ApiKey matching %s for reference and id not found." % (
+ api_key_id_or_key
+ )
LOG.exception(msg)
abort(http_client.NOT_FOUND, msg)
permission_type = PermissionType.API_KEY_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=api_key_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=api_key_db,
+ permission_type=permission_type,
+ )
try:
- mask_secrets = self._get_mask_secrets(show_secrets=show_secrets,
- requester_user=requester_user)
+ mask_secrets = self._get_mask_secrets(
+ show_secrets=show_secrets, requester_user=requester_user
+ )
return ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets)
except (ValidationError, ValueError) as e:
- LOG.exception('Failed to serialize API key.')
+ LOG.exception("Failed to serialize API key.")
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
@property
@@ -96,29 +95,34 @@ def max_limit(self):
def get_all(self, requester_user, show_secrets=None, limit=None, offset=0):
"""
- List all keys.
+ List all keys.
- Handles requests:
- GET /apikeys/
+ Handles requests:
+ GET /apikeys/
"""
- mask_secrets = self._get_mask_secrets(show_secrets=show_secrets,
- requester_user=requester_user)
+ mask_secrets = self._get_mask_secrets(
+ show_secrets=show_secrets, requester_user=requester_user
+ )
- limit = resource.validate_limit_query_param(limit, requester_user=requester_user)
+ limit = resource.validate_limit_query_param(
+ limit, requester_user=requester_user
+ )
try:
api_key_dbs = ApiKey.get_all(limit=limit, offset=offset)
- api_keys = [ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets)
- for api_key_db in api_key_dbs]
+ api_keys = [
+ ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets)
+ for api_key_db in api_key_dbs
+ ]
except OverflowError:
msg = 'Offset "%s" specified is more than 32 bit int' % (offset)
raise ValueError(msg)
resp = Response(json=api_keys)
- resp.headers['X-Total-Count'] = str(api_key_dbs.count())
+ resp.headers["X-Total-Count"] = str(api_key_dbs.count())
if limit:
- resp.headers['X-Limit'] = str(limit)
+ resp.headers["X-Limit"] = str(limit)
return resp
@@ -129,14 +133,16 @@ def post(self, api_key_api, requester_user):
permission_type = PermissionType.API_KEY_CREATE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user,
- resource_api=api_key_api,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_api_permission(
+ user_db=requester_user,
+ resource_api=api_key_api,
+ permission_type=permission_type,
+ )
api_key_db = None
api_key = None
try:
- if not getattr(api_key_api, 'user', None):
+ if not getattr(api_key_api, "user", None):
if requester_user:
api_key_api.user = requester_user.name
else:
@@ -148,22 +154,22 @@ def post(self, api_key_api, requester_user):
user_db = UserDB(name=api_key_api.user)
User.add_or_update(user_db)
- extra = {'username': api_key_api.user, 'user': user_db}
+ extra = {"username": api_key_api.user, "user": user_db}
LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra)
# If key_hash is provided use that and do not create a new key. The assumption
# is user already has the original api-key
- if not getattr(api_key_api, 'key_hash', None):
+ if not getattr(api_key_api, "key_hash", None):
api_key, api_key_hash = auth_util.generate_api_key_and_hash()
# store key_hash in DB
api_key_api.key_hash = api_key_hash
api_key_db = ApiKey.add_or_update(ApiKeyAPI.to_model(api_key_api))
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for api_key data=%s.', api_key_api)
+ LOG.exception("Validation failed for api_key data=%s.", api_key_api)
abort(http_client.BAD_REQUEST, six.text_type(e))
- extra = {'api_key_db': api_key_db}
- LOG.audit('ApiKey created. ApiKey.id=%s' % (api_key_db.id), extra=extra)
+ extra = {"api_key_db": api_key_db}
+ LOG.audit("ApiKey created. ApiKey.id=%s" % (api_key_db.id), extra=extra)
api_key_create_response_api = ApiKeyCreateResponseAPI.from_model(api_key_db)
# Return real api_key back to user. A one-way hash of the api_key is stored in the DB
@@ -178,9 +184,11 @@ def put(self, api_key_api, api_key_id_or_key, requester_user):
permission_type = PermissionType.API_KEY_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=api_key_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=api_key_db,
+ permission_type=permission_type,
+ )
old_api_key_db = api_key_db
api_key_db = ApiKeyAPI.to_model(api_key_api)
@@ -191,7 +199,7 @@ def put(self, api_key_api, api_key_id_or_key, requester_user):
user_db = UserDB(name=api_key_api.user)
User.add_or_update(user_db)
- extra = {'username': api_key_api.user, 'user': user_db}
+ extra = {"username": api_key_api.user, "user": user_db}
LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra)
# Passing in key_hash as MASKED_ATTRIBUTE_VALUE is expected since we do not
@@ -203,36 +211,38 @@ def put(self, api_key_api, api_key_id_or_key, requester_user):
# Rather than silently ignore any update to key_hash it is better to explicitly
# disallow and notify user.
if old_api_key_db.key_hash != api_key_db.key_hash:
- raise ValueError('Update of key_hash is not allowed.')
+ raise ValueError("Update of key_hash is not allowed.")
api_key_db.id = old_api_key_db.id
api_key_db = ApiKey.add_or_update(api_key_db)
- extra = {'old_api_key_db': old_api_key_db, 'new_api_key_db': api_key_db}
- LOG.audit('API Key updated. ApiKey.id=%s.' % (api_key_db.id), extra=extra)
+ extra = {"old_api_key_db": old_api_key_db, "new_api_key_db": api_key_db}
+ LOG.audit("API Key updated. ApiKey.id=%s." % (api_key_db.id), extra=extra)
api_key_api = ApiKeyAPI.from_model(api_key_db)
return api_key_api
def delete(self, api_key_id_or_key, requester_user):
"""
- Delete the key value pair.
+ Delete the key value pair.
- Handles requests:
- DELETE /apikeys/1
+ Handles requests:
+ DELETE /apikeys/1
"""
api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key)
permission_type = PermissionType.API_KEY_DELETE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=api_key_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=api_key_db,
+ permission_type=permission_type,
+ )
ApiKey.delete(api_key_db)
- extra = {'api_key_db': api_key_db}
- LOG.audit('ApiKey deleted. ApiKey.id=%s' % (api_key_db.id), extra=extra)
+ extra = {"api_key_db": api_key_db}
+ LOG.audit("ApiKey deleted. ApiKey.id=%s" % (api_key_db.id), extra=extra)
return Response(status=http_client.NO_CONTENT)
diff --git a/st2api/st2api/controllers/v1/execution_views.py b/st2api/st2api/controllers/v1/execution_views.py
index 9a61bdc321..f4240b94ab 100644
--- a/st2api/st2api/controllers/v1/execution_views.py
+++ b/st2api/st2api/controllers/v1/execution_views.py
@@ -29,51 +29,51 @@
# response. Failure to do so will eventually result in Chrome hanging out while opening History
# tab of st2web.
SUPPORTED_FILTERS = {
- 'action': 'action.ref',
- 'status': 'status',
- 'liveaction': 'liveaction.id',
- 'parent': 'parent',
- 'rule': 'rule.name',
- 'runner': 'runner.name',
- 'timestamp': 'start_timestamp',
- 'trigger': 'trigger.name',
- 'trigger_type': 'trigger_type.name',
- 'trigger_instance': 'trigger_instance.id',
- 'user': 'context.user'
+ "action": "action.ref",
+ "status": "status",
+ "liveaction": "liveaction.id",
+ "parent": "parent",
+ "rule": "rule.name",
+ "runner": "runner.name",
+ "timestamp": "start_timestamp",
+ "trigger": "trigger.name",
+ "trigger_type": "trigger_type.name",
+ "trigger_instance": "trigger_instance.id",
+ "user": "context.user",
}
# A list of fields for which null (None) is a valid value which we include in the list of valid
# filters.
FILTERS_WITH_VALID_NULL_VALUES = [
- 'parent',
- 'rule',
- 'trigger',
- 'trigger_type',
- 'trigger_instance'
+ "parent",
+ "rule",
+ "trigger",
+ "trigger_type",
+ "trigger_instance",
]
# List of filters that are too broad to distinct by them and are very likely to represent 1 to 1
# relation between filter and particular history record.
-IGNORE_FILTERS = ['parent', 'timestamp', 'liveaction', 'trigger_instance']
+IGNORE_FILTERS = ["parent", "timestamp", "liveaction", "trigger_instance"]
class FiltersController(object):
def get_all(self, types=None):
"""
- List all distinct filters.
+ List all distinct filters.
- Handles requests:
- GET /executions/views/filters[?types=action,rule]
+ Handles requests:
+ GET /executions/views/filters[?types=action,rule]
- :param types: Comma delimited string of filter types to output.
- :type types: ``str``
+ :param types: Comma delimited string of filter types to output.
+ :type types: ``str``
"""
filters = {}
for name, field in six.iteritems(SUPPORTED_FILTERS):
if name not in IGNORE_FILTERS and (not types or name in types):
if name not in FILTERS_WITH_VALID_NULL_VALUES:
- query = {field.replace('.', '__'): {'$ne': None}}
+ query = {field.replace(".", "__"): {"$ne": None}}
else:
query = {}
diff --git a/st2api/st2api/controllers/v1/inquiries.py b/st2api/st2api/controllers/v1/inquiries.py
index a892076917..fb3bf2e3f0 100644
--- a/st2api/st2api/controllers/v1/inquiries.py
+++ b/st2api/st2api/controllers/v1/inquiries.py
@@ -34,13 +34,11 @@
from st2common.services import inquiry as inquiry_service
-__all__ = [
- 'InquiriesController'
-]
+__all__ = ["InquiriesController"]
LOG = logging.getLogger(__name__)
-INQUIRY_RUNNER = 'inquirer'
+INQUIRY_RUNNER = "inquirer"
class InquiriesController(ResourceController):
@@ -55,12 +53,18 @@ class InquiriesController(ResourceController):
model = inqy_api_models.InquiryAPI
access = ex_db_access.ActionExecution
- def get_all(self, exclude_attributes=None, include_attributes=None, requester_user=None,
- limit=None, **raw_filters):
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ requester_user=None,
+ limit=None,
+ **raw_filters,
+ ):
"""Retrieve multiple Inquiries
- Handles requests:
- GET /inquiries/
+ Handles requests:
+ GET /inquiries/
"""
# NOTE: This controller retrieves execution objects and returns a new model composed of
@@ -70,13 +74,13 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us
# filtering before returning the response.
raw_inquiries = super(InquiriesController, self)._get_all(
exclude_fields=[],
- include_fields=['id', 'result'],
+ include_fields=["id", "result"],
limit=limit,
raw_filters={
- 'status': action_constants.LIVEACTION_STATUS_PENDING,
- 'runner': INQUIRY_RUNNER
+ "status": action_constants.LIVEACTION_STATUS_PENDING,
+ "runner": INQUIRY_RUNNER,
},
- requester_user=requester_user
+ requester_user=requester_user,
)
# Since "model" is set to InquiryAPI (for good reasons), _get_all returns a list of
@@ -90,18 +94,18 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us
# Repackage into Response with correct headers
resp = api_router.Response(json=inquiries)
- resp.headers['X-Total-Count'] = raw_inquiries.headers['X-Total-Count']
+ resp.headers["X-Total-Count"] = raw_inquiries.headers["X-Total-Count"]
if limit:
- resp.headers['X-Limit'] = str(limit)
+ resp.headers["X-Limit"] = str(limit)
return resp
def get_one(self, inquiry_id, requester_user=None):
"""Retrieve a single Inquiry
- Handles requests:
- GET /inquiries/
+ Handles requests:
+ GET /inquiries/
"""
# Retrieve the inquiry by id.
@@ -110,7 +114,7 @@ def get_one(self, inquiry_id, requester_user=None):
inquiry = self._get_one_by_id(
id=inquiry_id,
requester_user=requester_user,
- permission_type=rbac_types.PermissionType.INQUIRY_VIEW
+ permission_type=rbac_types.PermissionType.INQUIRY_VIEW,
)
except db_exceptions.StackStormDBObjectNotFoundError as e:
LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id)
@@ -132,15 +136,18 @@ def get_one(self, inquiry_id, requester_user=None):
def put(self, inquiry_id, response_data, requester_user):
"""Provide response data to an Inquiry
- In general, provided the response data validates against the provided
- schema, and the user has the appropriate permissions to respond,
- this will set the Inquiry execution to a successful status, and resume
- the parent workflow.
+ In general, provided the response data validates against the provided
+ schema, and the user has the appropriate permissions to respond,
+ this will set the Inquiry execution to a successful status, and resume
+ the parent workflow.
- Handles requests:
- PUT /inquiries/
+ Handles requests:
+ PUT /inquiries/
"""
- LOG.debug("Inquiry %s received response payload: %s" % (inquiry_id, response_data.response))
+ LOG.debug(
+ "Inquiry %s received response payload: %s"
+ % (inquiry_id, response_data.response)
+ )
# Set requester to system user if not provided.
if not requester_user:
@@ -151,7 +158,7 @@ def put(self, inquiry_id, response_data, requester_user):
inquiry = self._get_one_by_id(
id=inquiry_id,
requester_user=requester_user,
- permission_type=rbac_types.PermissionType.INQUIRY_RESPOND
+ permission_type=rbac_types.PermissionType.INQUIRY_RESPOND,
)
except db_exceptions.StackStormDBObjectNotFoundError as e:
LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id)
@@ -186,18 +193,23 @@ def put(self, inquiry_id, response_data, requester_user):
# Respond to inquiry and update if there is a partial response.
try:
- inquiry_service.respond(inquiry, response_data.response, requester=requester_user)
+ inquiry_service.respond(
+ inquiry, response_data.response, requester=requester_user
+ )
except Exception as e:
LOG.exception('Fail to update response for inquiry "%s".' % inquiry_id)
api_router.abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
- return {
- 'id': inquiry_id,
- 'response': response_data.response
- }
+ return {"id": inquiry_id, "response": response_data.response}
- def _get_one_by_id(self, id, requester_user, permission_type,
- exclude_fields=None, from_model_kwargs=None):
+ def _get_one_by_id(
+ self,
+ id,
+ requester_user,
+ permission_type,
+ exclude_fields=None,
+ from_model_kwargs=None,
+ ):
"""Override ResourceController._get_one_by_id to contain scope of Inquiries UID hack
:param exclude_fields: A list of object fields to exclude.
:type exclude_fields: ``list``
@@ -215,8 +227,11 @@ def _get_one_by_id(self, id, requester_user, permission_type,
# "inquiry:".
#
# TODO (mierdin): All of this should be removed once Inquiries get their own DB model.
- if (execution_db and getattr(execution_db, 'runner', None) and
- execution_db.runner.get('runner_module') == INQUIRY_RUNNER):
+ if (
+ execution_db
+ and getattr(execution_db, "runner", None)
+ and execution_db.runner.get("runner_module") == INQUIRY_RUNNER
+ ):
execution_db.get_uid = get_uid
LOG.debug('Checking permission on inquiry "%s".' % id)
@@ -226,7 +241,7 @@ def _get_one_by_id(self, id, requester_user, permission_type,
rbac_utils.assert_user_has_resource_db_permission(
user_db=requester_user,
resource_db=execution_db,
- permission_type=permission_type
+ permission_type=permission_type,
)
from_model_kwargs = from_model_kwargs or {}
@@ -237,9 +252,8 @@ def _get_one_by_id(self, id, requester_user, permission_type,
def get_uid():
- """Inquiry UID hack for RBAC
- """
- return 'inquiry'
+ """Inquiry UID hack for RBAC"""
+ return "inquiry"
inquiries_controller = InquiriesController()
diff --git a/st2api/st2api/controllers/v1/keyvalue.py b/st2api/st2api/controllers/v1/keyvalue.py
index eab8cb025a..2bd8449e24 100644
--- a/st2api/st2api/controllers/v1/keyvalue.py
+++ b/st2api/st2api/controllers/v1/keyvalue.py
@@ -24,7 +24,10 @@
from st2common.constants.keyvalue import ALL_SCOPE, FULL_SYSTEM_SCOPE, SYSTEM_SCOPE
from st2common.constants.keyvalue import FULL_USER_SCOPE, USER_SCOPE, ALLOWED_SCOPES
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException
+from st2common.exceptions.keyvalue import (
+ CryptoKeyNotSetupException,
+ InvalidScopeException,
+)
from st2common.models.api.keyvalue import KeyValuePairAPI
from st2common.models.db.auth import UserDB
from st2common.persistence.keyvalue import KeyValuePair
@@ -40,9 +43,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'KeyValuePairController'
-]
+__all__ = ["KeyValuePairController"]
class KeyValuePairController(ResourceController):
@@ -52,22 +53,21 @@ class KeyValuePairController(ResourceController):
model = KeyValuePairAPI
access = KeyValuePair
- supported_filters = {
- 'prefix': 'name__startswith',
- 'scope': 'scope'
- }
+ supported_filters = {"prefix": "name__startswith", "scope": "scope"}
def __init__(self):
super(KeyValuePairController, self).__init__()
self._coordinator = coordination.get_coordinator()
self.get_one_db_method = self._get_by_name
- def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False):
+ def get_one(
+ self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False
+ ):
"""
- List key by name.
+ List key by name.
- Handle:
- GET /keys/key1
+ Handle:
+ GET /keys/key1
"""
if not scope:
# Default to system scope
@@ -84,8 +84,9 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr
self._validate_scope(scope=scope)
# User needs to be either admin or requesting item for itself
- self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope,
- requester_user=requester_user)
+ self._validate_decrypt_query_parameter(
+ decrypt=decrypt, scope=scope, requester_user=requester_user
+ )
user_query_param_filter = bool(user)
@@ -95,45 +96,56 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr
rbac_utils = get_rbac_backend().get_utils_class()
# Validate that the authenticated user is admin if user query param is provided
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user,
- require_rbac=True)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user, require_rbac=True
+ )
# Additional guard to ensure there is no information leakage across users
is_admin = rbac_utils.user_is_admin(user_db=requester_user)
if is_admin and user_query_param_filter:
# Retrieve values scoped to the provided user
- user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE, user=user)
+ user_scope_prefix = get_key_reference(
+ name=name, scope=USER_SCOPE, user=user
+ )
else:
# RBAC not enabled or user is not an admin, retrieve user scoped values for the
# current user
- user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE,
- user=current_user)
+ user_scope_prefix = get_key_reference(
+ name=name, scope=USER_SCOPE, user=current_user
+ )
if scope == FULL_USER_SCOPE:
key_ref = user_scope_prefix
elif scope == FULL_SYSTEM_SCOPE:
key_ref = get_key_reference(scope=FULL_SYSTEM_SCOPE, name=name, user=user)
else:
- raise ValueError('Invalid scope: %s' % (scope))
+ raise ValueError("Invalid scope: %s" % (scope))
- from_model_kwargs = {'mask_secrets': not decrypt}
+ from_model_kwargs = {"mask_secrets": not decrypt}
kvp_api = self._get_one_by_scope_and_name(
- name=key_ref,
- scope=scope,
- from_model_kwargs=from_model_kwargs
+ name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs
)
return kvp_api
- def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=None,
- decrypt=False, sort=None, offset=0, limit=None, **raw_filters):
+ def get_all(
+ self,
+ requester_user,
+ prefix=None,
+ scope=FULL_SYSTEM_SCOPE,
+ user=None,
+ decrypt=False,
+ sort=None,
+ offset=0,
+ limit=None,
+ **raw_filters,
+ ):
"""
- List all keys.
+ List all keys.
- Handles requests:
- GET /keys/
+ Handles requests:
+ GET /keys/
"""
if not scope:
# Default to system scope
@@ -152,8 +164,9 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non
self._validate_all_scope(scope=scope, requester_user=requester_user)
# User needs to be either admin or requesting items for themselves
- self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope,
- requester_user=requester_user)
+ self._validate_decrypt_query_parameter(
+ decrypt=decrypt, scope=scope, requester_user=requester_user
+ )
user_query_param_filter = bool(user)
@@ -163,15 +176,15 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non
rbac_utils = get_rbac_backend().get_utils_class()
# Validate that the authenticated user is admin if user query param is provided
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user,
- require_rbac=True)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user, require_rbac=True
+ )
- from_model_kwargs = {'mask_secrets': not decrypt}
+ from_model_kwargs = {"mask_secrets": not decrypt}
if scope and scope not in ALL_SCOPE:
self._validate_scope(scope=scope)
- raw_filters['scope'] = scope
+ raw_filters["scope"] = scope
# Set prefix which will be used for user-scoped items.
# NOTE: It's very important raw_filters['prefix'] is set when requesting user scoped items
@@ -180,47 +193,52 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non
if is_admin and user_query_param_filter:
# Retrieve values scoped to the provided user
- user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE, user=user)
+ user_scope_prefix = get_key_reference(
+ name=prefix or "", scope=USER_SCOPE, user=user
+ )
else:
# RBAC not enabled or user is not an admin, retrieve user scoped values for the
# current user
- user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE,
- user=current_user)
+ user_scope_prefix = get_key_reference(
+ name=prefix or "", scope=USER_SCOPE, user=current_user
+ )
if scope == ALL_SCOPE:
# Special case for ALL_SCOPE
# 1. Retrieve system scoped values
- raw_filters['scope'] = FULL_SYSTEM_SCOPE
- raw_filters['prefix'] = prefix
+ raw_filters["scope"] = FULL_SYSTEM_SCOPE
+ raw_filters["prefix"] = prefix
- assert 'scope' in raw_filters
+ assert "scope" in raw_filters
kvp_apis_system = super(KeyValuePairController, self)._get_all(
from_model_kwargs=from_model_kwargs,
sort=sort,
offset=offset,
limit=limit,
raw_filters=raw_filters,
- requester_user=requester_user)
+ requester_user=requester_user,
+ )
# 2. Retrieve user scoped items for current user or for all the users (depending if the
# authenticated user is admin and if ?user is provided)
- raw_filters['scope'] = FULL_USER_SCOPE
+ raw_filters["scope"] = FULL_USER_SCOPE
if cfg.CONF.rbac.enable and is_admin and not user_query_param_filter:
# Admin user retrieving user-scoped items for all the users
- raw_filters['prefix'] = prefix or ''
+ raw_filters["prefix"] = prefix or ""
else:
- raw_filters['prefix'] = user_scope_prefix
+ raw_filters["prefix"] = user_scope_prefix
- assert 'scope' in raw_filters
- assert 'prefix' in raw_filters
+ assert "scope" in raw_filters
+ assert "prefix" in raw_filters
kvp_apis_user = super(KeyValuePairController, self)._get_all(
from_model_kwargs=from_model_kwargs,
sort=sort,
offset=offset,
limit=limit,
raw_filters=raw_filters,
- requester_user=requester_user)
+ requester_user=requester_user,
+ )
# Combine the result
kvp_apis = []
@@ -228,31 +246,33 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non
kvp_apis.extend(kvp_apis_user.json or [])
elif scope in [USER_SCOPE, FULL_USER_SCOPE]:
# Make sure we only returned values scoped to current user
- prefix = get_key_reference(name=prefix or '', scope=scope, user=user)
- raw_filters['prefix'] = user_scope_prefix
+ prefix = get_key_reference(name=prefix or "", scope=scope, user=user)
+ raw_filters["prefix"] = user_scope_prefix
- assert 'scope' in raw_filters
- assert 'prefix' in raw_filters
+ assert "scope" in raw_filters
+ assert "prefix" in raw_filters
kvp_apis = super(KeyValuePairController, self)._get_all(
from_model_kwargs=from_model_kwargs,
sort=sort,
offset=offset,
limit=limit,
raw_filters=raw_filters,
- requester_user=requester_user)
+ requester_user=requester_user,
+ )
elif scope in [SYSTEM_SCOPE, FULL_SYSTEM_SCOPE]:
- raw_filters['prefix'] = prefix
+ raw_filters["prefix"] = prefix
- assert 'scope' in raw_filters
+ assert "scope" in raw_filters
kvp_apis = super(KeyValuePairController, self)._get_all(
from_model_kwargs=from_model_kwargs,
sort=sort,
offset=offset,
limit=limit,
raw_filters=raw_filters,
- requester_user=requester_user)
+ requester_user=requester_user,
+ )
else:
- raise ValueError('Invalid scope: %s' % (scope))
+ raise ValueError("Invalid scope: %s" % (scope))
return kvp_apis
@@ -266,42 +286,42 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE):
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
- scope = getattr(kvp, 'scope', scope)
+ scope = getattr(kvp, "scope", scope)
scope = get_datastore_full_scope(scope)
self._validate_scope(scope=scope)
- user = getattr(kvp, 'user', requester_user.name) or requester_user.name
+ user = getattr(kvp, "user", requester_user.name) or requester_user.name
# Validate that the authenticated user is admin if user query param is provided
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user,
- require_rbac=True)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user, require_rbac=True
+ )
# Validate that encrypted option can only be used by admins
- encrypted = getattr(kvp, 'encrypted', False)
- self._validate_encrypted_query_parameter(encrypted=encrypted, scope=scope,
- requester_user=requester_user)
+ encrypted = getattr(kvp, "encrypted", False)
+ self._validate_encrypted_query_parameter(
+ encrypted=encrypted, scope=scope, requester_user=requester_user
+ )
key_ref = get_key_reference(scope=scope, name=name, user=user)
lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope)
- LOG.debug('PUT scope: %s, name: %s', scope, name)
+ LOG.debug("PUT scope: %s, name: %s", scope, name)
# TODO: Custom permission check since the key doesn't need to exist here
# Note: We use lock to avoid a race
with self._coordinator.get_lock(lock_name):
try:
existing_kvp_api = self._get_one_by_scope_and_name(
- scope=scope,
- name=key_ref
+ scope=scope, name=key_ref
)
except StackStormDBObjectNotFoundError:
existing_kvp_api = None
# st2client sends invalid id when initially setting a key so we ignore those
- id_ = kvp.__dict__.get('id', None)
+ id_ = kvp.__dict__.get("id", None)
if not existing_kvp_api and id_ and not bson.ObjectId.is_valid(id_):
- del kvp.__dict__['id']
+ del kvp.__dict__["id"]
kvp.name = key_ref
kvp.scope = scope
@@ -314,7 +334,7 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE):
kvp_db = KeyValuePair.add_or_update(kvp_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for key value data=%s', kvp)
+ LOG.exception("Validation failed for key value data=%s", kvp)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
except CryptoKeyNotSetupException as e:
@@ -325,18 +345,18 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE):
LOG.exception(six.text_type(e))
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'kvp_db': kvp_db}
- LOG.audit('KeyValuePair updated. KeyValuePair.id=%s' % (kvp_db.id), extra=extra)
+ extra = {"kvp_db": kvp_db}
+ LOG.audit("KeyValuePair updated. KeyValuePair.id=%s" % (kvp_db.id), extra=extra)
kvp_api = KeyValuePairAPI.from_model(kvp_db)
return kvp_api
def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None):
"""
- Delete the key value pair.
+ Delete the key value pair.
- Handles requests:
- DELETE /keys/1
+ Handles requests:
+ DELETE /keys/1
"""
if not scope:
scope = FULL_SYSTEM_SCOPE
@@ -351,37 +371,42 @@ def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None):
# Validate that the authenticated user is admin if user query param is provided
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user,
- require_rbac=True)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user, require_rbac=True
+ )
key_ref = get_key_reference(scope=scope, name=name, user=user)
lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope)
# Note: We use lock to avoid a race
with self._coordinator.get_lock(lock_name):
- from_model_kwargs = {'mask_secrets': True}
+ from_model_kwargs = {"mask_secrets": True}
kvp_api = self._get_one_by_scope_and_name(
- name=key_ref,
- scope=scope,
- from_model_kwargs=from_model_kwargs
+ name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs
)
kvp_db = KeyValuePairAPI.to_model(kvp_api)
- LOG.debug('DELETE /keys/ lookup with scope=%s name=%s found object: %s',
- scope, name, kvp_db)
+ LOG.debug(
+ "DELETE /keys/ lookup with scope=%s name=%s found object: %s",
+ scope,
+ name,
+ kvp_db,
+ )
try:
KeyValuePair.delete(kvp_db)
except Exception as e:
- LOG.exception('Database delete encountered exception during '
- 'delete of name="%s". ', name)
+ LOG.exception(
+ "Database delete encountered exception during "
+ 'delete of name="%s". ',
+ name,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
- extra = {'kvp_db': kvp_db}
- LOG.audit('KeyValuePair deleted. KeyValuePair.id=%s' % (kvp_db.id), extra=extra)
+ extra = {"kvp_db": kvp_db}
+ LOG.audit("KeyValuePair deleted. KeyValuePair.id=%s" % (kvp_db.id), extra=extra)
return Response(status=http_client.NO_CONTENT)
@@ -392,7 +417,7 @@ def _get_lock_name_for_key(self, name, scope=FULL_SYSTEM_SCOPE):
:param name: Datastore item name (PK).
:type name: ``str``
"""
- lock_name = six.b('kvp-crud-%s.%s' % (scope, name))
+ lock_name = six.b("kvp-crud-%s.%s" % (scope, name))
return lock_name
def _validate_all_scope(self, scope, requester_user):
@@ -400,7 +425,7 @@ def _validate_all_scope(self, scope, requester_user):
Validate that "all" scope can only be provided by admins on RBAC installations.
"""
scope = get_datastore_full_scope(scope)
- is_all_scope = (scope == ALL_SCOPE)
+ is_all_scope = scope == ALL_SCOPE
rbac_utils = get_rbac_backend().get_utils_class()
is_admin = rbac_utils.user_is_admin(user_db=requester_user)
@@ -415,22 +440,25 @@ def _validate_decrypt_query_parameter(self, decrypt, scope, requester_user):
"""
rbac_utils = get_rbac_backend().get_utils_class()
is_admin = rbac_utils.user_is_admin(user_db=requester_user)
- is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE)
+ is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE
if decrypt and (not is_user_scope and not is_admin):
- msg = 'Decrypt option requires administrator access'
+ msg = "Decrypt option requires administrator access"
raise AccessDeniedError(message=msg, user_db=requester_user)
def _validate_encrypted_query_parameter(self, encrypted, scope, requester_user):
rbac_utils = get_rbac_backend().get_utils_class()
is_admin = rbac_utils.user_is_admin(user_db=requester_user)
if encrypted and not is_admin:
- msg = 'Pre-encrypted option requires administrator access'
+ msg = "Pre-encrypted option requires administrator access"
raise AccessDeniedError(message=msg, user_db=requester_user)
def _validate_scope(self, scope):
if scope not in ALLOWED_SCOPES:
- msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES)
+ msg = "Scope %s is not in allowed scopes list: %s." % (
+ scope,
+ ALLOWED_SCOPES,
+ )
raise ValueError(msg)
diff --git a/st2api/st2api/controllers/v1/pack_config_schemas.py b/st2api/st2api/controllers/v1/pack_config_schemas.py
index 551573e12e..933a7ab500 100644
--- a/st2api/st2api/controllers/v1/pack_config_schemas.py
+++ b/st2api/st2api/controllers/v1/pack_config_schemas.py
@@ -23,9 +23,7 @@
http_client = six.moves.http_client
-__all__ = [
- 'PackConfigSchemasController'
-]
+__all__ = ["PackConfigSchemasController"]
class PackConfigSchemasController(ResourceController):
@@ -40,7 +38,9 @@ def __init__(self):
# this case, RBAC is checked on the parent PackDB object
self.get_one_db_method = packs_service.get_pack_by_ref
- def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters):
+ def get_all(
+ self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters
+ ):
"""
Retrieve config schema for all the packs.
@@ -48,11 +48,13 @@ def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_fi
GET /config_schema/
"""
- return super(PackConfigSchemasController, self)._get_all(sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ return super(PackConfigSchemasController, self)._get_all(
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, pack_ref, requester_user):
"""
@@ -61,7 +63,9 @@ def get_one(self, pack_ref, requester_user):
Handles requests:
GET /config_schema/
"""
- packs_controller._get_one_by_ref_or_id(ref_or_id=pack_ref, requester_user=requester_user)
+ packs_controller._get_one_by_ref_or_id(
+ ref_or_id=pack_ref, requester_user=requester_user
+ )
return self._get_one_by_pack_ref(pack_ref=pack_ref)
diff --git a/st2api/st2api/controllers/v1/pack_configs.py b/st2api/st2api/controllers/v1/pack_configs.py
index 6eb18c7a34..4123a3cb22 100644
--- a/st2api/st2api/controllers/v1/pack_configs.py
+++ b/st2api/st2api/controllers/v1/pack_configs.py
@@ -35,9 +35,7 @@
http_client = six.moves.http_client
-__all__ = [
- 'PackConfigsController'
-]
+__all__ = ["PackConfigsController"]
LOG = logging.getLogger(__name__)
@@ -54,8 +52,15 @@ def __init__(self):
# this case, RBAC is checked on the parent PackDB object
self.get_one_db_method = packs_service.get_pack_by_ref
- def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets=False,
- **raw_filters):
+ def get_all(
+ self,
+ requester_user,
+ sort=None,
+ offset=0,
+ limit=None,
+ show_secrets=False,
+ **raw_filters,
+ ):
"""
Retrieve configs for all the packs.
@@ -63,14 +68,18 @@ def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets=
GET /configs/
"""
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
- return super(PackConfigsController, self)._get_all(sort=sort,
- offset=offset,
- limit=limit,
- from_model_kwargs=from_model_kwargs,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ return super(PackConfigsController, self)._get_all(
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ from_model_kwargs=from_model_kwargs,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, pack_ref, requester_user, show_secrets=False):
"""
@@ -80,7 +89,9 @@ def get_one(self, pack_ref, requester_user, show_secrets=False):
GET /configs/
"""
from_model_kwargs = {
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ )
}
try:
instance = packs_service.get_pack_by_ref(pack_ref=pack_ref)
@@ -89,18 +100,22 @@ def get_one(self, pack_ref, requester_user, show_secrets=False):
abort(http_client.NOT_FOUND, msg)
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=PermissionType.PACK_VIEW)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=PermissionType.PACK_VIEW,
+ )
- return self._get_one_by_pack_ref(pack_ref=pack_ref, from_model_kwargs=from_model_kwargs)
+ return self._get_one_by_pack_ref(
+ pack_ref=pack_ref, from_model_kwargs=from_model_kwargs
+ )
def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False):
"""
- Create a new config for a pack.
+ Create a new config for a pack.
- Handles requests:
- POST /configs/
+ Handles requests:
+ POST /configs/
"""
try:
@@ -121,9 +136,9 @@ def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False)
def _dump_config_to_disk(self, config_api):
config_content = yaml.safe_dump(config_api.values, default_flow_style=False)
- configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/')
- config_path = os.path.join(configs_path, '%s.yaml' % config_api.pack)
- with open(config_path, 'w') as f:
+ configs_path = os.path.join(cfg.CONF.system.base_path, "configs/")
+ config_path = os.path.join(configs_path, "%s.yaml" % config_api.pack)
+ with open(config_path, "w") as f:
f.write(config_content)
diff --git a/st2api/st2api/controllers/v1/pack_views.py b/st2api/st2api/controllers/v1/pack_views.py
index 5e8b310c33..4fd6f9dd3a 100644
--- a/st2api/st2api/controllers/v1/pack_views.py
+++ b/st2api/st2api/controllers/v1/pack_views.py
@@ -33,10 +33,7 @@
http_client = six.moves.http_client
-__all__ = [
- 'FilesController',
- 'FileController'
-]
+__all__ = ["FilesController", "FileController"]
http_client = six.moves.http_client
@@ -46,12 +43,10 @@
# Maximum file size in bytes. If the file on disk is larger then this value, we don't include it
# in the response. This prevents DDoS / exhaustion attacks.
-MAX_FILE_SIZE = (500 * 1000)
+MAX_FILE_SIZE = 500 * 1000
# File paths in the file controller for which RBAC checks are not performed
-WHITELISTED_FILE_PATHS = [
- 'icon.png'
-]
+WHITELISTED_FILE_PATHS = ["icon.png"]
class BaseFileController(BasePacksController):
@@ -76,7 +71,7 @@ def _get_file_stats(self, file_path):
return file_stats.st_size, file_stats.st_mtime
def _get_file_content(self, file_path):
- with codecs.open(file_path, 'rb') as fp:
+ with codecs.open(file_path, "rb") as fp:
content = fp.read()
return content
@@ -105,17 +100,19 @@ def __init__(self):
def get_one(self, ref_or_id, requester_user):
"""
- Outputs the content of all the files inside the pack.
+ Outputs the content of all the files inside the pack.
- Handles requests:
- GET /packs/views/files/
+ Handles requests:
+ GET /packs/views/files/
"""
pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=pack_db,
- permission_type=PermissionType.PACK_VIEW)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=pack_db,
+ permission_type=PermissionType.PACK_VIEW,
+ )
if not pack_db:
msg = 'Pack with ref_or_id "%s" does not exist' % (ref_or_id)
@@ -126,15 +123,19 @@ def get_one(self, ref_or_id, requester_user):
result = []
for file_path in pack_files:
- normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path)
+ normalized_file_path = get_pack_file_abs_path(
+ pack_ref=pack_ref, file_path=file_path
+ )
if not normalized_file_path or not os.path.isfile(normalized_file_path):
# Ignore references to files which don't exist on disk
continue
file_size = self._get_file_size(file_path=normalized_file_path)
if file_size is not None and file_size > MAX_FILE_SIZE:
- LOG.debug('Skipping file "%s" which size exceeds max file size (%s bytes)' %
- (normalized_file_path, MAX_FILE_SIZE))
+ LOG.debug(
+ 'Skipping file "%s" which size exceeds max file size (%s bytes)'
+ % (normalized_file_path, MAX_FILE_SIZE)
+ )
continue
content = self._get_file_content(file_path=normalized_file_path)
@@ -144,10 +145,7 @@ def get_one(self, ref_or_id, requester_user):
LOG.debug('Skipping binary file "%s"' % (normalized_file_path))
continue
- item = {
- 'file_path': file_path,
- 'content': content
- }
+ item = {"file_path": file_path, "content": content}
result.append(item)
return result
@@ -173,13 +171,19 @@ class FileController(BaseFileController):
Controller which allows user to retrieve content of a specific file in a pack.
"""
- def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None,
- if_modified_since=None):
+ def get_one(
+ self,
+ ref_or_id,
+ file_path,
+ requester_user,
+ if_none_match=None,
+ if_modified_since=None,
+ ):
"""
- Outputs the content of a specific file in a pack.
+ Outputs the content of a specific file in a pack.
- Handles requests:
- GET /packs/views/file//
+ Handles requests:
+ GET /packs/views/file//
"""
pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id)
@@ -188,7 +192,7 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None,
raise StackStormDBObjectNotFoundError(msg)
if not file_path:
- raise ValueError('Missing file path')
+ raise ValueError("Missing file path")
pack_ref = pack_db.ref
@@ -196,11 +200,15 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None,
permission_type = PermissionType.PACK_VIEW
if file_path not in WHITELISTED_FILE_PATHS:
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=pack_db,
- permission_type=permission_type)
-
- normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=pack_db,
+ permission_type=permission_type,
+ )
+
+ normalized_file_path = get_pack_file_abs_path(
+ pack_ref=pack_ref, file_path=file_path
+ )
if not normalized_file_path or not os.path.isfile(normalized_file_path):
# Ignore references to files which don't exist on disk
raise StackStormDBObjectNotFoundError('File "%s" not found' % (file_path))
@@ -209,24 +217,28 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None,
response = Response()
- if not self._is_file_changed(file_mtime,
- if_none_match=if_none_match,
- if_modified_since=if_modified_since):
+ if not self._is_file_changed(
+ file_mtime, if_none_match=if_none_match, if_modified_since=if_modified_since
+ ):
response.status = http_client.NOT_MODIFIED
else:
if file_size is not None and file_size > MAX_FILE_SIZE:
- msg = ('File %s exceeds maximum allowed file size (%s bytes)' %
- (file_path, MAX_FILE_SIZE))
+ msg = "File %s exceeds maximum allowed file size (%s bytes)" % (
+ file_path,
+ MAX_FILE_SIZE,
+ )
raise ValueError(msg)
- content_type = mimetypes.guess_type(normalized_file_path)[0] or \
- 'application/octet-stream'
+ content_type = (
+ mimetypes.guess_type(normalized_file_path)[0]
+ or "application/octet-stream"
+ )
- response.headers['Content-Type'] = content_type
+ response.headers["Content-Type"] = content_type
response.body = self._get_file_content(file_path=normalized_file_path)
- response.headers['Last-Modified'] = format_date_time(file_mtime)
- response.headers['ETag'] = repr(file_mtime)
+ response.headers["Last-Modified"] = format_date_time(file_mtime)
+ response.headers["ETag"] = repr(file_mtime)
return response
diff --git a/st2api/st2api/controllers/v1/packs.py b/st2api/st2api/controllers/v1/packs.py
index 6193a3f01f..75da16e5e5 100644
--- a/st2api/st2api/controllers/v1/packs.py
+++ b/st2api/st2api/controllers/v1/packs.py
@@ -52,115 +52,119 @@
http_client = six.moves.http_client
-__all__ = [
- 'PacksController',
- 'BasePacksController',
- 'ENTITIES'
-]
+__all__ = ["PacksController", "BasePacksController", "ENTITIES"]
LOG = logging.getLogger(__name__)
# Note: The order those are defined it's important so they are registered in
# the same order as they are in st2-register-content.
# We also need to use list of tuples to preserve the order.
-ENTITIES = OrderedDict([
- ('trigger', (TriggersRegistrar, 'triggers')),
- ('sensor', (SensorsRegistrar, 'sensors')),
- ('action', (ActionsRegistrar, 'actions')),
- ('rule', (RulesRegistrar, 'rules')),
- ('alias', (AliasesRegistrar, 'aliases')),
- ('policy', (PolicyRegistrar, 'policies')),
- ('config', (ConfigsRegistrar, 'configs'))
-])
+ENTITIES = OrderedDict(
+ [
+ ("trigger", (TriggersRegistrar, "triggers")),
+ ("sensor", (SensorsRegistrar, "sensors")),
+ ("action", (ActionsRegistrar, "actions")),
+ ("rule", (RulesRegistrar, "rules")),
+ ("alias", (AliasesRegistrar, "aliases")),
+ ("policy", (PolicyRegistrar, "policies")),
+ ("config", (ConfigsRegistrar, "configs")),
+ ]
+)
def _get_proxy_config():
- LOG.debug('Loading proxy configuration from env variables %s.', os.environ)
- http_proxy = os.environ.get('http_proxy', None)
- https_proxy = os.environ.get('https_proxy', None)
- no_proxy = os.environ.get('no_proxy', None)
- proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None)
+ LOG.debug("Loading proxy configuration from env variables %s.", os.environ)
+ http_proxy = os.environ.get("http_proxy", None)
+ https_proxy = os.environ.get("https_proxy", None)
+ no_proxy = os.environ.get("no_proxy", None)
+ proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None)
proxy_config = {
- 'http_proxy': http_proxy,
- 'https_proxy': https_proxy,
- 'proxy_ca_bundle_path': proxy_ca_bundle_path,
- 'no_proxy': no_proxy
+ "http_proxy": http_proxy,
+ "https_proxy": https_proxy,
+ "proxy_ca_bundle_path": proxy_ca_bundle_path,
+ "no_proxy": no_proxy,
}
- LOG.debug('Proxy configuration: %s', proxy_config)
+ LOG.debug("Proxy configuration: %s", proxy_config)
return proxy_config
class PackInstallController(ActionExecutionsControllerMixin):
-
def post(self, pack_install_request, requester_user=None):
parameters = {
- 'packs': pack_install_request.packs,
+ "packs": pack_install_request.packs,
}
if pack_install_request.force:
- parameters['force'] = True
+ parameters["force"] = True
if pack_install_request.skip_dependencies:
- parameters['skip_dependencies'] = True
+ parameters["skip_dependencies"] = True
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
- new_liveaction_api = LiveActionCreateAPI(action='packs.install',
- parameters=parameters,
- user=requester_user.name)
+ new_liveaction_api = LiveActionCreateAPI(
+ action="packs.install", parameters=parameters, user=requester_user.name
+ )
- execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api,
- requester_user=requester_user)
+ execution_resp = self._handle_schedule_execution(
+ liveaction_api=new_liveaction_api, requester_user=requester_user
+ )
- exec_id = PackAsyncAPI(execution_id=execution_resp.json['id'])
+ exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"])
return Response(json=exec_id, status=http_client.ACCEPTED)
class PackUninstallController(ActionExecutionsControllerMixin):
-
def post(self, pack_uninstall_request, ref_or_id=None, requester_user=None):
if ref_or_id:
- parameters = {
- 'packs': [ref_or_id]
- }
+ parameters = {"packs": [ref_or_id]}
else:
- parameters = {
- 'packs': pack_uninstall_request.packs
- }
+ parameters = {"packs": pack_uninstall_request.packs}
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
- new_liveaction_api = LiveActionCreateAPI(action='packs.uninstall',
- parameters=parameters,
- user=requester_user.name)
+ new_liveaction_api = LiveActionCreateAPI(
+ action="packs.uninstall", parameters=parameters, user=requester_user.name
+ )
- execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api,
- requester_user=requester_user)
+ execution_resp = self._handle_schedule_execution(
+ liveaction_api=new_liveaction_api, requester_user=requester_user
+ )
- exec_id = PackAsyncAPI(execution_id=execution_resp.json['id'])
+ exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"])
return Response(json=exec_id, status=http_client.ACCEPTED)
class PackRegisterController(object):
- CONTENT_TYPES = ['runner', 'action', 'trigger', 'sensor', 'rule',
- 'rule_type', 'alias', 'policy_type', 'policy', 'config']
+ CONTENT_TYPES = [
+ "runner",
+ "action",
+ "trigger",
+ "sensor",
+ "rule",
+ "rule_type",
+ "alias",
+ "policy_type",
+ "policy",
+ "config",
+ ]
def post(self, pack_register_request):
- if pack_register_request and hasattr(pack_register_request, 'types'):
+ if pack_register_request and hasattr(pack_register_request, "types"):
types = pack_register_request.types
- if 'all' in types:
+ if "all" in types:
types = PackRegisterController.CONTENT_TYPES
else:
types = PackRegisterController.CONTENT_TYPES
- if pack_register_request and hasattr(pack_register_request, 'packs'):
+ if pack_register_request and hasattr(pack_register_request, "packs"):
packs = list(set(pack_register_request.packs))
else:
packs = None
@@ -168,64 +172,80 @@ def post(self, pack_register_request):
result = defaultdict(int)
# Register depended resources (actions depend on runners, rules depend on rule types, etc)
- if ('runner' in types or 'runners' in types) or ('action' in types or 'actions' in types):
- result['runners'] = runners_registrar.register_runners(experimental=True)
- if ('rule_type' in types or 'rule_types' in types) or \
- ('rule' in types or 'rules' in types):
- result['rule_types'] = rule_types_registrar.register_rule_types()
- if ('policy_type' in types or 'policy_types' in types) or \
- ('policy' in types or 'policies' in types):
- result['policy_types'] = policies_registrar.register_policy_types(st2common)
+ if ("runner" in types or "runners" in types) or (
+ "action" in types or "actions" in types
+ ):
+ result["runners"] = runners_registrar.register_runners(experimental=True)
+ if ("rule_type" in types or "rule_types" in types) or (
+ "rule" in types or "rules" in types
+ ):
+ result["rule_types"] = rule_types_registrar.register_rule_types()
+ if ("policy_type" in types or "policy_types" in types) or (
+ "policy" in types or "policies" in types
+ ):
+ result["policy_types"] = policies_registrar.register_policy_types(st2common)
use_pack_cache = False
- fail_on_failure = getattr(pack_register_request, 'fail_on_failure', True)
+ fail_on_failure = getattr(pack_register_request, "fail_on_failure", True)
for type, (Registrar, name) in six.iteritems(ENTITIES):
if type in types or name in types:
- registrar = Registrar(use_pack_cache=use_pack_cache,
- use_runners_cache=True,
- fail_on_failure=fail_on_failure)
+ registrar = Registrar(
+ use_pack_cache=use_pack_cache,
+ use_runners_cache=True,
+ fail_on_failure=fail_on_failure,
+ )
if packs:
for pack in packs:
pack_path = content_utils.get_pack_base_path(pack)
try:
- registered_count = registrar.register_from_pack(pack_dir=pack_path)
+ registered_count = registrar.register_from_pack(
+ pack_dir=pack_path
+ )
result[name] += registered_count
except ValueError as e:
# Throw more user-friendly exception if requsted pack doesn't exist
- if re.match('Directory ".*?" doesn\'t exist', six.text_type(e)):
- msg = 'Pack "%s" not found on disk: %s' % (pack, six.text_type(e))
+ if re.match(
+ 'Directory ".*?" doesn\'t exist', six.text_type(e)
+ ):
+ msg = 'Pack "%s" not found on disk: %s' % (
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
raise e
else:
packs_base_paths = content_utils.get_packs_base_paths()
- registered_count = registrar.register_from_packs(base_dirs=packs_base_paths)
+ registered_count = registrar.register_from_packs(
+ base_dirs=packs_base_paths
+ )
result[name] += registered_count
return result
class PackSearchController(object):
-
def post(self, pack_search_request):
proxy_config = _get_proxy_config()
- if hasattr(pack_search_request, 'query'):
- packs = packs_service.search_pack_index(pack_search_request.query,
- case_sensitive=False,
- proxy_config=proxy_config)
+ if hasattr(pack_search_request, "query"):
+ packs = packs_service.search_pack_index(
+ pack_search_request.query,
+ case_sensitive=False,
+ proxy_config=proxy_config,
+ )
return [PackAPI(**pack) for pack in packs]
else:
- pack = packs_service.get_pack_from_index(pack_search_request.pack,
- proxy_config=proxy_config)
+ pack = packs_service.get_pack_from_index(
+ pack_search_request.pack, proxy_config=proxy_config
+ )
return PackAPI(**pack) if pack else []
class IndexHealthController(object):
-
def get(self):
"""
Check if all listed indexes are healthy: they should be reachable,
@@ -233,7 +253,9 @@ def get(self):
"""
proxy_config = _get_proxy_config()
- _, status = packs_service.fetch_pack_index(allow_empty=True, proxy_config=proxy_config)
+ _, status = packs_service.fetch_pack_index(
+ allow_empty=True, proxy_config=proxy_config
+ )
health = {
"indexes": {
@@ -249,13 +271,13 @@ def get(self):
}
for index in status:
- if index['error']:
- error_count = health['indexes']['errors'].get(index['error'], 0) + 1
- health['indexes']['invalid'] += 1
- health['indexes']['errors'][index['error']] = error_count
+ if index["error"]:
+ error_count = health["indexes"]["errors"].get(index["error"], 0) + 1
+ health["indexes"]["invalid"] += 1
+ health["indexes"]["errors"][index["error"]] = error_count
else:
- health['indexes']['valid'] += 1
- health['packs']['count'] += index['packs']
+ health["indexes"]["valid"] += 1
+ health["packs"]["count"] += index["packs"]
return health
@@ -265,12 +287,16 @@ class BasePacksController(ResourceController):
access = Pack
def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None):
- instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields)
+ instance = self._get_by_ref_or_id(
+ ref_or_id=ref_or_id, exclude_fields=exclude_fields
+ )
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=PermissionType.PACK_VIEW)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=PermissionType.PACK_VIEW,
+ )
if not instance:
msg = 'Unable to identify resource with ref_or_id "%s".' % (ref_or_id)
@@ -282,7 +308,9 @@ def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None):
return result
def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None):
- resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields)
+ resource_db = self._get_by_id(
+ resource_id=ref_or_id, exclude_fields=exclude_fields
+ )
if not resource_db:
# Try ref
@@ -302,7 +330,7 @@ def _get_by_ref(self, ref, exclude_fields=None):
return resource_db
-class PacksIndexController():
+class PacksIndexController:
search = PackSearchController()
health = IndexHealthController()
@@ -311,10 +339,7 @@ def get_all(self):
index, status = packs_service.fetch_pack_index(proxy_config=proxy_config)
- return {
- 'status': status,
- 'index': index
- }
+ return {"status": status, "index": index}
class PacksController(BasePacksController):
@@ -322,14 +347,9 @@ class PacksController(BasePacksController):
model = PackAPI
access = Pack
- supported_filters = {
- 'name': 'name',
- 'ref': 'ref'
- }
+ supported_filters = {"name": "name", "ref": "ref"}
- query_options = {
- 'sort': ['ref']
- }
+ query_options = {"sort": ["ref"]}
# Nested controllers
install = PackInstallController()
@@ -342,18 +362,30 @@ def __init__(self):
super(PacksController, self).__init__()
self.get_one_db_method = self._get_by_ref_or_id
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return super(PacksController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(PacksController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user):
- return self._get_one_by_ref_or_id(ref_or_id=ref_or_id, requester_user=requester_user)
+ return self._get_one_by_ref_or_id(
+ ref_or_id=ref_or_id, requester_user=requester_user
+ )
packs_controller = PacksController()
diff --git a/st2api/st2api/controllers/v1/policies.py b/st2api/st2api/controllers/v1/policies.py
index 3fc488708b..aa57b7cf3d 100644
--- a/st2api/st2api/controllers/v1/policies.py
+++ b/st2api/st2api/controllers/v1/policies.py
@@ -37,54 +37,73 @@ class PolicyTypeController(resource.ResourceController):
model = PolicyTypeAPI
access = PolicyType
- mandatory_include_fields_retrieve = ['id', 'name', 'resource_type']
+ mandatory_include_fields_retrieve = ["id", "name", "resource_type"]
- supported_filters = {
- 'resource_type': 'resource_type'
- }
+ supported_filters = {"resource_type": "resource_type"}
- query_options = {
- 'sort': ['resource_type', 'name']
- }
+ query_options = {"sort": ["resource_type", "name"]}
def get_one(self, ref_or_id, requester_user):
return self._get_one(ref_or_id, requester_user=requester_user)
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return self._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return self._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def _get_one(self, ref_or_id, requester_user):
instance = self._get_by_ref_or_id(ref_or_id=ref_or_id)
permission_type = PermissionType.POLICY_TYPE_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=instance,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=instance,
+ permission_type=permission_type,
+ )
result = self.model.from_model(instance)
return result
- def _get_all(self, exclude_fields=None, include_fields=None, sort=None, offset=0, limit=None,
- query_options=None, from_model_kwargs=None, raw_filters=None,
- requester_user=None):
-
- resp = super(PolicyTypeController, self)._get_all(exclude_fields=exclude_fields,
- include_fields=include_fields,
- sort=sort,
- offset=offset,
- limit=limit,
- query_options=query_options,
- from_model_kwargs=from_model_kwargs,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def _get_all(
+ self,
+ exclude_fields=None,
+ include_fields=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ query_options=None,
+ from_model_kwargs=None,
+ raw_filters=None,
+ requester_user=None,
+ ):
+
+ resp = super(PolicyTypeController, self)._get_all(
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ query_options=query_options,
+ from_model_kwargs=from_model_kwargs,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
return resp
@@ -114,7 +133,9 @@ def _get_by_ref(self, resource_ref):
except Exception:
return None
- resource_db = self.access.query(name=ref.name, resource_type=ref.resource_type).first()
+ resource_db = self.access.query(
+ name=ref.name, resource_type=ref.resource_type
+ ).first()
return resource_db
@@ -123,77 +144,93 @@ class PolicyController(resource.ContentPackResourceController):
access = Policy
supported_filters = {
- 'pack': 'pack',
- 'resource_ref': 'resource_ref',
- 'policy_type': 'policy_type'
- }
-
- query_options = {
- 'sort': ['pack', 'name']
+ "pack": "pack",
+ "resource_ref": "resource_ref",
+ "policy_type": "policy_type",
}
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return self._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ query_options = {"sort": ["pack", "name"]}
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return self._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user):
permission_type = PermissionType.POLICY_VIEW
- return self._get_one(ref_or_id, permission_type=permission_type,
- requester_user=requester_user)
+ return self._get_one(
+ ref_or_id, permission_type=permission_type, requester_user=requester_user
+ )
def post(self, instance, requester_user):
"""
- Create a new policy.
- Handles requests:
- POST /policies/
+ Create a new policy.
+ Handles requests:
+ POST /policies/
"""
permission_type = PermissionType.POLICY_CREATE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user,
- resource_api=instance,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_api_permission(
+ user_db=requester_user,
+ resource_api=instance,
+ permission_type=permission_type,
+ )
- op = 'POST /policies/'
+ op = "POST /policies/"
db_model = self.model.to_model(instance)
- LOG.debug('%s verified object: %s', op, db_model)
+ LOG.debug("%s verified object: %s", op, db_model)
db_model = self.access.add_or_update(db_model)
- LOG.debug('%s created object: %s', op, db_model)
- LOG.audit('Policy created. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model})
+ LOG.debug("%s created object: %s", op, db_model)
+ LOG.audit(
+ "Policy created. Policy.id=%s" % (db_model.id),
+ extra={"policy_db": db_model},
+ )
exec_result = self.model.from_model(db_model)
return Response(json=exec_result, status=http_client.CREATED)
def put(self, instance, ref_or_id, requester_user):
- op = 'PUT /policies/%s/' % ref_or_id
+ op = "PUT /policies/%s/" % ref_or_id
db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id)
- LOG.debug('%s found object: %s', op, db_model)
+ LOG.debug("%s found object: %s", op, db_model)
permission_type = PermissionType.POLICY_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=db_model,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=db_model,
+ permission_type=permission_type,
+ )
db_model_id = db_model.id
try:
validate_not_part_of_system_pack(db_model)
except ValueValidationException as e:
- LOG.exception('%s unable to update object from system pack.', op)
+ LOG.exception("%s unable to update object from system pack.", op)
abort(http_client.BAD_REQUEST, six.text_type(e))
- if not getattr(instance, 'pack', None):
+ if not getattr(instance, "pack", None):
instance.pack = db_model.pack
try:
@@ -201,12 +238,15 @@ def put(self, instance, ref_or_id, requester_user):
db_model.id = db_model_id
db_model = self.access.add_or_update(db_model)
except (ValidationError, ValueError) as e:
- LOG.exception('%s unable to update object: %s', op, db_model)
+ LOG.exception("%s unable to update object: %s", op, db_model)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- LOG.debug('%s updated object: %s', op, db_model)
- LOG.audit('Policy updated. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model})
+ LOG.debug("%s updated object: %s", op, db_model)
+ LOG.audit(
+ "Policy updated. Policy.id=%s" % (db_model.id),
+ extra={"policy_db": db_model},
+ )
exec_result = self.model.from_model(db_model)
@@ -214,38 +254,43 @@ def put(self, instance, ref_or_id, requester_user):
def delete(self, ref_or_id, requester_user):
"""
- Delete a policy.
- Handles requests:
- POST /policies/1?_method=delete
- DELETE /policies/1
- DELETE /policies/mypack.mypolicy
+ Delete a policy.
+ Handles requests:
+ POST /policies/1?_method=delete
+ DELETE /policies/1
+ DELETE /policies/mypack.mypolicy
"""
- op = 'DELETE /policies/%s/' % ref_or_id
+ op = "DELETE /policies/%s/" % ref_or_id
db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id)
- LOG.debug('%s found object: %s', op, db_model)
+ LOG.debug("%s found object: %s", op, db_model)
permission_type = PermissionType.POLICY_DELETE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=db_model,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=db_model,
+ permission_type=permission_type,
+ )
try:
validate_not_part_of_system_pack(db_model)
except ValueValidationException as e:
- LOG.exception('%s unable to delete object from system pack.', op)
+ LOG.exception("%s unable to delete object from system pack.", op)
abort(http_client.BAD_REQUEST, six.text_type(e))
try:
self.access.delete(db_model)
except Exception as e:
- LOG.exception('%s unable to delete object: %s', op, db_model)
+ LOG.exception("%s unable to delete object: %s", op, db_model)
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
- LOG.debug('%s deleted object: %s', op, db_model)
- LOG.audit('Policy deleted. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model})
+ LOG.debug("%s deleted object: %s", op, db_model)
+ LOG.audit(
+ "Policy deleted. Policy.id=%s" % (db_model.id),
+ extra={"policy_db": db_model},
+ )
# return None
return Response(status=http_client.NO_CONTENT)
diff --git a/st2api/st2api/controllers/v1/rbac.py b/st2api/st2api/controllers/v1/rbac.py
index 0e8c1d4179..49a552f7dc 100644
--- a/st2api/st2api/controllers/v1/rbac.py
+++ b/st2api/st2api/controllers/v1/rbac.py
@@ -23,78 +23,76 @@
from st2common.rbac.backends import get_rbac_backend
from st2common.router import exc
-__all__ = [
- 'RolesController',
- 'RoleAssignmentsController',
- 'PermissionTypesController'
-]
+__all__ = ["RolesController", "RoleAssignmentsController", "PermissionTypesController"]
class RolesController(ResourceController):
model = RoleAPI
access = Role
- supported_filters = {
- 'name': 'name',
- 'system': 'system'
- }
+ supported_filters = {"name": "name", "system": "system"}
- query_options = {
- 'sort': ['name']
- }
+ query_options = {"sort": ["name"]}
def get_one(self, name_or_id, requester_user):
rbac_utils = get_rbac_backend().get_utils_class()
rbac_utils.assert_user_is_admin(user_db=requester_user)
- return self._get_one_by_name_or_id(name_or_id=name_or_id,
- permission_type=None,
- requester_user=requester_user)
+ return self._get_one_by_name_or_id(
+ name_or_id=name_or_id, permission_type=None, requester_user=requester_user
+ )
def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters):
rbac_utils = get_rbac_backend().get_utils_class()
rbac_utils.assert_user_is_admin(user_db=requester_user)
- return self._get_all(sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ return self._get_all(
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
class RoleAssignmentsController(ResourceController):
"""
Meta controller for listing role assignments.
"""
+
model = UserRoleAssignmentAPI
access = UserRoleAssignment
supported_filters = {
- 'user': 'user',
- 'role': 'role',
- 'source': 'source',
- 'remote': 'is_remote'
+ "user": "user",
+ "role": "role",
+ "source": "source",
+ "remote": "is_remote",
}
def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters):
- user = raw_filters.get('user', None)
+ user = raw_filters.get("user", None)
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user,
- user=user)
-
- return self._get_all(sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ rbac_utils.assert_user_is_admin_or_operating_on_own_resource(
+ user_db=requester_user, user=user
+ )
+
+ return self._get_all(
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, id, requester_user):
- result = self._get_one_by_id(id,
- requester_user=requester_user,
- permission_type=None)
- user = getattr(result, 'user', None)
+ result = self._get_one_by_id(
+ id, requester_user=requester_user, permission_type=None
+ )
+ user = getattr(result, "user", None)
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user,
- user=user)
+ rbac_utils.assert_user_is_admin_or_operating_on_own_resource(
+ user_db=requester_user, user=user
+ )
return result
@@ -106,10 +104,10 @@ class PermissionTypesController(object):
def get_all(self, requester_user):
"""
- List all the available permission types.
+ List all the available permission types.
- Handles requests:
- GET /rbac/permission_types
+ Handles requests:
+ GET /rbac/permission_types
"""
rbac_utils = get_rbac_backend().get_utils_class()
rbac_utils.assert_user_is_admin(user_db=requester_user)
@@ -119,10 +117,10 @@ def get_all(self, requester_user):
def get_one(self, resource_type, requester_user):
"""
- List all the available permission types for a particular resource type.
+ List all the available permission types for a particular resource type.
- Handles requests:
- GET /rbac/permission_types/
+ Handles requests:
+ GET /rbac/permission_types/
"""
rbac_utils = get_rbac_backend().get_utils_class()
rbac_utils.assert_user_is_admin(user_db=requester_user)
@@ -131,7 +129,7 @@ def get_one(self, resource_type, requester_user):
permission_types = all_permission_types.get(resource_type, None)
if permission_types is None:
- raise exc.HTTPNotFound('Invalid resource type: %s' % (resource_type))
+ raise exc.HTTPNotFound("Invalid resource type: %s" % (resource_type))
return permission_types
diff --git a/st2api/st2api/controllers/v1/rule_enforcement_views.py b/st2api/st2api/controllers/v1/rule_enforcement_views.py
index 3d23d027a9..75831a917b 100644
--- a/st2api/st2api/controllers/v1/rule_enforcement_views.py
+++ b/st2api/st2api/controllers/v1/rule_enforcement_views.py
@@ -26,9 +26,7 @@
from st2api.controllers.resource import ResourceController
-__all__ = [
- 'RuleEnforcementViewController'
-]
+__all__ = ["RuleEnforcementViewController"]
class RuleEnforcementViewController(ResourceController):
@@ -50,8 +48,16 @@ class RuleEnforcementViewController(ResourceController):
supported_filters = SUPPORTED_FILTERS
filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
rule_enforcement_apis = super(RuleEnforcementViewController, self)._get_all(
exclude_fields=exclude_attributes,
include_fields=include_attributes,
@@ -59,16 +65,25 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o
offset=offset,
limit=limit,
raw_filters=raw_filters,
- requester_user=requester_user)
+ requester_user=requester_user,
+ )
- rule_enforcement_apis.json = self._append_view_properties(rule_enforcement_apis.json)
+ rule_enforcement_apis.json = self._append_view_properties(
+ rule_enforcement_apis.json
+ )
return rule_enforcement_apis
def get_one(self, id, requester_user):
- rule_enforcement_api = super(RuleEnforcementViewController,
- self)._get_one_by_id(id, requester_user=requester_user,
- permission_type=PermissionType.RULE_ENFORCEMENT_VIEW)
- rule_enforcement_api = self._append_view_properties([rule_enforcement_api.__json__()])[0]
+ rule_enforcement_api = super(
+ RuleEnforcementViewController, self
+ )._get_one_by_id(
+ id,
+ requester_user=requester_user,
+ permission_type=PermissionType.RULE_ENFORCEMENT_VIEW,
+ )
+ rule_enforcement_api = self._append_view_properties(
+ [rule_enforcement_api.__json__()]
+ )[0]
return rule_enforcement_api
def _append_view_properties(self, rule_enforcement_apis):
@@ -80,29 +95,29 @@ def _append_view_properties(self, rule_enforcement_apis):
execution_ids = []
for rule_enforcement_api in rule_enforcement_apis:
- if rule_enforcement_api.get('trigger_instance_id', None):
- trigger_instance_ids.add(str(rule_enforcement_api['trigger_instance_id']))
+ if rule_enforcement_api.get("trigger_instance_id", None):
+ trigger_instance_ids.add(
+ str(rule_enforcement_api["trigger_instance_id"])
+ )
- if rule_enforcement_api.get('execution_id', None):
- execution_ids.append(rule_enforcement_api['execution_id'])
+ if rule_enforcement_api.get("execution_id", None):
+ execution_ids.append(rule_enforcement_api["execution_id"])
# 1. Retrieve corresponding execution objects
# NOTE: Executions contain a lot of field and could contain a lot of data so we only
# retrieve fields we need
only_fields = [
- 'id',
-
- 'action.ref',
- 'action.parameters',
-
- 'runner.name',
- 'runner.runner_parameters',
-
- 'parameters',
- 'status'
+ "id",
+ "action.ref",
+ "action.parameters",
+ "runner.name",
+ "runner.runner_parameters",
+ "parameters",
+ "status",
]
- execution_dbs = ActionExecution.query(id__in=execution_ids,
- only_fields=only_fields)
+ execution_dbs = ActionExecution.query(
+ id__in=execution_ids, only_fields=only_fields
+ )
execution_dbs_by_id = {}
for execution_db in execution_dbs:
@@ -114,26 +129,32 @@ def _append_view_properties(self, rule_enforcement_apis):
trigger_instance_dbs_by_id = {}
for trigger_instance_db in trigger_instance_dbs:
- trigger_instance_dbs_by_id[str(trigger_instance_db.id)] = trigger_instance_db
+ trigger_instance_dbs_by_id[
+ str(trigger_instance_db.id)
+ ] = trigger_instance_db
# Ammend rule enforcement objects with additional data
for rule_enforcement_api in rule_enforcement_apis:
- rule_enforcement_api['trigger_instance'] = {}
- rule_enforcement_api['execution'] = {}
+ rule_enforcement_api["trigger_instance"] = {}
+ rule_enforcement_api["execution"] = {}
- trigger_instance_id = rule_enforcement_api.get('trigger_instance_id', None)
- execution_id = rule_enforcement_api.get('execution_id', None)
+ trigger_instance_id = rule_enforcement_api.get("trigger_instance_id", None)
+ execution_id = rule_enforcement_api.get("execution_id", None)
- trigger_instance_db = trigger_instance_dbs_by_id.get(trigger_instance_id, None)
+ trigger_instance_db = trigger_instance_dbs_by_id.get(
+ trigger_instance_id, None
+ )
execution_db = execution_dbs_by_id.get(execution_id, None)
if trigger_instance_db:
- trigger_instance_api = TriggerInstanceAPI.from_model(trigger_instance_db)
- rule_enforcement_api['trigger_instance'] = trigger_instance_api
+ trigger_instance_api = TriggerInstanceAPI.from_model(
+ trigger_instance_db
+ )
+ rule_enforcement_api["trigger_instance"] = trigger_instance_api
if execution_db:
execution_api = ActionExecutionAPI.from_model(execution_db)
- rule_enforcement_api['execution'] = execution_api
+ rule_enforcement_api["execution"] = execution_api
return rule_enforcement_apis
diff --git a/st2api/st2api/controllers/v1/rule_enforcements.py b/st2api/st2api/controllers/v1/rule_enforcements.py
index 1c117558ca..f1c1f4c5b7 100644
--- a/st2api/st2api/controllers/v1/rule_enforcements.py
+++ b/st2api/st2api/controllers/v1/rule_enforcements.py
@@ -24,11 +24,10 @@
from st2api.controllers.resource import ResourceController
__all__ = [
- 'RuleEnforcementController',
-
- 'SUPPORTED_FILTERS',
- 'QUERY_OPTIONS',
- 'FILTER_TRANSFORM_FUNCTIONS'
+ "RuleEnforcementController",
+ "SUPPORTED_FILTERS",
+ "QUERY_OPTIONS",
+ "FILTER_TRANSFORM_FUNCTIONS",
]
@@ -38,23 +37,21 @@
SUPPORTED_FILTERS = {
- 'rule_ref': 'rule.ref',
- 'rule_id': 'rule.id',
- 'execution': 'execution_id',
- 'trigger_instance': 'trigger_instance_id',
- 'enforced_at': 'enforced_at',
- 'enforced_at_gt': 'enforced_at.gt',
- 'enforced_at_lt': 'enforced_at.lt'
+ "rule_ref": "rule.ref",
+ "rule_id": "rule.id",
+ "execution": "execution_id",
+ "trigger_instance": "trigger_instance_id",
+ "enforced_at": "enforced_at",
+ "enforced_at_gt": "enforced_at.gt",
+ "enforced_at_lt": "enforced_at.lt",
}
-QUERY_OPTIONS = {
- 'sort': ['-enforced_at', 'rule.ref']
-}
+QUERY_OPTIONS = {"sort": ["-enforced_at", "rule.ref"]}
FILTER_TRANSFORM_FUNCTIONS = {
- 'enforced_at': lambda value: isotime.parse(value=value),
- 'enforced_at_gt': lambda value: isotime.parse(value=value),
- 'enforced_at_lt': lambda value: isotime.parse(value=value)
+ "enforced_at": lambda value: isotime.parse(value=value),
+ "enforced_at_gt": lambda value: isotime.parse(value=value),
+ "enforced_at_lt": lambda value: isotime.parse(value=value),
}
@@ -69,20 +66,32 @@ class RuleEnforcementController(ResourceController):
supported_filters = SUPPORTED_FILTERS
filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return super(RuleEnforcementController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(RuleEnforcementController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, id, requester_user):
- return super(RuleEnforcementController,
- self)._get_one_by_id(id, requester_user=requester_user,
- permission_type=PermissionType.RULE_ENFORCEMENT_VIEW)
+ return super(RuleEnforcementController, self)._get_one_by_id(
+ id,
+ requester_user=requester_user,
+ permission_type=PermissionType.RULE_ENFORCEMENT_VIEW,
+ )
rule_enforcements_controller = RuleEnforcementController()
diff --git a/st2api/st2api/controllers/v1/rule_views.py b/st2api/st2api/controllers/v1/rule_views.py
index 70555149b7..39b4682c52 100644
--- a/st2api/st2api/controllers/v1/rule_views.py
+++ b/st2api/st2api/controllers/v1/rule_views.py
@@ -32,10 +32,12 @@
LOG = logging.getLogger(__name__)
-__all__ = ['RuleViewController']
+__all__ = ["RuleViewController"]
-class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResourceController):
+class RuleViewController(
+ BaseResourceIsolationControllerMixin, ContentPackResourceController
+):
"""
Add some extras to a Rule object to make it easier for UI to render a rule. The additions
do not necessarily belong in the Rule itself but are still valuable augmentations.
@@ -74,64 +76,78 @@ class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResour
model = RuleViewAPI
access = Rule
- supported_filters = {
- 'name': 'name',
- 'pack': 'pack',
- 'user': 'context.user'
- }
-
- query_options = {
- 'sort': ['pack', 'name']
- }
-
- mandatory_include_fields_retrieve = ['pack', 'name', 'trigger']
-
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- rules = super(RuleViewController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ supported_filters = {"name": "name", "pack": "pack", "user": "context.user"}
+
+ query_options = {"sort": ["pack", "name"]}
+
+ mandatory_include_fields_retrieve = ["pack", "name", "trigger"]
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ rules = super(RuleViewController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
result = self._append_view_properties(rules.json)
rules.json = result
return rules
def get_one(self, ref_or_id, requester_user):
- from_model_kwargs = {'mask_secrets': True}
- rule = self._get_one(ref_or_id, permission_type=PermissionType.RULE_VIEW,
- requester_user=requester_user, from_model_kwargs=from_model_kwargs)
+ from_model_kwargs = {"mask_secrets": True}
+ rule = self._get_one(
+ ref_or_id,
+ permission_type=PermissionType.RULE_VIEW,
+ requester_user=requester_user,
+ from_model_kwargs=from_model_kwargs,
+ )
result = self._append_view_properties([rule.json])[0]
rule.json = result
return rule
def _append_view_properties(self, rules):
- action_by_refs, trigger_by_refs, trigger_type_by_refs = self._get_referenced_models(rules)
+ (
+ action_by_refs,
+ trigger_by_refs,
+ trigger_type_by_refs,
+ ) = self._get_referenced_models(rules)
for rule in rules:
- action_ref = rule.get('action', {}).get('ref', None)
- trigger_ref = rule.get('trigger', {}).get('ref', None)
- trigger_type_ref = rule.get('trigger', {}).get('type', None)
+ action_ref = rule.get("action", {}).get("ref", None)
+ trigger_ref = rule.get("trigger", {}).get("ref", None)
+ trigger_type_ref = rule.get("trigger", {}).get("type", None)
action_db = action_by_refs.get(action_ref, None)
- if 'action' in rule:
- rule['action']['description'] = action_db.description if action_db else ''
+ if "action" in rule:
+ rule["action"]["description"] = (
+ action_db.description if action_db else ""
+ )
- if 'trigger' in rule:
- rule['trigger']['description'] = ''
+ if "trigger" in rule:
+ rule["trigger"]["description"] = ""
trigger_db = trigger_by_refs.get(trigger_ref, None)
if trigger_db:
- rule['trigger']['description'] = trigger_db.description
+ rule["trigger"]["description"] = trigger_db.description
# If description is not found in trigger get description from TriggerType
- if 'trigger' in rule and not rule['trigger']['description']:
+ if "trigger" in rule and not rule["trigger"]["description"]:
trigger_type_db = trigger_type_by_refs.get(trigger_type_ref, None)
if trigger_type_db:
- rule['trigger']['description'] = trigger_type_db.description
+ rule["trigger"]["description"] = trigger_type_db.description
return rules
@@ -145,9 +161,9 @@ def _get_referenced_models(self, rules):
trigger_type_refs = set()
for rule in rules:
- action_ref = rule.get('action', {}).get('ref', None)
- trigger_ref = rule.get('trigger', {}).get('ref', None)
- trigger_type_ref = rule.get('trigger', {}).get('type', None)
+ action_ref = rule.get("action", {}).get("ref", None)
+ trigger_ref = rule.get("trigger", {}).get("ref", None)
+ trigger_type_ref = rule.get("trigger", {}).get("type", None)
if action_ref:
action_refs.add(action_ref)
@@ -164,27 +180,31 @@ def _get_referenced_models(self, rules):
# The functions that will return args that can used to query.
def ref_query_args(ref):
- return {'ref': ref}
+ return {"ref": ref}
def name_pack_query_args(ref):
resource_ref = ResourceReference.from_string_reference(ref=ref)
- return {'name': resource_ref.name, 'pack': resource_ref.pack}
+ return {"name": resource_ref.name, "pack": resource_ref.pack}
- action_dbs = self._get_entities(model_persistence=Action,
- refs=action_refs,
- query_args=ref_query_args)
+ action_dbs = self._get_entities(
+ model_persistence=Action, refs=action_refs, query_args=ref_query_args
+ )
for action_db in action_dbs:
action_by_refs[action_db.ref] = action_db
- trigger_dbs = self._get_entities(model_persistence=Trigger,
- refs=trigger_refs,
- query_args=name_pack_query_args)
+ trigger_dbs = self._get_entities(
+ model_persistence=Trigger,
+ refs=trigger_refs,
+ query_args=name_pack_query_args,
+ )
for trigger_db in trigger_dbs:
trigger_by_refs[trigger_db.get_reference().ref] = trigger_db
- trigger_type_dbs = self._get_entities(model_persistence=TriggerType,
- refs=trigger_type_refs,
- query_args=name_pack_query_args)
+ trigger_type_dbs = self._get_entities(
+ model_persistence=TriggerType,
+ refs=trigger_type_refs,
+ query_args=name_pack_query_args,
+ )
for trigger_type_db in trigger_type_dbs:
trigger_type_by_refs[trigger_type_db.get_reference().ref] = trigger_type_db
diff --git a/st2api/st2api/controllers/v1/rules.py b/st2api/st2api/controllers/v1/rules.py
index 5904f9140e..89f9e63531 100644
--- a/st2api/st2api/controllers/v1/rules.py
+++ b/st2api/st2api/controllers/v1/rules.py
@@ -34,124 +34,149 @@
from st2common.router import exc
from st2common.router import abort
from st2common.router import Response
-from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count
+from st2common.services.triggers import (
+ cleanup_trigger_db_for_rule,
+ increment_trigger_ref_count,
+)
http_client = six.moves.http_client
LOG = logging.getLogger(__name__)
-class RuleController(BaseRestControllerMixin, BaseResourceIsolationControllerMixin,
- ContentPackResourceController):
+class RuleController(
+ BaseRestControllerMixin,
+ BaseResourceIsolationControllerMixin,
+ ContentPackResourceController,
+):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of Rules in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of Rules in the system.
"""
+
views = RuleViewController()
model = RuleAPI
access = Rule
supported_filters = {
- 'name': 'name',
- 'pack': 'pack',
- 'action': 'action.ref',
- 'trigger': 'trigger',
- 'enabled': 'enabled',
- 'user': 'context.user'
+ "name": "name",
+ "pack": "pack",
+ "action": "action.ref",
+ "trigger": "trigger",
+ "enabled": "enabled",
+ "user": "context.user",
}
- filter_transform_functions = {
- 'enabled': transform_to_bool
- }
+ filter_transform_functions = {"enabled": transform_to_bool}
- query_options = {
- 'sort': ['pack', 'name']
- }
+ query_options = {"sort": ["pack", "name"]}
- mandatory_include_fields_retrieve = ['pack', 'name', 'trigger']
+ mandatory_include_fields_retrieve = ["pack", "name", "trigger"]
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, show_secrets=False, requester_user=None, **raw_filters):
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ show_secrets=False,
+ requester_user=None,
+ **raw_filters,
+ ):
from_model_kwargs = {
- 'ignore_missing_trigger': True,
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "ignore_missing_trigger": True,
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ ),
}
- return super(RuleController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- from_model_kwargs=from_model_kwargs,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ return super(RuleController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ from_model_kwargs=from_model_kwargs,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user, show_secrets=False):
from_model_kwargs = {
- 'ignore_missing_trigger': True,
- 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets)
+ "ignore_missing_trigger": True,
+ "mask_secrets": self._get_mask_secrets(
+ requester_user, show_secrets=show_secrets
+ ),
}
- return super(RuleController, self)._get_one(ref_or_id, from_model_kwargs=from_model_kwargs,
- requester_user=requester_user,
- permission_type=PermissionType.RULE_VIEW)
+ return super(RuleController, self)._get_one(
+ ref_or_id,
+ from_model_kwargs=from_model_kwargs,
+ requester_user=requester_user,
+ permission_type=PermissionType.RULE_VIEW,
+ )
def post(self, rule, requester_user):
"""
- Create a new rule.
+ Create a new rule.
- Handles requests:
- POST /rules/
+ Handles requests:
+ POST /rules/
"""
rbac_utils = get_rbac_backend().get_utils_class()
permission_type = PermissionType.RULE_CREATE
- rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user,
- resource_api=rule,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_api_permission(
+ user_db=requester_user, resource_api=rule, permission_type=permission_type
+ )
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
# Validate that the authenticated user is admin if user query param is provided
user = requester_user.name
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user
+ )
- if not hasattr(rule, 'context'):
+ if not hasattr(rule, "context"):
rule.context = dict()
- rule.context['user'] = user
+ rule.context["user"] = user
try:
rule_db = RuleAPI.to_model(rule)
- LOG.debug('/rules/ POST verified RuleAPI and formulated RuleDB=%s', rule_db)
+ LOG.debug("/rules/ POST verified RuleAPI and formulated RuleDB=%s", rule_db)
# Check referenced trigger and action permissions
# Note: This needs to happen after "to_model" call since to_model performs some
# validation (trigger exists, etc.)
- rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user,
- rule_api=rule)
+ rbac_utils.assert_user_has_rule_trigger_and_action_permission(
+ user_db=requester_user, rule_api=rule
+ )
rule_db = Rule.add_or_update(rule_db)
# After the rule has been added modify the ref_count. This way a failure to add
# the rule due to violated constraints will have no impact on ref_count.
increment_trigger_ref_count(rule_api=rule)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for rule data=%s.', rule)
+ LOG.exception("Validation failed for rule data=%s.", rule)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
except (ValueValidationException, jsonschema.ValidationError) as e:
- LOG.exception('Validation failed for rule data=%s.', rule)
+ LOG.exception("Validation failed for rule data=%s.", rule)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
except TriggerDoesNotExistException:
- msg = ('Trigger "%s" defined in the rule does not exist in system or it\'s missing '
- 'required "parameters" attribute' % (rule.trigger['type']))
+ msg = (
+ 'Trigger "%s" defined in the rule does not exist in system or it\'s missing '
+ 'required "parameters" attribute' % (rule.trigger["type"])
+ )
LOG.exception(msg)
abort(http_client.BAD_REQUEST, msg)
return
- extra = {'rule_db': rule_db}
- LOG.audit('Rule created. Rule.id=%s' % (rule_db.id), extra=extra)
+ extra = {"rule_db": rule_db}
+ LOG.audit("Rule created. Rule.id=%s" % (rule_db.id), extra=extra)
rule_api = RuleAPI.from_model(rule_db)
return Response(json=rule_api, status=exc.HTTPCreated.code)
@@ -161,27 +186,33 @@ def put(self, rule, rule_ref_or_id, requester_user):
rbac_utils = get_rbac_backend().get_utils_class()
permission_type = PermissionType.RULE_MODIFY
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=rule,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user, resource_db=rule, permission_type=permission_type
+ )
- LOG.debug('PUT /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db)
+ LOG.debug(
+ "PUT /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db
+ )
if not requester_user:
requester_user = UserDB(cfg.CONF.system_user.user)
# Validate that the authenticated user is admin if user query param is provided
user = requester_user.name
- rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user,
- user=user)
+ rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(
+ user_db=requester_user, user=user
+ )
- if not hasattr(rule, 'context'):
+ if not hasattr(rule, "context"):
rule.context = dict()
- rule.context['user'] = user
+ rule.context["user"] = user
try:
- if rule.id is not None and rule.id != '' and rule.id != rule_ref_or_id:
- LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.',
- rule.id, rule_ref_or_id)
+ if rule.id is not None and rule.id != "" and rule.id != rule_ref_or_id:
+ LOG.warning(
+ "Discarding mismatched id=%s found in payload and using uri_id=%s.",
+ rule.id,
+ rule_ref_or_id,
+ )
old_rule_db = rule_db
try:
@@ -193,8 +224,9 @@ def put(self, rule, rule_ref_or_id, requester_user):
# Check referenced trigger and action permissions
# Note: This needs to happen after "to_model" call since to_model performs some
# validation (trigger exists, etc.)
- rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user,
- rule_api=rule)
+ rbac_utils.assert_user_has_rule_trigger_and_action_permission(
+ user_db=requester_user, rule_api=rule
+ )
rule_db.id = rule_ref_or_id
rule_db = Rule.add_or_update(rule_db)
@@ -202,48 +234,52 @@ def put(self, rule, rule_ref_or_id, requester_user):
# the rule due to violated constraints will have no impact on ref_count.
increment_trigger_ref_count(rule_api=rule)
except (ValueValidationException, jsonschema.ValidationError, ValueError) as e:
- LOG.exception('Validation failed for rule data=%s', rule)
+ LOG.exception("Validation failed for rule data=%s", rule)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
# use old_rule_db for cleanup.
cleanup_trigger_db_for_rule(old_rule_db)
- extra = {'old_rule_db': old_rule_db, 'new_rule_db': rule_db}
- LOG.audit('Rule updated. Rule.id=%s.' % (rule_db.id), extra=extra)
+ extra = {"old_rule_db": old_rule_db, "new_rule_db": rule_db}
+ LOG.audit("Rule updated. Rule.id=%s." % (rule_db.id), extra=extra)
rule_api = RuleAPI.from_model(rule_db)
return rule_api
def delete(self, rule_ref_or_id, requester_user):
"""
- Delete a rule.
+ Delete a rule.
- Handles requests:
- DELETE /rules/1
+ Handles requests:
+ DELETE /rules/1
"""
rule_db = self._get_by_ref_or_id(ref_or_id=rule_ref_or_id)
permission_type = PermissionType.RULE_DELETE
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=rule_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user, resource_db=rule_db, permission_type=permission_type
+ )
- LOG.debug('DELETE /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db)
+ LOG.debug(
+ "DELETE /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db
+ )
try:
Rule.delete(rule_db)
except Exception as e:
- LOG.exception('Database delete encountered exception during delete of id="%s".',
- rule_ref_or_id)
+ LOG.exception(
+ 'Database delete encountered exception during delete of id="%s".',
+ rule_ref_or_id,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
# use old_rule_db for cleanup.
cleanup_trigger_db_for_rule(rule_db)
- extra = {'rule_db': rule_db}
- LOG.audit('Rule deleted. Rule.id=%s.' % (rule_db.id), extra=extra)
+ extra = {"rule_db": rule_db}
+ LOG.audit("Rule deleted. Rule.id=%s." % (rule_db.id), extra=extra)
return Response(status=http_client.NO_CONTENT)
diff --git a/st2api/st2api/controllers/v1/ruletypes.py b/st2api/st2api/controllers/v1/ruletypes.py
index dcf62069ef..267c192534 100644
--- a/st2api/st2api/controllers/v1/ruletypes.py
+++ b/st2api/st2api/controllers/v1/ruletypes.py
@@ -28,8 +28,8 @@
class RuleTypesController(object):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of a RuleType in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of a RuleType in the system.
"""
@staticmethod
@@ -46,15 +46,17 @@ def __get_by_name(name):
try:
return [RuleType.get_by_name(name)]
except ValueError as e:
- LOG.debug('Database lookup for name="%s" resulted in exception : %s.', name, e)
+ LOG.debug(
+ 'Database lookup for name="%s" resulted in exception : %s.', name, e
+ )
return []
def get_one(self, id):
"""
- List RuleType objects by id.
+ List RuleType objects by id.
- Handle:
- GET /ruletypes/1
+ Handle:
+ GET /ruletypes/1
"""
ruletype_db = RuleTypesController.__get_by_id(id)
ruletype_api = RuleTypeAPI.from_model(ruletype_db)
@@ -62,14 +64,15 @@ def get_one(self, id):
def get_all(self):
"""
- List all RuleType objects.
+ List all RuleType objects.
- Handles requests:
- GET /ruletypes/
+ Handles requests:
+ GET /ruletypes/
"""
ruletype_dbs = RuleType.get_all()
- ruletype_apis = [RuleTypeAPI.from_model(runnertype_db)
- for runnertype_db in ruletype_dbs]
+ ruletype_apis = [
+ RuleTypeAPI.from_model(runnertype_db) for runnertype_db in ruletype_dbs
+ ]
return ruletype_apis
diff --git a/st2api/st2api/controllers/v1/runnertypes.py b/st2api/st2api/controllers/v1/runnertypes.py
index b947babd94..1c84b4425c 100644
--- a/st2api/st2api/controllers/v1/runnertypes.py
+++ b/st2api/st2api/controllers/v1/runnertypes.py
@@ -31,34 +31,42 @@
class RunnerTypesController(ResourceController):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of an RunnerType in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of an RunnerType in the system.
"""
model = RunnerTypeAPI
access = RunnerType
- supported_filters = {
- 'name': 'name'
- }
-
- query_options = {
- 'sort': ['name']
- }
-
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return super(RunnerTypesController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ supported_filters = {"name": "name"}
+
+ query_options = {"sort": ["name"]}
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(RunnerTypesController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, name_or_id, requester_user):
- return self._get_one_by_name_or_id(name_or_id,
- requester_user=requester_user,
- permission_type=PermissionType.RUNNER_VIEW)
+ return self._get_one_by_name_or_id(
+ name_or_id,
+ requester_user=requester_user,
+ permission_type=PermissionType.RUNNER_VIEW,
+ )
def put(self, runner_type_api, name_or_id, requester_user):
# Note: We only allow "enabled" attribute of the runner to be changed
@@ -66,28 +74,41 @@ def put(self, runner_type_api, name_or_id, requester_user):
permission_type = PermissionType.RUNNER_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=runner_type_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=runner_type_db,
+ permission_type=permission_type,
+ )
old_runner_type_db = runner_type_db
- LOG.debug('PUT /runnertypes/ lookup with id=%s found object: %s', name_or_id,
- runner_type_db)
+ LOG.debug(
+ "PUT /runnertypes/ lookup with id=%s found object: %s",
+ name_or_id,
+ runner_type_db,
+ )
try:
if runner_type_api.id and runner_type_api.id != name_or_id:
- LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.',
- runner_type_api.id, name_or_id)
+ LOG.warning(
+ "Discarding mismatched id=%s found in payload and using uri_id=%s.",
+ runner_type_api.id,
+ name_or_id,
+ )
runner_type_db.enabled = runner_type_api.enabled
runner_type_db = RunnerType.add_or_update(runner_type_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for runner type data=%s', runner_type_api)
+ LOG.exception("Validation failed for runner type data=%s", runner_type_api)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'old_runner_type_db': old_runner_type_db, 'new_runner_type_db': runner_type_db}
- LOG.audit('Runner Type updated. RunnerType.id=%s.' % (runner_type_db.id), extra=extra)
+ extra = {
+ "old_runner_type_db": old_runner_type_db,
+ "new_runner_type_db": runner_type_db,
+ }
+ LOG.audit(
+ "Runner Type updated. RunnerType.id=%s." % (runner_type_db.id), extra=extra
+ )
runner_type_api = RunnerTypeAPI.from_model(runner_type_db)
return runner_type_api
diff --git a/st2api/st2api/controllers/v1/sensors.py b/st2api/st2api/controllers/v1/sensors.py
index a3a71853d8..b62b56c92d 100644
--- a/st2api/st2api/controllers/v1/sensors.py
+++ b/st2api/st2api/controllers/v1/sensors.py
@@ -36,35 +36,41 @@ class SensorTypeController(resource.ContentPackResourceController):
model = SensorTypeAPI
access = SensorType
supported_filters = {
- 'name': 'name',
- 'pack': 'pack',
- 'enabled': 'enabled',
- 'trigger': 'trigger_types'
+ "name": "name",
+ "pack": "pack",
+ "enabled": "enabled",
+ "trigger": "trigger_types",
}
- filter_transform_functions = {
- 'enabled': transform_to_bool
- }
-
- options = {
- 'sort': ['pack', 'name']
- }
-
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return super(SensorTypeController, self)._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ filter_transform_functions = {"enabled": transform_to_bool}
+
+ options = {"sort": ["pack", "name"]}
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return super(SensorTypeController, self)._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, ref_or_id, requester_user):
permission_type = PermissionType.SENSOR_VIEW
- return super(SensorTypeController, self)._get_one(ref_or_id,
- requester_user=requester_user,
- permission_type=permission_type)
+ return super(SensorTypeController, self)._get_one(
+ ref_or_id, requester_user=requester_user, permission_type=permission_type
+ )
def put(self, sensor_type, ref_or_id, requester_user):
# Note: Right now this function only supports updating of "enabled"
@@ -76,9 +82,11 @@ def put(self, sensor_type, ref_or_id, requester_user):
permission_type = PermissionType.SENSOR_MODIFY
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=sensor_type_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=sensor_type_db,
+ permission_type=permission_type,
+ )
sensor_type_id = sensor_type_db.id
@@ -88,23 +96,23 @@ def put(self, sensor_type, ref_or_id, requester_user):
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- if not getattr(sensor_type, 'pack', None):
+ if not getattr(sensor_type, "pack", None):
sensor_type.pack = sensor_type_db.pack
try:
old_sensor_type_db = sensor_type_db
sensor_type_db.id = sensor_type_id
- sensor_type_db.enabled = getattr(sensor_type, 'enabled', False)
+ sensor_type_db.enabled = getattr(sensor_type, "enabled", False)
sensor_type_db = SensorType.add_or_update(sensor_type_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Unable to update sensor_type data=%s', sensor_type)
+ LOG.exception("Unable to update sensor_type data=%s", sensor_type)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
extra = {
- 'old_sensor_type_db': old_sensor_type_db,
- 'new_sensor_type_db': sensor_type_db
+ "old_sensor_type_db": old_sensor_type_db,
+ "new_sensor_type_db": sensor_type_db,
}
- LOG.audit('Sensor updated. Sensor.id=%s.' % (sensor_type_db.id), extra=extra)
+ LOG.audit("Sensor updated. Sensor.id=%s." % (sensor_type_db.id), extra=extra)
sensor_type_api = SensorTypeAPI.from_model(sensor_type_db)
return sensor_type_api
diff --git a/st2api/st2api/controllers/v1/service_registry.py b/st2api/st2api/controllers/v1/service_registry.py
index d9ee9d542b..3a54563b25 100644
--- a/st2api/st2api/controllers/v1/service_registry.py
+++ b/st2api/st2api/controllers/v1/service_registry.py
@@ -22,8 +22,8 @@
from st2common.rbac.backends import get_rbac_backend
__all__ = [
- 'ServiceRegistryGroupsController',
- 'ServiceRegistryGroupMembersController',
+ "ServiceRegistryGroupsController",
+ "ServiceRegistryGroupMembersController",
]
@@ -35,11 +35,9 @@ def get_all(self, requester_user):
coordinator = coordination.get_coordinator()
group_ids = list(coordinator.get_groups().get())
- group_ids = [item.decode('utf-8') for item in group_ids]
+ group_ids = [item.decode("utf-8") for item in group_ids]
- result = {
- 'groups': group_ids
- }
+ result = {"groups": group_ids}
return result
@@ -51,26 +49,26 @@ def get_one(self, group_id, requester_user):
coordinator = coordination.get_coordinator()
if not isinstance(group_id, six.binary_type):
- group_id = group_id.encode('utf-8')
+ group_id = group_id.encode("utf-8")
try:
member_ids = list(coordinator.get_members(group_id).get())
except GroupNotCreated:
- msg = ('Group with ID "%s" not found.' % (group_id.decode('utf-8')))
+ msg = 'Group with ID "%s" not found.' % (group_id.decode("utf-8"))
raise StackStormDBObjectNotFoundError(msg)
- result = {
- 'members': []
- }
+ result = {"members": []}
for member_id in member_ids:
- capabilities = coordinator.get_member_capabilities(group_id, member_id).get()
+ capabilities = coordinator.get_member_capabilities(
+ group_id, member_id
+ ).get()
item = {
- 'group_id': group_id.decode('utf-8'),
- 'member_id': member_id.decode('utf-8'),
- 'capabilities': capabilities
+ "group_id": group_id.decode("utf-8"),
+ "member_id": member_id.decode("utf-8"),
+ "capabilities": capabilities,
}
- result['members'].append(item)
+ result["members"].append(item)
return result
diff --git a/st2api/st2api/controllers/v1/timers.py b/st2api/st2api/controllers/v1/timers.py
index c91b80fec1..541957a099 100644
--- a/st2api/st2api/controllers/v1/timers.py
+++ b/st2api/st2api/controllers/v1/timers.py
@@ -30,17 +30,13 @@
from st2common.services.triggerwatcher import TriggerWatcher
from st2common.router import abort
-__all__ = [
- 'TimersController',
- 'TimersHolder'
-]
+__all__ = ["TimersController", "TimersHolder"]
LOG = logging.getLogger(__name__)
class TimersHolder(object):
-
def __init__(self):
self._timers = {}
@@ -54,7 +50,7 @@ def get_all(self, timer_type=None):
timer_triggers = []
for _, timer in iteritems(self._timers):
- if not timer_type or timer['type'] == timer_type:
+ if not timer_type or timer["type"] == timer_type:
timer_triggers.append(timer)
return timer_triggers
@@ -65,35 +61,37 @@ class TimersController(resource.ContentPackResourceController):
access = Trigger
supported_filters = {
- 'type': 'type',
+ "type": "type",
}
- query_options = {
- 'sort': ['type']
- }
+ query_options = {"sort": ["type"]}
def __init__(self):
self._timers = TimersHolder()
self._trigger_types = TIMER_TRIGGER_TYPES.keys()
queue_suffix = self.__class__.__name__
- self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger,
- update_handler=self._handle_update_trigger,
- delete_handler=self._handle_delete_trigger,
- trigger_types=self._trigger_types,
- queue_suffix=queue_suffix,
- exclusive=True)
+ self._trigger_watcher = TriggerWatcher(
+ create_handler=self._handle_create_trigger,
+ update_handler=self._handle_update_trigger,
+ delete_handler=self._handle_delete_trigger,
+ trigger_types=self._trigger_types,
+ queue_suffix=queue_suffix,
+ exclusive=True,
+ )
self._trigger_watcher.start()
self._register_timer_trigger_types()
self._allowed_timer_types = TIMER_TRIGGER_TYPES.keys()
def get_all(self, timer_type=None):
if timer_type and timer_type not in self._allowed_timer_types:
- msg = 'Timer type %s not in supported types - %s.' % (timer_type,
- self._allowed_timer_types)
+ msg = "Timer type %s not in supported types - %s." % (
+ timer_type,
+ self._allowed_timer_types,
+ )
abort(http_client.BAD_REQUEST, msg)
t_all = self._timers.get_all(timer_type=timer_type)
- LOG.debug('Got timers: %s', t_all)
+ LOG.debug("Got timers: %s", t_all)
return t_all
def get_one(self, ref_or_id, requester_user):
@@ -108,9 +106,11 @@ def get_one(self, ref_or_id, requester_user):
resource_db = TimerDB(pack=trigger_db.pack, name=trigger_db.name)
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=resource_db,
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=resource_db,
+ permission_type=permission_type,
+ )
result = self.model.from_model(trigger_db)
return result
@@ -119,7 +119,7 @@ def add_trigger(self, trigger):
# Note: Permission checking for creating and deleting a timer is done during rule
# creation
ref = self._get_timer_ref(trigger)
- LOG.info('Started timer %s with parameters %s', ref, trigger['parameters'])
+ LOG.info("Started timer %s with parameters %s", ref, trigger["parameters"])
self._timers.add_trigger(ref, trigger)
def update_trigger(self, trigger):
@@ -130,14 +130,16 @@ def remove_trigger(self, trigger):
# creation
ref = self._get_timer_ref(trigger)
self._timers.remove_trigger(ref, trigger)
- LOG.info('Stopped timer %s with parameters %s.', ref, trigger['parameters'])
+ LOG.info("Stopped timer %s with parameters %s.", ref, trigger["parameters"])
def _register_timer_trigger_types(self):
for trigger_type in TIMER_TRIGGER_TYPES.values():
trigger_service.create_trigger_type_db(trigger_type)
def _get_timer_ref(self, trigger):
- return ResourceReference.to_string_reference(pack=trigger['pack'], name=trigger['name'])
+ return ResourceReference.to_string_reference(
+ pack=trigger["pack"], name=trigger["name"]
+ )
##############################################
# Event handler methods for the trigger events
diff --git a/st2api/st2api/controllers/v1/traces.py b/st2api/st2api/controllers/v1/traces.py
index 91c6e95e4f..4ab1d02aa5 100644
--- a/st2api/st2api/controllers/v1/traces.py
+++ b/st2api/st2api/controllers/v1/traces.py
@@ -18,47 +18,53 @@
from st2common.persistence.trace import Trace
from st2common.rbac.types import PermissionType
-__all__ = [
- 'TracesController'
-]
+__all__ = ["TracesController"]
class TracesController(ResourceController):
model = TraceAPI
access = Trace
supported_filters = {
- 'trace_tag': 'trace_tag',
- 'execution': 'action_executions.object_id',
- 'rule': 'rules.object_id',
- 'trigger_instance': 'trigger_instances.object_id',
+ "trace_tag": "trace_tag",
+ "execution": "action_executions.object_id",
+ "rule": "rules.object_id",
+ "trigger_instance": "trigger_instances.object_id",
}
- query_options = {
- 'sort': ['-start_timestamp', 'trace_tag']
- }
+ query_options = {"sort": ["-start_timestamp", "trace_tag"]}
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
# Use a custom sort order when filtering on a timestamp so we return a correct result as
# expected by the user
query_options = None
- if 'sort_desc' in raw_filters and raw_filters['sort_desc'] == 'True':
- query_options = {'sort': ['-start_timestamp', 'trace_tag']}
- elif 'sort_asc' in raw_filters and raw_filters['sort_asc'] == 'True':
- query_options = {'sort': ['+start_timestamp', 'trace_tag']}
- return self._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- query_options=query_options,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ if "sort_desc" in raw_filters and raw_filters["sort_desc"] == "True":
+ query_options = {"sort": ["-start_timestamp", "trace_tag"]}
+ elif "sort_asc" in raw_filters and raw_filters["sort_asc"] == "True":
+ query_options = {"sort": ["+start_timestamp", "trace_tag"]}
+ return self._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ query_options=query_options,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, id, requester_user):
- return self._get_one_by_id(id,
- requester_user=requester_user,
- permission_type=PermissionType.TRACE_VIEW)
+ return self._get_one_by_id(
+ id, requester_user=requester_user, permission_type=PermissionType.TRACE_VIEW
+ )
traces_controller = TracesController()
diff --git a/st2api/st2api/controllers/v1/triggers.py b/st2api/st2api/controllers/v1/triggers.py
index 12c3f133ec..cbdc5ca66b 100644
--- a/st2api/st2api/controllers/v1/triggers.py
+++ b/st2api/st2api/controllers/v1/triggers.py
@@ -39,55 +39,64 @@
class TriggerTypeController(resource.ContentPackResourceController):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of TriggerTypes in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of TriggerTypes in the system.
"""
+
model = TriggerTypeAPI
access = TriggerType
- supported_filters = {
- 'name': 'name',
- 'pack': 'pack'
- }
-
- options = {
- 'sort': ['pack', 'name']
- }
-
- query_options = {
- 'sort': ['ref']
- }
-
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
- return self._get_all(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ supported_filters = {"name": "name", "pack": "pack"}
+
+ options = {"sort": ["pack", "name"]}
+
+ query_options = {"sort": ["ref"]}
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
+ return self._get_all(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
def get_one(self, triggertype_ref_or_id):
- return self._get_one(triggertype_ref_or_id, permission_type=None, requester_user=None)
+ return self._get_one(
+ triggertype_ref_or_id, permission_type=None, requester_user=None
+ )
def post(self, triggertype):
"""
- Create a new triggertype.
+ Create a new triggertype.
- Handles requests:
- POST /triggertypes/
+ Handles requests:
+ POST /triggertypes/
"""
try:
triggertype_db = TriggerTypeAPI.to_model(triggertype)
triggertype_db = TriggerType.add_or_update(triggertype_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for triggertype data=%s.', triggertype)
+ LOG.exception("Validation failed for triggertype data=%s.", triggertype)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
else:
- extra = {'triggertype_db': triggertype_db}
- LOG.audit('TriggerType created. TriggerType.id=%s' % (triggertype_db.id), extra=extra)
+ extra = {"triggertype_db": triggertype_db}
+ LOG.audit(
+ "TriggerType created. TriggerType.id=%s" % (triggertype_db.id),
+ extra=extra,
+ )
if not triggertype_db.parameters_schema:
TriggerTypeController._create_shadow_trigger(triggertype_db)
@@ -106,34 +115,44 @@ def put(self, triggertype, triggertype_ref_or_id):
try:
triggertype_db = TriggerTypeAPI.to_model(triggertype)
- if triggertype.id is not None and len(triggertype.id) > 0 and \
- triggertype.id != triggertype_id:
- LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.',
- triggertype.id, triggertype_id)
+ if (
+ triggertype.id is not None
+ and len(triggertype.id) > 0
+ and triggertype.id != triggertype_id
+ ):
+ LOG.warning(
+ "Discarding mismatched id=%s found in payload and using uri_id=%s.",
+ triggertype.id,
+ triggertype_id,
+ )
triggertype_db.id = triggertype_id
old_triggertype_db = triggertype_db
triggertype_db = TriggerType.add_or_update(triggertype_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for triggertype data=%s', triggertype)
+ LOG.exception("Validation failed for triggertype data=%s", triggertype)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'old_triggertype_db': old_triggertype_db, 'new_triggertype_db': triggertype_db}
- LOG.audit('TriggerType updated. TriggerType.id=%s' % (triggertype_db.id), extra=extra)
+ extra = {
+ "old_triggertype_db": old_triggertype_db,
+ "new_triggertype_db": triggertype_db,
+ }
+ LOG.audit(
+ "TriggerType updated. TriggerType.id=%s" % (triggertype_db.id), extra=extra
+ )
triggertype_api = TriggerTypeAPI.from_model(triggertype_db)
return triggertype_api
def delete(self, triggertype_ref_or_id):
"""
- Delete a triggertype.
+ Delete a triggertype.
- Handles requests:
- DELETE /triggertypes/1
- DELETE /triggertypes/pack.name
+ Handles requests:
+ DELETE /triggertypes/1
+ DELETE /triggertypes/pack.name
"""
- LOG.info('DELETE /triggertypes/ with ref_or_id=%s',
- triggertype_ref_or_id)
+ LOG.info("DELETE /triggertypes/ with ref_or_id=%s", triggertype_ref_or_id)
triggertype_db = self._get_by_ref_or_id(ref_or_id=triggertype_ref_or_id)
triggertype_id = triggertype_db.id
@@ -146,13 +165,18 @@ def delete(self, triggertype_ref_or_id):
try:
TriggerType.delete(triggertype_db)
except Exception as e:
- LOG.exception('Database delete encountered exception during delete of id="%s". ',
- triggertype_id)
+ LOG.exception(
+ 'Database delete encountered exception during delete of id="%s". ',
+ triggertype_id,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
else:
- extra = {'triggertype': triggertype_db}
- LOG.audit('TriggerType deleted. TriggerType.id=%s' % (triggertype_db.id), extra=extra)
+ extra = {"triggertype": triggertype_db}
+ LOG.audit(
+ "TriggerType deleted. TriggerType.id=%s" % (triggertype_db.id),
+ extra=extra,
+ )
if not triggertype_db.parameters_schema:
TriggerTypeController._delete_shadow_trigger(triggertype_db)
@@ -162,55 +186,70 @@ def delete(self, triggertype_ref_or_id):
def _create_shadow_trigger(triggertype_db):
try:
trigger_type_ref = triggertype_db.get_reference().ref
- trigger = {'name': triggertype_db.name,
- 'pack': triggertype_db.pack,
- 'type': trigger_type_ref,
- 'parameters': {}}
+ trigger = {
+ "name": triggertype_db.name,
+ "pack": triggertype_db.pack,
+ "type": trigger_type_ref,
+ "parameters": {},
+ }
trigger_db = TriggerService.create_or_update_trigger_db(trigger)
- extra = {'trigger_db': trigger_db}
- LOG.audit('Trigger created for parameter-less TriggerType. Trigger.id=%s' %
- (trigger_db.id), extra=extra)
+ extra = {"trigger_db": trigger_db}
+ LOG.audit(
+ "Trigger created for parameter-less TriggerType. Trigger.id=%s"
+ % (trigger_db.id),
+ extra=extra,
+ )
except (ValidationError, ValueError):
- LOG.exception('Validation failed for trigger data=%s.', trigger)
+ LOG.exception("Validation failed for trigger data=%s.", trigger)
# Not aborting as this is convenience.
return
except StackStormDBObjectConflictError as e:
- LOG.warn('Trigger creation of "%s" failed with uniqueness conflict. Exception: %s',
- trigger, six.text_type(e))
+ LOG.warn(
+ 'Trigger creation of "%s" failed with uniqueness conflict. Exception: %s',
+ trigger,
+ six.text_type(e),
+ )
# Not aborting as this is convenience.
return
@staticmethod
def _delete_shadow_trigger(triggertype_db):
# shadow Trigger's have the same name as the shadowed TriggerType.
- triggertype_ref = ResourceReference(name=triggertype_db.name, pack=triggertype_db.pack)
+ triggertype_ref = ResourceReference(
+ name=triggertype_db.name, pack=triggertype_db.pack
+ )
trigger_db = TriggerService.get_trigger_db_by_ref(triggertype_ref.ref)
if not trigger_db:
- LOG.warn('No shadow trigger found for %s. Will skip delete.', triggertype_db)
+ LOG.warn(
+ "No shadow trigger found for %s. Will skip delete.", triggertype_db
+ )
return
try:
Trigger.delete(trigger_db)
except Exception:
- LOG.exception('Database delete encountered exception during delete of id="%s". ',
- trigger_db.id)
+ LOG.exception(
+ 'Database delete encountered exception during delete of id="%s". ',
+ trigger_db.id,
+ )
- extra = {'trigger_db': trigger_db}
- LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra)
+ extra = {"trigger_db": trigger_db}
+ LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra)
class TriggerController(object):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of Triggers in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of Triggers in the system.
"""
+
def get_one(self, trigger_id):
"""
- List trigger by id.
+ List trigger by id.
- Handle:
- GET /triggers/1
+ Handle:
+ GET /triggers/1
"""
trigger_db = TriggerController.__get_by_id(trigger_id)
trigger_api = TriggerAPI.from_model(trigger_db)
@@ -218,10 +257,10 @@ def get_one(self, trigger_id):
def get_all(self, requester_user=None):
"""
- List all triggers.
+ List all triggers.
- Handles requests:
- GET /triggers/
+ Handles requests:
+ GET /triggers/
"""
trigger_dbs = Trigger.get_all()
trigger_apis = [TriggerAPI.from_model(trigger_db) for trigger_db in trigger_dbs]
@@ -229,20 +268,20 @@ def get_all(self, requester_user=None):
def post(self, trigger):
"""
- Create a new trigger.
+ Create a new trigger.
- Handles requests:
- POST /triggers/
+ Handles requests:
+ POST /triggers/
"""
try:
trigger_db = TriggerService.create_trigger_db(trigger)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for trigger data=%s.', trigger)
+ LOG.exception("Validation failed for trigger data=%s.", trigger)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'trigger': trigger_db}
- LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra)
+ extra = {"trigger": trigger_db}
+ LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra)
trigger_api = TriggerAPI.from_model(trigger_db)
return Response(json=trigger_api, status=http_client.CREATED)
@@ -250,42 +289,47 @@ def post(self, trigger):
def put(self, trigger, trigger_id):
trigger_db = TriggerController.__get_by_id(trigger_id)
try:
- if trigger.id is not None and trigger.id != '' and trigger.id != trigger_id:
- LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.',
- trigger.id, trigger_id)
+ if trigger.id is not None and trigger.id != "" and trigger.id != trigger_id:
+ LOG.warning(
+ "Discarding mismatched id=%s found in payload and using uri_id=%s.",
+ trigger.id,
+ trigger_id,
+ )
trigger_db = TriggerAPI.to_model(trigger)
trigger_db.id = trigger_id
trigger_db = Trigger.add_or_update(trigger_db)
except (ValidationError, ValueError) as e:
- LOG.exception('Validation failed for trigger data=%s', trigger)
+ LOG.exception("Validation failed for trigger data=%s", trigger)
abort(http_client.BAD_REQUEST, six.text_type(e))
return
- extra = {'old_trigger_db': trigger, 'new_trigger_db': trigger_db}
- LOG.audit('Trigger updated. Trigger.id=%s' % (trigger.id), extra=extra)
+ extra = {"old_trigger_db": trigger, "new_trigger_db": trigger_db}
+ LOG.audit("Trigger updated. Trigger.id=%s" % (trigger.id), extra=extra)
trigger_api = TriggerAPI.from_model(trigger_db)
return trigger_api
def delete(self, trigger_id):
"""
- Delete a trigger.
+ Delete a trigger.
- Handles requests:
- DELETE /triggers/1
+ Handles requests:
+ DELETE /triggers/1
"""
- LOG.info('DELETE /triggers/ with id=%s', trigger_id)
+ LOG.info("DELETE /triggers/ with id=%s", trigger_id)
trigger_db = TriggerController.__get_by_id(trigger_id)
try:
Trigger.delete(trigger_db)
except Exception as e:
- LOG.exception('Database delete encountered exception during delete of id="%s". ',
- trigger_id)
+ LOG.exception(
+ 'Database delete encountered exception during delete of id="%s". ',
+ trigger_id,
+ )
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
return
- extra = {'trigger_db': trigger_db}
- LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra)
+ extra = {"trigger_db": trigger_db}
+ LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra)
return Response(status=http_client.NO_CONTENT)
@@ -294,7 +338,9 @@ def __get_by_id(trigger_id):
try:
return Trigger.get_by_id(trigger_id)
except (ValueError, ValidationError):
- LOG.exception('Database lookup for id="%s" resulted in exception.', trigger_id)
+ LOG.exception(
+ 'Database lookup for id="%s" resulted in exception.', trigger_id
+ )
abort(http_client.NOT_FOUND)
@staticmethod
@@ -302,7 +348,11 @@ def __get_by_name(trigger_name):
try:
return [Trigger.get_by_name(trigger_name)]
except ValueError as e:
- LOG.debug('Database lookup for name="%s" resulted in exception : %s.', trigger_name, e)
+ LOG.debug(
+ 'Database lookup for name="%s" resulted in exception : %s.',
+ trigger_name,
+ e,
+ )
return []
@@ -311,7 +361,9 @@ class TriggerInstanceControllerMixin(object):
access = TriggerInstance
-class TriggerInstanceResendController(TriggerInstanceControllerMixin, resource.ResourceController):
+class TriggerInstanceResendController(
+ TriggerInstanceControllerMixin, resource.ResourceController
+):
supported_filters = {}
def __init__(self, *args, **kwargs):
@@ -338,106 +390,130 @@ def post(self, trigger_instance_id):
POST /triggerinstance//re_send
"""
# Note: We only really need parameters here
- existing_trigger_instance = self._get_one_by_id(id=trigger_instance_id,
- permission_type=None,
- requester_user=None)
+ existing_trigger_instance = self._get_one_by_id(
+ id=trigger_instance_id, permission_type=None, requester_user=None
+ )
new_payload = copy.deepcopy(existing_trigger_instance.payload)
- new_payload['__context'] = {
- 'original_id': trigger_instance_id
- }
+ new_payload["__context"] = {"original_id": trigger_instance_id}
try:
- self.trigger_dispatcher.dispatch(existing_trigger_instance.trigger,
- new_payload)
+ self.trigger_dispatcher.dispatch(
+ existing_trigger_instance.trigger, new_payload
+ )
return {
- 'message': 'Trigger instance %s succesfully re-sent.' % trigger_instance_id,
- 'payload': new_payload
+ "message": "Trigger instance %s succesfully re-sent."
+ % trigger_instance_id,
+ "payload": new_payload,
}
except Exception as e:
abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e))
-class TriggerInstanceController(TriggerInstanceControllerMixin, resource.ResourceController):
+class TriggerInstanceController(
+ TriggerInstanceControllerMixin, resource.ResourceController
+):
"""
- Implements the RESTful web endpoint that handles
- the lifecycle of TriggerInstances in the system.
+ Implements the RESTful web endpoint that handles
+ the lifecycle of TriggerInstances in the system.
"""
+
supported_filters = {
- 'timestamp_gt': 'occurrence_time.gt',
- 'timestamp_lt': 'occurrence_time.lt',
- 'status': 'status',
- 'trigger': 'trigger.in'
+ "timestamp_gt": "occurrence_time.gt",
+ "timestamp_lt": "occurrence_time.lt",
+ "status": "status",
+ "trigger": "trigger.in",
}
filter_transform_functions = {
- 'timestamp_gt': lambda value: isotime.parse(value=value),
- 'timestamp_lt': lambda value: isotime.parse(value=value)
+ "timestamp_gt": lambda value: isotime.parse(value=value),
+ "timestamp_lt": lambda value: isotime.parse(value=value),
}
- query_options = {
- 'sort': ['-occurrence_time', 'trigger']
- }
+ query_options = {"sort": ["-occurrence_time", "trigger"]}
def __init__(self):
super(TriggerInstanceController, self).__init__()
def get_one(self, instance_id):
"""
- List triggerinstance by instance_id.
+ List triggerinstance by instance_id.
- Handle:
- GET /triggerinstances/1
+ Handle:
+ GET /triggerinstances/1
"""
- return self._get_one_by_id(instance_id, permission_type=None, requester_user=None)
-
- def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0,
- limit=None, requester_user=None, **raw_filters):
+ return self._get_one_by_id(
+ instance_id, permission_type=None, requester_user=None
+ )
+
+ def get_all(
+ self,
+ exclude_attributes=None,
+ include_attributes=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ requester_user=None,
+ **raw_filters,
+ ):
"""
- List all triggerinstances.
+ List all triggerinstances.
- Handles requests:
- GET /triggerinstances/
+ Handles requests:
+ GET /triggerinstances/
"""
# If trigger_type filter is provided, filter based on the TriggerType via Trigger object
- trigger_type_ref = raw_filters.get('trigger_type', None)
+ trigger_type_ref = raw_filters.get("trigger_type", None)
if trigger_type_ref:
# 1. Retrieve TriggerType object id which match this trigger_type ref
- trigger_dbs = Trigger.query(type=trigger_type_ref,
- only_fields=['ref', 'name', 'pack', 'type'])
+ trigger_dbs = Trigger.query(
+ type=trigger_type_ref, only_fields=["ref", "name", "pack", "type"]
+ )
trigger_refs = [trigger_db.ref for trigger_db in trigger_dbs]
- raw_filters['trigger'] = trigger_refs
+ raw_filters["trigger"] = trigger_refs
- if trigger_type_ref and len(raw_filters.get('trigger', [])) == 0:
+ if trigger_type_ref and len(raw_filters.get("trigger", [])) == 0:
# Empty list means trigger_type_ref filter was provided, but we matched no Triggers so
# we should return back empty result
return []
- trigger_instances = self._get_trigger_instances(exclude_fields=exclude_attributes,
- include_fields=include_attributes,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ trigger_instances = self._get_trigger_instances(
+ exclude_fields=exclude_attributes,
+ include_fields=include_attributes,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
return trigger_instances
- def _get_trigger_instances(self, exclude_fields=None, include_fields=None, sort=None, offset=0,
- limit=None, raw_filters=None, requester_user=None):
+ def _get_trigger_instances(
+ self,
+ exclude_fields=None,
+ include_fields=None,
+ sort=None,
+ offset=0,
+ limit=None,
+ raw_filters=None,
+ requester_user=None,
+ ):
if limit is None:
limit = self.default_limit
limit = int(limit)
- LOG.debug('Retrieving all trigger instances with filters=%s', raw_filters)
- return super(TriggerInstanceController, self)._get_all(exclude_fields=exclude_fields,
- include_fields=include_fields,
- sort=sort,
- offset=offset,
- limit=limit,
- raw_filters=raw_filters,
- requester_user=requester_user)
+ LOG.debug("Retrieving all trigger instances with filters=%s", raw_filters)
+ return super(TriggerInstanceController, self)._get_all(
+ exclude_fields=exclude_fields,
+ include_fields=include_fields,
+ sort=sort,
+ offset=offset,
+ limit=limit,
+ raw_filters=raw_filters,
+ requester_user=requester_user,
+ )
triggertype_controller = TriggerTypeController()
diff --git a/st2api/st2api/controllers/v1/user.py b/st2api/st2api/controllers/v1/user.py
index e3de60b978..0593a13384 100644
--- a/st2api/st2api/controllers/v1/user.py
+++ b/st2api/st2api/controllers/v1/user.py
@@ -17,9 +17,7 @@
from st2common.rbac.backends import get_rbac_backend
-__all__ = [
- 'UserController'
-]
+__all__ = ["UserController"]
class UserController(object):
@@ -43,21 +41,21 @@ def get(self, requester_user, auth_info):
roles = []
data = {
- 'username': requester_user.name,
- 'authentication': {
- 'method': auth_info['method'],
- 'location': auth_info['location']
+ "username": requester_user.name,
+ "authentication": {
+ "method": auth_info["method"],
+ "location": auth_info["location"],
+ },
+ "rbac": {
+ "enabled": cfg.CONF.rbac.enable,
+ "roles": roles,
+ "is_admin": rbac_utils.user_is_admin(user_db=requester_user),
},
- 'rbac': {
- 'enabled': cfg.CONF.rbac.enable,
- 'roles': roles,
- 'is_admin': rbac_utils.user_is_admin(user_db=requester_user)
- }
}
- if auth_info.get('token_expire', None):
- token_expire = auth_info['token_expire'].strftime('%Y-%m-%dT%H:%M:%SZ')
- data['authentication']['token_expire'] = token_expire
+ if auth_info.get("token_expire", None):
+ token_expire = auth_info["token_expire"].strftime("%Y-%m-%dT%H:%M:%SZ")
+ data["authentication"]["token_expire"] = token_expire
return data
diff --git a/st2api/st2api/controllers/v1/webhooks.py b/st2api/st2api/controllers/v1/webhooks.py
index 35af0c8337..1985bb4dad 100644
--- a/st2api/st2api/controllers/v1/webhooks.py
+++ b/st2api/st2api/controllers/v1/webhooks.py
@@ -19,7 +19,10 @@
from six.moves import http_client
from st2common import log as logging
-from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME
+from st2common.constants.auth import (
+ HEADER_API_KEY_ATTRIBUTE_NAME,
+ HEADER_ATTRIBUTE_NAME,
+)
from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES
from st2common.models.api.trace import TraceContext
from st2common.models.api.trigger import TriggerAPI
@@ -35,13 +38,14 @@
LOG = logging.getLogger(__name__)
-TRACE_TAG_HEADER = 'St2-Trace-Tag'
+TRACE_TAG_HEADER = "St2-Trace-Tag"
class HooksHolder(object):
"""
Maintains a hook to TriggerDB mapping.
"""
+
def __init__(self):
self._triggers_by_hook = {}
@@ -58,7 +62,7 @@ def remove_hook(self, hook, trigger):
return False
remove_index = -1
for idx, item in enumerate(self._triggers_by_hook[hook]):
- if item['id'] == trigger['id']:
+ if item["id"] == trigger["id"]:
remove_index = idx
break
if remove_index < 0:
@@ -81,17 +85,19 @@ def get_all(self):
class WebhooksController(object):
def __init__(self, *args, **kwargs):
self._hooks = HooksHolder()
- self._base_url = '/webhooks/'
+ self._base_url = "/webhooks/"
self._trigger_types = list(WEBHOOK_TRIGGER_TYPES.keys())
self._trigger_dispatcher_service = TriggerDispatcherService(LOG)
queue_suffix = self.__class__.__name__
- self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger,
- update_handler=self._handle_update_trigger,
- delete_handler=self._handle_delete_trigger,
- trigger_types=self._trigger_types,
- queue_suffix=queue_suffix,
- exclusive=True)
+ self._trigger_watcher = TriggerWatcher(
+ create_handler=self._handle_create_trigger,
+ update_handler=self._handle_update_trigger,
+ delete_handler=self._handle_delete_trigger,
+ trigger_types=self._trigger_types,
+ queue_suffix=queue_suffix,
+ exclusive=True,
+ )
self._trigger_watcher.start()
self._register_webhook_trigger_types()
@@ -108,9 +114,11 @@ def get_one(self, url, requester_user):
permission_type = PermissionType.WEBHOOK_VIEW
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=WebhookDB(name=url),
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=WebhookDB(name=url),
+ permission_type=permission_type,
+ )
# For demonstration purpose return 1st
return triggers[0]
@@ -120,55 +128,62 @@ def post(self, hook, webhook_body_api, headers, requester_user):
permission_type = PermissionType.WEBHOOK_SEND
rbac_utils = get_rbac_backend().get_utils_class()
- rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user,
- resource_db=WebhookDB(name=hook),
- permission_type=permission_type)
+ rbac_utils.assert_user_has_resource_db_permission(
+ user_db=requester_user,
+ resource_db=WebhookDB(name=hook),
+ permission_type=permission_type,
+ )
headers = self._get_headers_as_dict(headers)
headers = self._filter_authentication_headers(headers)
# If webhook contains a trace-tag use that else create create a unique trace-tag.
- trace_context = self._create_trace_context(trace_tag=headers.pop(TRACE_TAG_HEADER, None),
- hook=hook)
+ trace_context = self._create_trace_context(
+ trace_tag=headers.pop(TRACE_TAG_HEADER, None), hook=hook
+ )
- if hook == 'st2' or hook == 'st2/':
+ if hook == "st2" or hook == "st2/":
# When using st2 or system webhook, body needs to always be a dict
if not isinstance(body, dict):
type_string = get_json_type_for_python_value(body)
- msg = ('Webhook body needs to be an object, got: %s' % (type_string))
+ msg = "Webhook body needs to be an object, got: %s" % (type_string)
raise ValueError(msg)
- trigger = body.get('trigger', None)
- payload = body.get('payload', None)
+ trigger = body.get("trigger", None)
+ payload = body.get("payload", None)
if not trigger:
- msg = 'Trigger not specified.'
+ msg = "Trigger not specified."
return abort(http_client.BAD_REQUEST, msg)
- self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger,
- payload=payload,
- trace_context=trace_context,
- throw_on_validation_error=True)
+ self._trigger_dispatcher_service.dispatch_with_context(
+ trigger=trigger,
+ payload=payload,
+ trace_context=trace_context,
+ throw_on_validation_error=True,
+ )
else:
if not self._is_valid_hook(hook):
- self._log_request('Invalid hook.', headers, body)
- msg = 'Webhook %s not registered with st2' % hook
+ self._log_request("Invalid hook.", headers, body)
+ msg = "Webhook %s not registered with st2" % hook
return abort(http_client.NOT_FOUND, msg)
triggers = self._hooks.get_triggers_for_hook(hook)
payload = {}
- payload['headers'] = headers
- payload['body'] = body
+ payload["headers"] = headers
+ payload["body"] = body
# Dispatch trigger instance for each of the trigger found
for trigger_dict in triggers:
# TODO: Instead of dispatching the whole dict we should just
# dispatch TriggerDB.ref or similar
- self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger_dict,
- payload=payload,
- trace_context=trace_context,
- throw_on_validation_error=True)
+ self._trigger_dispatcher_service.dispatch_with_context(
+ trigger=trigger_dict,
+ payload=payload,
+ trace_context=trace_context,
+ throw_on_validation_error=True,
+ )
return Response(json=body, status=http_client.ACCEPTED)
@@ -183,7 +198,7 @@ def _register_webhook_trigger_types(self):
def _create_trace_context(self, trace_tag, hook):
# if no trace_tag then create a unique one
if not trace_tag:
- trace_tag = 'webhook-%s-%s' % (hook, uuid.uuid4().hex)
+ trace_tag = "webhook-%s-%s" % (hook, uuid.uuid4().hex)
return TraceContext(trace_tag=trace_tag)
def add_trigger(self, trigger):
@@ -191,7 +206,7 @@ def add_trigger(self, trigger):
# Note: Permission checking for creating and deleting a webhook is done during rule
# creation
url = self._get_normalized_url(trigger)
- LOG.info('Listening to endpoint: %s', urlparse.urljoin(self._base_url, url))
+ LOG.info("Listening to endpoint: %s", urlparse.urljoin(self._base_url, url))
self._hooks.add_hook(url, trigger)
def update_trigger(self, trigger):
@@ -204,14 +219,16 @@ def remove_trigger(self, trigger):
removed = self._hooks.remove_hook(url, trigger)
if removed:
- LOG.info('Stop listening to endpoint: %s', urlparse.urljoin(self._base_url, url))
+ LOG.info(
+ "Stop listening to endpoint: %s", urlparse.urljoin(self._base_url, url)
+ )
def _get_normalized_url(self, trigger):
"""
remove the trailing and leading / so that the hook url and those coming
from trigger parameters end up being the same.
"""
- return trigger['parameters']['url'].strip('/')
+ return trigger["parameters"]["url"].strip("/")
def _get_headers_as_dict(self, headers):
headers_dict = {}
@@ -220,13 +237,13 @@ def _get_headers_as_dict(self, headers):
return headers_dict
def _filter_authentication_headers(self, headers):
- auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, 'Cookie']
+ auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, "Cookie"]
return {key: value for key, value in headers.items() if key not in auth_headers}
def _log_request(self, msg, headers, body, log_method=LOG.debug):
headers = self._get_headers_as_dict(headers)
body = str(body)
- log_method('%s\n\trequest.header: %s.\n\trequest.body: %s.', msg, headers, body)
+ log_method("%s\n\trequest.header: %s.\n\trequest.body: %s.", msg, headers, body)
##############################################
# Event handler methods for the trigger events
diff --git a/st2api/st2api/controllers/v1/workflow_inspection.py b/st2api/st2api/controllers/v1/workflow_inspection.py
index 1e5ee53d85..04d60dd2b1 100644
--- a/st2api/st2api/controllers/v1/workflow_inspection.py
+++ b/st2api/st2api/controllers/v1/workflow_inspection.py
@@ -30,13 +30,12 @@
class WorkflowInspectionController(object):
-
def mock_st2_ctx(self):
st2_ctx = {
- 'st2': {
- 'api_url': api_utils.get_full_public_api_url(),
- 'action_execution_id': uuid.uuid4().hex,
- 'user': cfg.CONF.system_user.user
+ "st2": {
+ "api_url": api_utils.get_full_public_api_url(),
+ "action_execution_id": uuid.uuid4().hex,
+ "user": cfg.CONF.system_user.user,
}
}
@@ -44,7 +43,7 @@ def mock_st2_ctx(self):
def post(self, wf_def):
# Load workflow definition into workflow spec model.
- spec_module = specs_loader.get_spec_module('native')
+ spec_module = specs_loader.get_spec_module("native")
wf_spec = spec_module.instantiate(wf_def)
# Mock the st2 context that is typically passed to the workflow engine.
diff --git a/st2api/st2api/validation.py b/st2api/st2api/validation.py
index ae92d1d9cb..42120c57bf 100644
--- a/st2api/st2api/validation.py
+++ b/st2api/st2api/validation.py
@@ -15,9 +15,7 @@
from oslo_config import cfg
-__all__ = [
- 'validate_rbac_is_correctly_configured'
-]
+__all__ = ["validate_rbac_is_correctly_configured"]
def validate_rbac_is_correctly_configured():
@@ -28,24 +26,29 @@ def validate_rbac_is_correctly_configured():
return True
from st2common.rbac.backends import get_available_backends
+
available_rbac_backends = get_available_backends()
# 1. Verify auth is enabled
if not cfg.CONF.auth.enable:
- msg = ('Authentication is not enabled. RBAC only works when authentication is enabled. '
- 'You can either enable authentication or disable RBAC.')
+ msg = (
+ "Authentication is not enabled. RBAC only works when authentication is enabled. "
+ "You can either enable authentication or disable RBAC."
+ )
raise ValueError(msg)
# 2. Verify default backend is set
- if cfg.CONF.rbac.backend != 'default':
- msg = ('You have enabled RBAC, but RBAC backend is not set to "default". '
- 'For RBAC to work, you need to set '
- '"rbac.backend" config option to "default" and restart st2api service.')
+ if cfg.CONF.rbac.backend != "default":
+ msg = (
+ 'You have enabled RBAC, but RBAC backend is not set to "default". '
+ "For RBAC to work, you need to set "
+ '"rbac.backend" config option to "default" and restart st2api service.'
+ )
raise ValueError(msg)
# 3. Verify default RBAC backend is available
- if 'default' not in available_rbac_backends:
- msg = ('"default" RBAC backend is not available.')
+ if "default" not in available_rbac_backends:
+ msg = '"default" RBAC backend is not available.'
raise ValueError(msg)
return True
diff --git a/st2api/st2api/wsgi.py b/st2api/st2api/wsgi.py
index b9c92b7bf4..79baf0f110 100644
--- a/st2api/st2api/wsgi.py
+++ b/st2api/st2api/wsgi.py
@@ -20,6 +20,7 @@
import os
from st2common.util.monkey_patch import monkey_patch
+
# Note: We need to perform monkey patching in the worker. If we do it in
# the master process (gunicorn_config.py), it breaks tons of things
# including shutdown
@@ -32,8 +33,11 @@
from st2api import app
config = {
- 'is_gunicorn': True,
- 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')]
+ "is_gunicorn": True,
+ "config_args": [
+ "--config-file",
+ os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"),
+ ],
}
application = app.setup_app(config)
diff --git a/st2api/tests/integration/test_gunicorn_configs.py b/st2api/tests/integration/test_gunicorn_configs.py
index 65950bfa7c..9375cf3b85 100644
--- a/st2api/tests/integration/test_gunicorn_configs.py
+++ b/st2api/tests/integration/test_gunicorn_configs.py
@@ -28,38 +28,44 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf')
+ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf")
class GunicornWSGIEntryPointTestCase(IntegrationTestCase):
- @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled')
+ @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled")
def test_st2api_wsgi_entry_point(self):
port = random.randint(10000, 30000)
- cmd = ('gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port)
+ cmd = (
+ 'gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1'
+ % port
+ )
env = os.environ.copy()
- env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH
+ env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH
process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid)
try:
self.add_process(process=process)
eventlet.sleep(8)
self.assertProcessIsRunning(process=process)
- response = requests.get('http://127.0.0.1:%s/v1/actions' % (port))
+ response = requests.get("http://127.0.0.1:%s/v1/actions" % (port))
self.assertEqual(response.status_code, http_client.OK)
finally:
kill_process(process)
- @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled')
+ @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled")
def test_st2auth(self):
port = random.randint(10000, 30000)
- cmd = ('gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port)
+ cmd = (
+ 'gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1'
+ % port
+ )
env = os.environ.copy()
- env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH
+ env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH
process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid)
try:
self.add_process(process=process)
eventlet.sleep(8)
self.assertProcessIsRunning(process=process)
- response = requests.post('http://127.0.0.1:%s/tokens' % (port))
+ response = requests.post("http://127.0.0.1:%s/tokens" % (port))
self.assertEqual(response.status_code, http_client.UNAUTHORIZED)
finally:
kill_process(process)
diff --git a/st2api/tests/unit/controllers/test_root.py b/st2api/tests/unit/controllers/test_root.py
index d4172ce155..db4ea01713 100644
--- a/st2api/tests/unit/controllers/test_root.py
+++ b/st2api/tests/unit/controllers/test_root.py
@@ -15,15 +15,13 @@
from st2tests.api import FunctionalTest
-__all__ = [
- 'RootControllerTestCase'
-]
+__all__ = ["RootControllerTestCase"]
class RootControllerTestCase(FunctionalTest):
def test_get_index(self):
- paths = ['/', '/v1/', '/v1']
+ paths = ["/", "/v1/", "/v1"]
for path in paths:
resp = self.app.get(path)
- self.assertIn('version', resp.json)
- self.assertIn('docs_url', resp.json)
+ self.assertIn("version", resp.json)
+ self.assertIn("docs_url", resp.json)
diff --git a/st2api/tests/unit/controllers/v1/test_action_alias.py b/st2api/tests/unit/controllers/v1/test_action_alias.py
index 299ce530e3..208ed082be 100644
--- a/st2api/tests/unit/controllers/v1/test_action_alias.py
+++ b/st2api/tests/unit/controllers/v1/test_action_alias.py
@@ -21,31 +21,33 @@
from st2tests.api import FunctionalTest
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
-FIXTURES_PACK = 'aliases'
+FIXTURES_PACK = "aliases"
TEST_MODELS = {
- 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml'],
- 'actions': ['action3.yaml', 'action4.yaml']
+ "aliases": [
+ "alias1.yaml",
+ "alias2.yaml",
+ "alias_with_undefined_jinja_in_ack_format.yaml",
+ ],
+ "actions": ["action3.yaml", "action4.yaml"],
}
TEST_LOAD_MODELS = {
- 'aliases': ['alias3.yaml'],
+ "aliases": ["alias3.yaml"],
}
-GENERIC_FIXTURES_PACK = 'generic'
+GENERIC_FIXTURES_PACK = "generic"
-TEST_LOAD_MODELS_GENERIC = {
- 'aliases': ['alias3.yaml'],
- 'runners': ['testrunner1.yaml']
-}
+TEST_LOAD_MODELS_GENERIC = {"aliases": ["alias3.yaml"], "runners": ["testrunner1.yaml"]}
-class ActionAliasControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/actionalias'
+class ActionAliasControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/actionalias"
controller_cls = ActionAliasController
- include_attribute_field_name = 'formats'
- exclude_attribute_field_name = 'result'
+ include_attribute_field_name = "formats"
+ exclude_attribute_field_name = "result"
models = None
alias1 = None
@@ -56,153 +58,186 @@ class ActionAliasControllerTestCase(FunctionalTest,
@classmethod
def setUpClass(cls):
super(ActionAliasControllerTestCase, cls).setUpClass()
- cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
- cls.alias1 = cls.models['aliases']['alias1.yaml']
- cls.alias2 = cls.models['aliases']['alias2.yaml']
-
- loaded_models = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_LOAD_MODELS)
- cls.alias3 = loaded_models['aliases']['alias3.yaml']
-
- FixturesLoader().save_fixtures_to_db(fixtures_pack=GENERIC_FIXTURES_PACK,
- fixtures_dict={'aliases': ['alias7.yaml']})
-
- loaded_models = FixturesLoader().load_models(fixtures_pack=GENERIC_FIXTURES_PACK,
- fixtures_dict=TEST_LOAD_MODELS_GENERIC)
- cls.alias3_generic = loaded_models['aliases']['alias3.yaml']
+ cls.models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+ cls.alias1 = cls.models["aliases"]["alias1.yaml"]
+ cls.alias2 = cls.models["aliases"]["alias2.yaml"]
+
+ loaded_models = FixturesLoader().load_models(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS
+ )
+ cls.alias3 = loaded_models["aliases"]["alias3.yaml"]
+
+ FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=GENERIC_FIXTURES_PACK,
+ fixtures_dict={"aliases": ["alias7.yaml"]},
+ )
+
+ loaded_models = FixturesLoader().load_models(
+ fixtures_pack=GENERIC_FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS_GENERIC
+ )
+ cls.alias3_generic = loaded_models["aliases"]["alias3.yaml"]
def test_get_all(self):
- resp = self.app.get('/v1/actionalias')
+ resp = self.app.get("/v1/actionalias")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 4, '/v1/actionalias did not return all aliases.')
-
- retrieved_names = [alias['name'] for alias in resp.json]
-
- self.assertEqual(retrieved_names, [self.alias1.name, self.alias2.name,
- 'alias_with_undefined_jinja_in_ack_format',
- 'alias7'],
- 'Incorrect aliases retrieved.')
+ self.assertEqual(
+ len(resp.json), 4, "/v1/actionalias did not return all aliases."
+ )
+
+ retrieved_names = [alias["name"] for alias in resp.json]
+
+ self.assertEqual(
+ retrieved_names,
+ [
+ self.alias1.name,
+ self.alias2.name,
+ "alias_with_undefined_jinja_in_ack_format",
+ "alias7",
+ ],
+ "Incorrect aliases retrieved.",
+ )
def test_get_all_query_param_filters(self):
- resp = self.app.get('/v1/actionalias?pack=doesntexist')
+ resp = self.app.get("/v1/actionalias?pack=doesntexist")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/actionalias?pack=aliases')
+ resp = self.app.get("/v1/actionalias?pack=aliases")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 3)
for alias_api in resp.json:
- self.assertEqual(alias_api['pack'], 'aliases')
+ self.assertEqual(alias_api["pack"], "aliases")
- resp = self.app.get('/v1/actionalias?pack=generic')
+ resp = self.app.get("/v1/actionalias?pack=generic")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
for alias_api in resp.json:
- self.assertEqual(alias_api['pack'], 'generic')
+ self.assertEqual(alias_api["pack"], "generic")
- resp = self.app.get('/v1/actionalias?name=doesntexist')
+ resp = self.app.get("/v1/actionalias?name=doesntexist")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/actionalias?name=alias2')
+ resp = self.app.get("/v1/actionalias?name=alias2")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['name'], 'alias2')
+ self.assertEqual(resp.json[0]["name"], "alias2")
def test_get_one(self):
- resp = self.app.get('/v1/actionalias/%s' % self.alias1.id)
+ resp = self.app.get("/v1/actionalias/%s" % self.alias1.id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['name'], self.alias1.name,
- 'Incorrect aliases retrieved.')
+ self.assertEqual(
+ resp.json["name"], self.alias1.name, "Incorrect aliases retrieved."
+ )
def test_post_delete(self):
post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3)))
self.assertEqual(post_resp.status_int, 201)
- get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id'])
+ get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"])
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['name'], self.alias3.name,
- 'Incorrect aliases retrieved.')
+ self.assertEqual(
+ get_resp.json["name"], self.alias3.name, "Incorrect aliases retrieved."
+ )
- del_resp = self.__do_delete(post_resp.json['id'])
+ del_resp = self.__do_delete(post_resp.json["id"])
self.assertEqual(del_resp.status_int, 204)
- get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id'], expect_errors=True)
+ get_resp = self.app.get(
+ "/v1/actionalias/%s" % post_resp.json["id"], expect_errors=True
+ )
self.assertEqual(get_resp.status_int, 404)
def test_update_existing_alias(self):
post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3)))
self.assertEqual(post_resp.status_int, 201)
- self.assertEqual(post_resp.json['name'], self.alias3['name'])
+ self.assertEqual(post_resp.json["name"], self.alias3["name"])
data = vars(ActionAliasAPI.from_model(self.alias3))
- data['name'] = 'updated-alias-name'
+ data["name"] = "updated-alias-name"
- put_resp = self.app.put_json('/v1/actionalias/%s' % post_resp.json['id'], data)
- self.assertEqual(put_resp.json['name'], data['name'])
+ put_resp = self.app.put_json("/v1/actionalias/%s" % post_resp.json["id"], data)
+ self.assertEqual(put_resp.json["name"], data["name"])
- get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id'])
- self.assertEqual(get_resp.json['name'], data['name'])
+ get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"])
+ self.assertEqual(get_resp.json["name"], data["name"])
- del_resp = self.__do_delete(post_resp.json['id'])
+ del_resp = self.__do_delete(post_resp.json["id"])
self.assertEqual(del_resp.status_int, 204)
def test_post_dup_name(self):
post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3)))
self.assertEqual(post_resp.status_int, 201)
- post_resp_dup_name = self._do_post(vars(ActionAliasAPI.from_model(self.alias3_generic)))
+ post_resp_dup_name = self._do_post(
+ vars(ActionAliasAPI.from_model(self.alias3_generic))
+ )
self.assertEqual(post_resp_dup_name.status_int, 201)
- self.__do_delete(post_resp.json['id'])
- self.__do_delete(post_resp_dup_name.json['id'])
+ self.__do_delete(post_resp.json["id"])
+ self.__do_delete(post_resp_dup_name.json["id"])
def test_match(self):
# No matching patterns
- data = {'command': 'hello donny'}
+ data = {"command": "hello donny"}
resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(str(resp.json['faultstring']),
- "Command 'hello donny' matched no patterns")
+ self.assertEqual(
+ str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns"
+ )
# More than one matching pattern
- data = {'command': 'Lorem ipsum banana dolor sit pineapple amet.'}
+ data = {"command": "Lorem ipsum banana dolor sit pineapple amet."}
resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(str(resp.json['faultstring']),
- "Command 'Lorem ipsum banana dolor sit pineapple amet.' "
- "matched more than 1 pattern")
+ self.assertEqual(
+ str(resp.json["faultstring"]),
+ "Command 'Lorem ipsum banana dolor sit pineapple amet.' "
+ "matched more than 1 pattern",
+ )
# Single matching pattern - success
- data = {'command': 'run whoami on localhost1'}
+ data = {"command": "run whoami on localhost1"}
resp = self.app.post_json("/v1/actionalias/match", data)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['actionalias']['name'],
- 'alias_with_undefined_jinja_in_ack_format')
+ self.assertEqual(
+ resp.json["actionalias"]["name"], "alias_with_undefined_jinja_in_ack_format"
+ )
def test_help(self):
resp = self.app.get("/v1/actionalias/help")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json.get('available'), 5)
+ self.assertEqual(resp.json.get("available"), 5)
def test_help_args(self):
- resp = self.app.get("/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0")
+ resp = self.app.get(
+ "/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0"
+ )
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json.get('available'), 3)
- self.assertEqual(len(resp.json.get('helpstrings')), 1)
+ self.assertEqual(resp.json.get("available"), 3)
+ self.assertEqual(len(resp.json.get("helpstrings")), 1)
def _insert_mock_models(self):
- alias_ids = [self.alias1['id'], self.alias2['id'], self.alias3['id'],
- self.alias3_generic['id']]
+ alias_ids = [
+ self.alias1["id"],
+ self.alias2["id"],
+ self.alias3["id"],
+ self.alias3_generic["id"],
+ ]
return alias_ids
def _delete_mock_models(self, object_ids):
return None
def _do_post(self, actionalias, expect_errors=False):
- return self.app.post_json('/v1/actionalias', actionalias, expect_errors=expect_errors)
+ return self.app.post_json(
+ "/v1/actionalias", actionalias, expect_errors=expect_errors
+ )
def __do_delete(self, actionalias_id, expect_errors=False):
- return self.app.delete('/v1/actionalias/%s' % actionalias_id, expect_errors=expect_errors)
+ return self.app.delete(
+ "/v1/actionalias/%s" % actionalias_id, expect_errors=expect_errors
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_action_views.py b/st2api/tests/unit/controllers/v1/test_action_views.py
index dbb9346662..a28219c04d 100644
--- a/st2api/tests/unit/controllers/v1/test_action_views.py
+++ b/st2api/tests/unit/controllers/v1/test_action_views.py
@@ -25,42 +25,44 @@
# ACTION_1: Good action definition.
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'entry_point': 'test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "entry_point": "test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_2: Good action definition. No content pack.
ACTION_2 = {
- 'name': 'st2.dummy.action2',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'entry_point': 'test/action2.py',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'c': {'type': 'string', 'default': 'C1', 'position': 0},
- 'd': {'type': 'string', 'default': 'D1', 'immutable': True}
- }
+ "name": "st2.dummy.action2",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "entry_point": "test/action2.py",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "c": {"type": "string", "default": "C1", "position": 0},
+ "d": {"type": "string", "default": "D1", "immutable": True},
+ },
}
-class ActionViewsOverviewControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/actions/views/overview'
+class ActionViewsOverviewControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/actions/views/overview"
controller_cls = OverviewController
- include_attribute_field_name = 'entry_point'
- exclude_attribute_field_name = 'parameters'
+ include_attribute_field_name = "entry_point"
+ exclude_attribute_field_name = "parameters"
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one(self):
post_resp = self._do_post(ACTION_1)
action_id = self._get_action_id(post_resp)
@@ -71,8 +73,9 @@ def test_get_one(self):
finally:
self._do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_ref(self):
post_resp = self._do_post(ACTION_1)
action_id = self._get_action_id(post_resp)
@@ -80,66 +83,85 @@ def test_get_one_ref(self):
try:
get_resp = self._do_get_one(action_ref)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['ref'], action_ref)
+ self.assertEqual(get_resp.json["ref"], action_ref)
finally:
self._do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_and_limit_minus_one(self):
action_1_id = self._get_action_id(self._do_post(ACTION_1))
action_2_id = self._get_action_id(self._do_post(ACTION_2))
try:
- resp = self.app.get('/v1/actions/views/overview')
+ resp = self.app.get("/v1/actions/views/overview")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 2,
- '/v1/actions/views/overview did not return all actions.')
- resp = self.app.get('/v1/actions/views/overview/?limit=-1')
+ self.assertEqual(
+ len(resp.json),
+ 2,
+ "/v1/actions/views/overview did not return all actions.",
+ )
+ resp = self.app.get("/v1/actions/views/overview/?limit=-1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 2,
- '/v1/actions/views/overview did not return all actions.')
+ self.assertEqual(
+ len(resp.json),
+ 2,
+ "/v1/actions/views/overview did not return all actions.",
+ )
finally:
self._do_delete(action_1_id)
self._do_delete(action_2_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_negative_limit(self):
action_1_id = self._get_action_id(self._do_post(ACTION_1))
action_2_id = self._get_action_id(self._do_post(ACTION_2))
try:
- resp = self.app.get('/v1/actions/views/overview/?limit=-22', expect_errors=True)
+ resp = self.app.get(
+ "/v1/actions/views/overview/?limit=-22", expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
finally:
self._do_delete(action_1_id)
self._do_delete(action_2_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_filter_by_name(self):
action_1_id = self._get_action_id(self._do_post(ACTION_1))
action_2_id = self._get_action_id(self._do_post(ACTION_2))
try:
- resp = self.app.get('/v1/actions/views/overview?name=%s' % str('st2.dummy.action2'))
+ resp = self.app.get(
+ "/v1/actions/views/overview?name=%s" % str("st2.dummy.action2")
+ )
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json[0]['id'], action_2_id, 'Filtering failed')
+ self.assertEqual(resp.json[0]["id"], action_2_id, "Filtering failed")
finally:
self._do_delete(action_1_id)
self._do_delete(action_2_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_include_attributes_filter(self):
- return super(ActionViewsOverviewControllerTestCase, self) \
- .test_get_all_include_attributes_filter()
+ return super(
+ ActionViewsOverviewControllerTestCase, self
+ ).test_get_all_include_attributes_filter()
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_exclude_attributes_filter(self):
- return super(ActionViewsOverviewControllerTestCase, self) \
- .test_get_all_include_attributes_filter()
+ return super(
+ ActionViewsOverviewControllerTestCase, self
+ ).test_get_all_include_attributes_filter()
def _insert_mock_models(self):
action_1_id = self._get_action_id(self._do_post(ACTION_1))
@@ -149,115 +171,141 @@ def _insert_mock_models(self):
@staticmethod
def _get_action_id(resp):
- return resp.json['id']
+ return resp.json["id"]
@staticmethod
def _get_action_ref(resp):
- return '.'.join((resp.json['pack'], resp.json['name']))
+ return ".".join((resp.json["pack"], resp.json["name"]))
@staticmethod
def _get_action_name(resp):
- return resp.json['name']
+ return resp.json["name"]
def _do_get_one(self, action_id, expect_errors=False):
- return self.app.get('/v1/actions/views/overview/%s' % action_id,
- expect_errors=expect_errors)
+ return self.app.get(
+ "/v1/actions/views/overview/%s" % action_id, expect_errors=expect_errors
+ )
def _do_post(self, action, expect_errors=False):
- return self.app.post_json('/v1/actions', action, expect_errors=expect_errors)
+ return self.app.post_json("/v1/actions", action, expect_errors=expect_errors)
def _do_delete(self, action_id, expect_errors=False):
- return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors)
+ return self.app.delete(
+ "/v1/actions/%s" % action_id, expect_errors=expect_errors
+ )
class ActionViewsParametersControllerTestCase(FunctionalTest):
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
try:
- get_resp = self.app.get('/v1/actions/views/parameters/%s' % action_id)
+ get_resp = self.app.get("/v1/actions/views/parameters/%s" % action_id)
self.assertEqual(get_resp.status_int, 200)
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
class ActionEntryPointViewControllerTestCase(FunctionalTest):
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
- @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock(
- return_value='/path/to/file'))
- @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True)
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ content_utils,
+ "get_entry_point_abs_path",
+ mock.MagicMock(return_value="/path/to/file"),
+ )
+ @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True)
def test_get_one(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
try:
- get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_id)
+ get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_id)
self.assertEqual(get_resp.status_int, 200)
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
- @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock(
- return_value='/path/to/file'))
- @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True)
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ content_utils,
+ "get_entry_point_abs_path",
+ mock.MagicMock(return_value="/path/to/file"),
+ )
+ @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True)
def test_get_one_ref(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
- action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name']))
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
+ action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"]))
try:
- get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref)
+ get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref)
self.assertEqual(get_resp.status_int, 200)
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
- @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock(
- return_value='/path/to/file.yaml'))
- @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True)
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ content_utils,
+ "get_entry_point_abs_path",
+ mock.MagicMock(return_value="/path/to/file.yaml"),
+ )
+ @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True)
def test_get_one_ref_yaml_content_type(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
- action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name']))
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
+ action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"]))
try:
- get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref)
+ get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.headers['Content-Type'], 'application/x-yaml')
+ self.assertEqual(get_resp.headers["Content-Type"], "application/x-yaml")
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
- @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock(
- return_value=__file__.replace('.pyc', '.py')))
- @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True)
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ content_utils,
+ "get_entry_point_abs_path",
+ mock.MagicMock(return_value=__file__.replace(".pyc", ".py")),
+ )
+ @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True)
def test_get_one_ref_python_content_type(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
- action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name']))
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
+ action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"]))
try:
- get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref)
+ get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref)
self.assertEqual(get_resp.status_int, 200)
- self.assertIn(get_resp.headers['Content-Type'], ['application/x-python',
- 'text/x-python'])
+ self.assertIn(
+ get_resp.headers["Content-Type"],
+ ["application/x-python", "text/x-python"],
+ )
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
- @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock(
- return_value='/file/does/not/exist'))
- @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True)
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ content_utils,
+ "get_entry_point_abs_path",
+ mock.MagicMock(return_value="/file/does/not/exist"),
+ )
+ @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True)
def test_get_one_ref_text_plain_content_type(self):
- post_resp = self.app.post_json('/v1/actions', ACTION_1)
- action_id = post_resp.json['id']
- action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name']))
+ post_resp = self.app.post_json("/v1/actions", ACTION_1)
+ action_id = post_resp.json["id"]
+ action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"]))
try:
- get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref)
+ get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.headers['Content-Type'], 'text/plain')
+ self.assertEqual(get_resp.headers["Content-Type"], "text/plain")
finally:
- self.app.delete('/v1/actions/%s' % action_id)
+ self.app.delete("/v1/actions/%s" % action_id)
diff --git a/st2api/tests/unit/controllers/v1/test_actions.py b/st2api/tests/unit/controllers/v1/test_actions.py
index c189803d2c..40e973c0a1 100644
--- a/st2api/tests/unit/controllers/v1/test_actions.py
+++ b/st2api/tests/unit/controllers/v1/test_actions.py
@@ -41,257 +41,259 @@
# ACTION_1: Good action definition.
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
},
- 'tags': [
- {'name': 'tag1', 'value': 'dont-care'},
- {'name': 'tag2', 'value': 'dont-care'}
- ]
+ "tags": [
+ {"name": "tag1", "value": "dont-care"},
+ {"name": "tag2", "value": "dont-care"},
+ ],
}
# ACTION_2: Good action definition. No content pack.
ACTION_2 = {
- 'name': 'st2.dummy.action2',
- 'description': 'test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action2.py',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'c': {'type': 'string', 'default': 'C1', 'position': 0},
- 'd': {'type': 'string', 'default': 'D1', 'immutable': True}
- }
+ "name": "st2.dummy.action2",
+ "description": "test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action2.py",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "c": {"type": "string", "default": "C1", "position": 0},
+ "d": {"type": "string", "default": "D1", "immutable": True},
+ },
}
# ACTION_3: No enabled field
ACTION_3 = {
- 'name': 'st2.dummy.action3',
- 'description': 'test description',
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action3",
+ "description": "test description",
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_4: Enabled field is False
ACTION_4 = {
- 'name': 'st2.dummy.action4',
- 'description': 'test description',
- 'enabled': False,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action4",
+ "description": "test description",
+ "enabled": False,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_5: Invalid runner_type
ACTION_5 = {
- 'name': 'st2.dummy.action5',
- 'description': 'test description',
- 'enabled': False,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'xyzxyz',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action5",
+ "description": "test description",
+ "enabled": False,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "xyzxyz",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_6: No description field.
ACTION_6 = {
- 'name': 'st2.dummy.action6',
- 'enabled': False,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action6",
+ "enabled": False,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_7: id field provided
ACTION_7 = {
- 'id': 'foobar',
- 'name': 'st2.dummy.action7',
- 'description': 'test description',
- 'enabled': False,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "id": "foobar",
+ "name": "st2.dummy.action7",
+ "description": "test description",
+ "enabled": False,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_8: id field provided
ACTION_8 = {
- 'name': 'st2.dummy.action8',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'cmd': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action8",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "cmd": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# ACTION_9: Parameter dict has fields not part of JSONSchema spec.
ACTION_9 = {
- 'name': 'st2.dummy.action9',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1', 'dummyfield': True}, # dummyfield is invalid.
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action9",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {
+ "type": "string",
+ "default": "A1",
+ "dummyfield": True,
+ }, # dummyfield is invalid.
+ "b": {"type": "string", "default": "B1"},
+ },
}
# Same name as ACTION_1. Different pack though.
# Ensure that this remains the only action with pack == wolfpack1,
# otherwise take care of the test test_get_one_using_pack_parameter
ACTION_10 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'wolfpack1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "pack": "wolfpack1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
# Good action with a system pack
ACTION_11 = {
- 'name': 'st2.dummy.action11',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action2.py',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'c': {'type': 'string', 'default': 'C1', 'position': 0},
- 'd': {'type': 'string', 'default': 'D1', 'immutable': True}
- }
+ "name": "st2.dummy.action11",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action2.py",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "c": {"type": "string", "default": "C1", "position": 0},
+ "d": {"type": "string", "default": "D1", "immutable": True},
+ },
}
# Good action inside dummy pack
ACTION_12 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'dummy_pack_1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "pack": "dummy_pack_1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
},
- 'tags': [
- {'name': 'tag1', 'value': 'dont-care'},
- {'name': 'tag2', 'value': 'dont-care'}
- ]
+ "tags": [
+ {"name": "tag1", "value": "dont-care"},
+ {"name": "tag2", "value": "dont-care"},
+ ],
}
# Action with invalid parameter type attribute
ACTION_13 = {
- 'name': 'st2.dummy.action2',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'dummy_pack_1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': ['string', 'object'], 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'}
- }
+ "name": "st2.dummy.action2",
+ "description": "test description",
+ "enabled": True,
+ "pack": "dummy_pack_1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": ["string", "object"], "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ },
}
ACTION_14 = {
- 'name': 'st2.dummy.action14',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'dummy_pack_1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'},
- 'sudo': {'type': 'string'}
- }
+ "name": "st2.dummy.action14",
+ "description": "test description",
+ "enabled": True,
+ "pack": "dummy_pack_1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ "sudo": {"type": "string"},
+ },
}
ACTION_15 = {
- 'name': 'st2.dummy.action15',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'dummy_pack_1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'},
- 'sudo': {'default': True, 'immutable': True}
- }
+ "name": "st2.dummy.action15",
+ "description": "test description",
+ "enabled": True,
+ "pack": "dummy_pack_1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ "sudo": {"default": True, "immutable": True},
+ },
}
ACTION_WITH_NOTIFY = {
- 'name': 'st2.dummy.action_notify_test',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'dummy_pack_1',
- 'entry_point': '/tmp/test/action1.sh',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'a': {'type': 'string', 'default': 'A1'},
- 'b': {'type': 'string', 'default': 'B1'},
- 'sudo': {'default': True, 'immutable': True}
+ "name": "st2.dummy.action_notify_test",
+ "description": "test description",
+ "enabled": True,
+ "pack": "dummy_pack_1",
+ "entry_point": "/tmp/test/action1.sh",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "a": {"type": "string", "default": "A1"},
+ "b": {"type": "string", "default": "B1"},
+ "sudo": {"default": True, "immutable": True},
},
- 'notify': {
- 'on-complete': {
- 'message': 'Woohoo! I completed!!!'
- }
- }
+ "notify": {"on-complete": {"message": "Woohoo! I completed!!!"}},
}
-class ActionsControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase,
- CleanFilesTestCase):
- get_all_path = '/v1/actions'
+class ActionsControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase, CleanFilesTestCase
+):
+ get_all_path = "/v1/actions"
controller_cls = ActionsController
- include_attribute_field_name = 'entry_point'
- exclude_attribute_field_name = 'parameters'
+ include_attribute_field_name = "entry_point"
+ exclude_attribute_field_name = "parameters"
register_packs = True
to_delete_files = [
- os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1/actions/filea.txt')
+ os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1/actions/filea.txt")
]
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_using_id(self):
post_resp = self.__do_post(ACTION_1)
action_id = self.__get_action_id(post_resp)
@@ -300,146 +302,169 @@ def test_get_one_using_id(self):
self.assertEqual(self.__get_action_id(get_resp), action_id)
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_using_ref(self):
- ref = '.'.join([ACTION_1['pack'], ACTION_1['name']])
+ ref = ".".join([ACTION_1["pack"], ACTION_1["name"]])
action_id = self.__get_action_id(self.__do_post(ACTION_1))
get_resp = self.__do_get_one(ref)
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_action_id(get_resp), action_id)
- self.assertEqual(get_resp.json['ref'], ref)
+ self.assertEqual(get_resp.json["ref"], ref)
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_validate_params(self):
post_resp = self.__do_post(ACTION_1)
action_id = self.__get_action_id(post_resp)
get_resp = self.__do_get_one(action_id)
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_action_id(get_resp), action_id)
- expected_args = ACTION_1['parameters']
- self.assertEqual(get_resp.json['parameters'], expected_args)
+ expected_args = ACTION_1["parameters"]
+ self.assertEqual(get_resp.json["parameters"], expected_args)
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_and_with_minus_one(self):
- action_1_ref = '.'.join([ACTION_1['pack'], ACTION_1['name']])
+ action_1_ref = ".".join([ACTION_1["pack"], ACTION_1["name"]])
action_1_id = self.__get_action_id(self.__do_post(ACTION_1))
action_2_id = self.__get_action_id(self.__do_post(ACTION_2))
- resp = self.app.get('/v1/actions')
+ resp = self.app.get("/v1/actions")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.')
+ self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.")
- item = [i for i in resp.json if i['id'] == action_1_id][0]
- self.assertEqual(item['ref'], action_1_ref)
+ item = [i for i in resp.json if i["id"] == action_1_id][0]
+ self.assertEqual(item["ref"], action_1_ref)
- resp = self.app.get('/v1/actions?limit=-1')
+ resp = self.app.get("/v1/actions?limit=-1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.')
+ self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.")
- item = [i for i in resp.json if i['id'] == action_1_id][0]
- self.assertEqual(item['ref'], action_1_ref)
+ item = [i for i in resp.json if i["id"] == action_1_id][0]
+ self.assertEqual(item["ref"], action_1_ref)
self.__do_delete(action_1_id)
self.__do_delete(action_2_id)
- @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin',
- mock.Mock(return_value=False))
+ @mock.patch(
+ "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin",
+ mock.Mock(return_value=False),
+ )
def test_get_all_invalid_limit_too_large_none_admin(self):
# limit > max_page_size, but user is not admin
- resp = self.app.get('/v1/actions?limit=1000', expect_errors=True)
+ resp = self.app.get("/v1/actions?limit=1000", expect_errors=True)
self.assertEqual(resp.status_int, http_client.FORBIDDEN)
- self.assertEqual(resp.json['faultstring'], 'Limit "1000" specified, maximum value is'
- ' "100"')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit "1000" specified, maximum value is' ' "100"',
+ )
def test_get_all_limit_negative_number(self):
- resp = self.app.get('/v1/actions?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/actions?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
-
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
+
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_include_attributes_filter(self):
- return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter()
+ return super(
+ ActionsControllerTestCase, self
+ ).test_get_all_include_attributes_filter()
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_all_exclude_attributes_filter(self):
- return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter()
+ return super(
+ ActionsControllerTestCase, self
+ ).test_get_all_include_attributes_filter()
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_query(self):
action_1_id = self.__get_action_id(self.__do_post(ACTION_1))
action_2_id = self.__get_action_id(self.__do_post(ACTION_2))
- resp = self.app.get('/v1/actions?name=%s' % ACTION_1['name'])
+ resp = self.app.get("/v1/actions?name=%s" % ACTION_1["name"])
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1, '/v1/actions did not return all actions.')
+ self.assertEqual(len(resp.json), 1, "/v1/actions did not return all actions.")
self.__do_delete(action_1_id)
self.__do_delete(action_2_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_fail(self):
- resp = self.app.get('/v1/actions/1', expect_errors=True)
+ resp = self.app.get("/v1/actions/1", expect_errors=True)
self.assertEqual(resp.status_int, 404)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_delete(self):
post_resp = self.__do_post(ACTION_1)
self.assertEqual(post_resp.status_int, 201)
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_action_with_bad_params(self):
post_resp = self.__do_post(ACTION_9, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_no_description_field(self):
post_resp = self.__do_post(ACTION_6)
self.assertEqual(post_resp.status_int, 201)
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_no_enable_field(self):
post_resp = self.__do_post(ACTION_3)
self.assertEqual(post_resp.status_int, 201)
- self.assertIn(b'enabled', post_resp.body)
+ self.assertIn(b"enabled", post_resp.body)
# If enabled field is not provided it should default to True
data = json.loads(post_resp.body)
- self.assertDictContainsSubset({'enabled': True}, data)
+ self.assertDictContainsSubset({"enabled": True}, data)
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_false_enable_field(self):
post_resp = self.__do_post(ACTION_4)
self.assertEqual(post_resp.status_int, 201)
data = json.loads(post_resp.body)
- self.assertDictContainsSubset({'enabled': False}, data)
+ self.assertDictContainsSubset({"enabled": False}, data)
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_name_unicode_action_already_exists(self):
# Verify that exception messages containing unicode characters don't result in internal
# server errors
action = copy.deepcopy(ACTION_1)
# NOTE: We explicitly don't prefix this string value with u""
- action['name'] = 'žactionćšžži💩'
+ action["name"] = "žactionćšžži💩"
# 1. Initial creation
post_resp = self.__do_post(action, expect_errors=True)
@@ -448,54 +473,64 @@ def test_post_name_unicode_action_already_exists(self):
# 2. Action already exists
post_resp = self.__do_post(action, expect_errors=True)
self.assertEqual(post_resp.status_int, 409)
- self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring'])
+ self.assertIn(
+ "Tried to save duplicate unique keys", post_resp.json["faultstring"]
+ )
# 3. Action already exists (this time with unicode type)
- action['name'] = u'žactionćšžži💩'
+ action["name"] = "žactionćšžži💩"
post_resp = self.__do_post(action, expect_errors=True)
self.assertEqual(post_resp.status_int, 409)
- self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring'])
+ self.assertIn(
+ "Tried to save duplicate unique keys", post_resp.json["faultstring"]
+ )
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_parameter_type_is_array_and_invalid(self):
post_resp = self.__do_post(ACTION_13, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
if six.PY3:
- expected_error = b'[\'string\', \'object\'] is not valid under any of the given schemas'
+ expected_error = (
+ b"['string', 'object'] is not valid under any of the given schemas"
+ )
else:
- expected_error = \
- b'[u\'string\', u\'object\'] is not valid under any of the given schemas'
+ expected_error = (
+ b"[u'string', u'object'] is not valid under any of the given schemas"
+ )
self.assertIn(expected_error, post_resp.body)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_discard_id_field(self):
post_resp = self.__do_post(ACTION_7)
self.assertEqual(post_resp.status_int, 201)
- self.assertIn(b'id', post_resp.body)
+ self.assertIn(b"id", post_resp.body)
data = json.loads(post_resp.body)
# Verify that user-provided id is discarded.
- self.assertNotEquals(data['id'], ACTION_7['id'])
+ self.assertNotEquals(data["id"], ACTION_7["id"])
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_duplicate(self):
action_ids = []
post_resp = self.__do_post(ACTION_1)
self.assertEqual(post_resp.status_int, 201)
- action_in_db = Action.get_by_name(ACTION_1.get('name'))
- self.assertIsNotNone(action_in_db, 'Action must be in db.')
+ action_in_db = Action.get_by_name(ACTION_1.get("name"))
+ self.assertIsNotNone(action_in_db, "Action must be in db.")
action_ids.append(self.__get_action_id(post_resp))
post_resp = self.__do_post(ACTION_1, expect_errors=True)
# Verify name conflict
self.assertEqual(post_resp.status_int, 409)
- self.assertEqual(post_resp.json['conflict-id'], action_ids[0])
+ self.assertEqual(post_resp.json["conflict-id"], action_ids[0])
post_resp = self.__do_post(ACTION_10)
action_ids.append(self.__get_action_id(post_resp))
@@ -505,20 +540,16 @@ def test_post_duplicate(self):
for i in action_ids:
self.__do_delete(i)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_include_files(self):
# Verify initial state
- pack_db = Pack.get_by_ref(ACTION_12['pack'])
- self.assertNotIn('actions/filea.txt', pack_db.files)
+ pack_db = Pack.get_by_ref(ACTION_12["pack"])
+ self.assertNotIn("actions/filea.txt", pack_db.files)
action = copy.deepcopy(ACTION_12)
- action['data_files'] = [
- {
- 'file_path': 'filea.txt',
- 'content': 'test content'
- }
- ]
+ action["data_files"] = [{"file_path": "filea.txt", "content": "test content"}]
post_resp = self.__do_post(action)
# Verify file has been written on disk
@@ -526,29 +557,30 @@ def test_post_include_files(self):
self.assertTrue(os.path.exists(file_path))
# Verify PackDB.files has been updated
- pack_db = Pack.get_by_ref(ACTION_12['pack'])
- self.assertIn('actions/filea.txt', pack_db.files)
+ pack_db = Pack.get_by_ref(ACTION_12["pack"])
+ self.assertIn("actions/filea.txt", pack_db.files)
self.__do_delete(self.__get_action_id(post_resp))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_post_put_delete(self):
action = copy.copy(ACTION_1)
post_resp = self.__do_post(action)
self.assertEqual(post_resp.status_int, 201)
- self.assertIn(b'id', post_resp.body)
+ self.assertIn(b"id", post_resp.body)
body = json.loads(post_resp.body)
- action['id'] = body['id']
- action['description'] = 'some other test description'
- pack = action['pack']
- del action['pack']
- self.assertNotIn('pack', action)
- put_resp = self.__do_put(action['id'], action)
+ action["id"] = body["id"]
+ action["description"] = "some other test description"
+ pack = action["pack"]
+ del action["pack"]
+ self.assertNotIn("pack", action)
+ put_resp = self.__do_put(action["id"], action)
self.assertEqual(put_resp.status_int, 200)
- self.assertIn(b'description', put_resp.body)
+ self.assertIn(b"description", put_resp.body)
body = json.loads(put_resp.body)
- self.assertEqual(body['description'], action['description'])
- self.assertEqual(body['pack'], pack)
+ self.assertEqual(body["description"], action["description"])
+ self.assertEqual(body["pack"], pack)
delete_resp = self.__do_delete(self.__get_action_id(post_resp))
self.assertEqual(delete_resp.status_int, 204)
@@ -559,94 +591,107 @@ def test_post_invalid_runner_type(self):
def test_post_override_runner_param_not_allowed(self):
post_resp = self.__do_post(ACTION_14, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- expected = ('The attribute "type" for the runner parameter "sudo" '
- 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.')
- self.assertEqual(post_resp.json.get('faultstring'), expected)
+ expected = (
+ 'The attribute "type" for the runner parameter "sudo" '
+ 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.'
+ )
+ self.assertEqual(post_resp.json.get("faultstring"), expected)
def test_post_override_runner_param_allowed(self):
post_resp = self.__do_post(ACTION_15)
self.assertEqual(post_resp.status_int, 201)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_delete(self):
post_resp = self.__do_post(ACTION_1)
del_resp = self.__do_delete(self.__get_action_id(post_resp))
self.assertEqual(del_resp.status_int, 204)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_action_with_tags(self):
post_resp = self.__do_post(ACTION_1)
action_id = self.__get_action_id(post_resp)
get_resp = self.__do_get_one(action_id)
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_action_id(get_resp), action_id)
- self.assertEqual(get_resp.json['tags'], ACTION_1['tags'])
+ self.assertEqual(get_resp.json["tags"], ACTION_1["tags"])
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_action_with_notify_update(self):
post_resp = self.__do_post(ACTION_WITH_NOTIFY)
action_id = self.__get_action_id(post_resp)
get_resp = self.__do_get_one(action_id)
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_action_id(get_resp), action_id)
- self.assertIsNotNone(get_resp.json['notify']['on-complete'])
+ self.assertIsNotNone(get_resp.json["notify"]["on-complete"])
# Now post the same action with no notify
ACTION_WITHOUT_NOTIFY = copy.copy(ACTION_WITH_NOTIFY)
- del ACTION_WITHOUT_NOTIFY['notify']
+ del ACTION_WITHOUT_NOTIFY["notify"]
self.__do_put(action_id, ACTION_WITHOUT_NOTIFY)
# Validate that notify section has vanished
get_resp = self.__do_get_one(action_id)
- self.assertEqual(get_resp.json['notify'], {})
+ self.assertEqual(get_resp.json["notify"], {})
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_using_name_parameter(self):
action_id, action_name = self.__get_action_id_and_additional_attribute(
- self.__do_post(ACTION_1), 'name')
- get_resp = self.__do_get_actions_by_url_parameter('name', action_name)
+ self.__do_post(ACTION_1), "name"
+ )
+ get_resp = self.__do_get_actions_by_url_parameter("name", action_name)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json[0]['id'], action_id)
- self.assertEqual(get_resp.json[0]['name'], action_name)
+ self.assertEqual(get_resp.json[0]["id"], action_id)
+ self.assertEqual(get_resp.json[0]["name"], action_name)
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_using_pack_parameter(self):
action_id, action_pack = self.__get_action_id_and_additional_attribute(
- self.__do_post(ACTION_10), 'pack')
- get_resp = self.__do_get_actions_by_url_parameter('pack', action_pack)
+ self.__do_post(ACTION_10), "pack"
+ )
+ get_resp = self.__do_get_actions_by_url_parameter("pack", action_pack)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json[0]['id'], action_id)
- self.assertEqual(get_resp.json[0]['pack'], action_pack)
+ self.assertEqual(get_resp.json[0]["id"], action_id)
+ self.assertEqual(get_resp.json[0]["pack"], action_pack)
self.__do_delete(action_id)
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def test_get_one_using_tag_parameter(self):
action_id, action_tags = self.__get_action_id_and_additional_attribute(
- self.__do_post(ACTION_1), 'tags')
- get_resp = self.__do_get_actions_by_url_parameter('tags', action_tags[0]['name'])
+ self.__do_post(ACTION_1), "tags"
+ )
+ get_resp = self.__do_get_actions_by_url_parameter(
+ "tags", action_tags[0]["name"]
+ )
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json[0]['id'], action_id)
- self.assertEqual(get_resp.json[0]['tags'], action_tags)
+ self.assertEqual(get_resp.json[0]["id"], action_id)
+ self.assertEqual(get_resp.json[0]["tags"], action_tags)
self.__do_delete(action_id)
# TODO: Re-enable those tests after we ensure DB is flushed in setUp
# and each test starts in a clean state
- @unittest2.skip('Skip because of test polution')
+ @unittest2.skip("Skip because of test polution")
def test_update_action_belonging_to_system_pack(self):
post_resp = self.__do_post(ACTION_11)
action_id = self.__get_action_id(post_resp)
put_resp = self.__do_put(action_id, ACTION_11, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- @unittest2.skip('Skip because of test polution')
+ @unittest2.skip("Skip because of test polution")
def test_delete_action_belonging_to_system_pack(self):
post_resp = self.__do_post(ACTION_11)
action_id = self.__get_action_id(post_resp)
@@ -664,31 +709,37 @@ def _do_delete(self, action_id, expect_errors=False):
@staticmethod
def __get_action_id(resp):
- return resp.json['id']
+ return resp.json["id"]
@staticmethod
def __get_action_name(resp):
- return resp.json['name']
+ return resp.json["name"]
@staticmethod
def __get_action_tags(resp):
- return resp.json['tags']
+ return resp.json["tags"]
@staticmethod
def __get_action_id_and_additional_attribute(resp, attribute):
- return resp.json['id'], resp.json[attribute]
+ return resp.json["id"], resp.json[attribute]
def __do_get_one(self, action_id, expect_errors=False):
- return self.app.get('/v1/actions/%s' % action_id, expect_errors=expect_errors)
+ return self.app.get("/v1/actions/%s" % action_id, expect_errors=expect_errors)
def __do_get_actions_by_url_parameter(self, filter, value, expect_errors=False):
- return self.app.get('/v1/actions?%s=%s' % (filter, value), expect_errors=expect_errors)
+ return self.app.get(
+ "/v1/actions?%s=%s" % (filter, value), expect_errors=expect_errors
+ )
def __do_post(self, action, expect_errors=False):
- return self.app.post_json('/v1/actions', action, expect_errors=expect_errors)
+ return self.app.post_json("/v1/actions", action, expect_errors=expect_errors)
def __do_put(self, action_id, action, expect_errors=False):
- return self.app.put_json('/v1/actions/%s' % action_id, action, expect_errors=expect_errors)
+ return self.app.put_json(
+ "/v1/actions/%s" % action_id, action, expect_errors=expect_errors
+ )
def __do_delete(self, action_id, expect_errors=False):
- return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors)
+ return self.app.delete(
+ "/v1/actions/%s" % action_id, expect_errors=expect_errors
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_alias_execution.py b/st2api/tests/unit/controllers/v1/test_alias_execution.py
index e7b1827f31..9806a46864 100644
--- a/st2api/tests/unit/controllers/v1/test_alias_execution.py
+++ b/st2api/tests/unit/controllers/v1/test_alias_execution.py
@@ -24,29 +24,32 @@
from st2tests.fixturesloader import FixturesLoader
from st2tests.api import FunctionalTest
-FIXTURES_PACK = 'aliases'
+FIXTURES_PACK = "aliases"
TEST_MODELS = {
- 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml',
- 'alias_with_immutable_list_param.yaml',
- 'alias_with_immutable_list_param_str_cast.yaml',
- 'alias4.yaml', 'alias5.yaml', 'alias_fixes1.yaml', 'alias_fixes2.yaml',
- 'alias_match_multiple.yaml'],
- 'actions': ['action1.yaml', 'action2.yaml', 'action3.yaml', 'action4.yaml'],
- 'runners': ['runner1.yaml']
+ "aliases": [
+ "alias1.yaml",
+ "alias2.yaml",
+ "alias_with_undefined_jinja_in_ack_format.yaml",
+ "alias_with_immutable_list_param.yaml",
+ "alias_with_immutable_list_param_str_cast.yaml",
+ "alias4.yaml",
+ "alias5.yaml",
+ "alias_fixes1.yaml",
+ "alias_fixes2.yaml",
+ "alias_match_multiple.yaml",
+ ],
+ "actions": ["action1.yaml", "action2.yaml", "action3.yaml", "action4.yaml"],
+ "runners": ["runner1.yaml"],
}
-TEST_LOAD_MODELS = {
- 'aliases': ['alias3.yaml']
-}
+TEST_LOAD_MODELS = {"aliases": ["alias3.yaml"]}
-EXECUTION = ActionExecutionDB(id='54e657d60640fd16887d6855',
- status=LIVEACTION_STATUS_SUCCEEDED,
- result='')
+EXECUTION = ActionExecutionDB(
+ id="54e657d60640fd16887d6855", status=LIVEACTION_STATUS_SUCCEEDED, result=""
+)
-__all__ = [
- 'AliasExecutionTestCase'
-]
+__all__ = ["AliasExecutionTestCase"]
class AliasExecutionTestCase(FunctionalTest):
@@ -59,193 +62,217 @@ class AliasExecutionTestCase(FunctionalTest):
@classmethod
def setUpClass(cls):
super(AliasExecutionTestCase, cls).setUpClass()
- cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
-
- cls.runner1 = cls.models['runners']['runner1.yaml']
- cls.action1 = cls.models['actions']['action1.yaml']
- cls.alias1 = cls.models['aliases']['alias1.yaml']
- cls.alias2 = cls.models['aliases']['alias2.yaml']
- cls.alias4 = cls.models['aliases']['alias4.yaml']
- cls.alias5 = cls.models['aliases']['alias5.yaml']
- cls.alias_with_undefined_jinja_in_ack_format = \
- cls.models['aliases']['alias_with_undefined_jinja_in_ack_format.yaml']
-
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ cls.models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+
+ cls.runner1 = cls.models["runners"]["runner1.yaml"]
+ cls.action1 = cls.models["actions"]["action1.yaml"]
+ cls.alias1 = cls.models["aliases"]["alias1.yaml"]
+ cls.alias2 = cls.models["aliases"]["alias2.yaml"]
+ cls.alias4 = cls.models["aliases"]["alias4.yaml"]
+ cls.alias5 = cls.models["aliases"]["alias5.yaml"]
+ cls.alias_with_undefined_jinja_in_ack_format = cls.models["aliases"][
+ "alias_with_undefined_jinja_in_ack_format.yaml"
+ ]
+
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_basic_execution(self, request):
command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.'
post_resp = self._do_post(alias_execution=self.alias1, command=command)
self.assertEqual(post_resp.status_int, 201)
- expected_parameters = {'param1': 'value1', 'param2': 'value2 value3'}
+ expected_parameters = {"param1": "value1", "param2": "value2 value3"}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_basic_execution_with_immutable_parameters(self, request):
- command = 'lorem ipsum'
+ command = "lorem ipsum"
post_resp = self._do_post(alias_execution=self.alias5, command=command)
self.assertEqual(post_resp.status_int, 201)
- expected_parameters = {'param1': 'value1', 'param2': 'value2'}
+ expected_parameters = {"param1": "value1", "param2": "value2"}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_invalid_format_string_referenced_in_request(self, request):
command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.'
- format_str = 'some invalid not supported string'
- post_resp = self._do_post(alias_execution=self.alias1, command=command,
- format_str=format_str, expect_errors=True)
+ format_str = "some invalid not supported string"
+ post_resp = self._do_post(
+ alias_execution=self.alias1,
+ command=command,
+ format_str=format_str,
+ expect_errors=True,
+ )
self.assertEqual(post_resp.status_int, 400)
- expected_msg = ('Format string "some invalid not supported string" is '
- 'not available on the alias "alias1"')
- self.assertIn(expected_msg, post_resp.json['faultstring'])
+ expected_msg = (
+ 'Format string "some invalid not supported string" is '
+ 'not available on the alias "alias1"'
+ )
+ self.assertIn(expected_msg, post_resp.json["faultstring"])
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_execution_with_array_type_single_value(self, request):
- command = 'Lorem ipsum value1 dolor sit value2 amet.'
+ command = "Lorem ipsum value1 dolor sit value2 amet."
self._do_post(alias_execution=self.alias2, command=command)
- expected_parameters = {'param1': 'value1', 'param3': ['value2']}
+ expected_parameters = {"param1": "value1", "param3": ["value2"]}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_execution_with_array_type_multi_value(self, request):
command = 'Lorem ipsum value1 dolor sit "value2, value3" amet.'
post_resp = self._do_post(alias_execution=self.alias2, command=command)
self.assertEqual(post_resp.status_int, 201)
- expected_parameters = {'param1': 'value1', 'param3': ['value2', 'value3']}
+ expected_parameters = {"param1": "value1", "param3": ["value2", "value3"]}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_invalid_jinja_var_in_ack_format(self, request):
- command = 'run date on localhost'
+ command = "run date on localhost"
# print(self.alias_with_undefined_jinja_in_ack_format)
post_resp = self._do_post(
alias_execution=self.alias_with_undefined_jinja_in_ack_format,
command=command,
- expect_errors=False
+ expect_errors=False,
)
self.assertEqual(post_resp.status_int, 201)
- expected_parameters = {'cmd': 'date', 'hosts': 'localhost'}
+ expected_parameters = {"cmd": "date", "hosts": "localhost"}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
self.assertEqual(
- post_resp.json['message'],
- 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined'
+ post_resp.json["message"],
+ 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined',
)
- @mock.patch.object(action_service, 'request')
+ @mock.patch.object(action_service, "request")
def test_execution_secret_parameter(self, request):
- execution = ActionExecutionDB(id='54e657d60640fd16887d6855',
- status=LIVEACTION_STATUS_SUCCEEDED,
- action={'parameters': self.action1.parameters},
- runner={'runner_parameters': self.runner1.runner_parameters},
- parameters={
- 'param4': SUPER_SECRET_PARAMETER
- },
- result='')
+ execution = ActionExecutionDB(
+ id="54e657d60640fd16887d6855",
+ status=LIVEACTION_STATUS_SUCCEEDED,
+ action={"parameters": self.action1.parameters},
+ runner={"runner_parameters": self.runner1.runner_parameters},
+ parameters={"param4": SUPER_SECRET_PARAMETER},
+ result="",
+ )
request.return_value = (None, execution)
- command = 'Lorem ipsum value1 dolor sit ' + SUPER_SECRET_PARAMETER + ' amet.'
+ command = "Lorem ipsum value1 dolor sit " + SUPER_SECRET_PARAMETER + " amet."
post_resp = self._do_post(alias_execution=self.alias4, command=command)
self.assertEqual(post_resp.status_int, 201)
- expected_parameters = {'param1': 'value1', 'param4': SUPER_SECRET_PARAMETER}
+ expected_parameters = {"param1": "value1", "param4": SUPER_SECRET_PARAMETER}
self.assertEqual(request.call_args[0][0].parameters, expected_parameters)
- post_resp = self._do_post(alias_execution=self.alias4, command=command, show_secrets=True,
- expect_errors=True)
+ post_resp = self._do_post(
+ alias_execution=self.alias4,
+ command=command,
+ show_secrets=True,
+ expect_errors=True,
+ )
self.assertEqual(post_resp.status_int, 201)
- self.assertEqual(post_resp.json['execution']['parameters']['param4'],
- SUPER_SECRET_PARAMETER)
+ self.assertEqual(
+ post_resp.json["execution"]["parameters"]["param4"], SUPER_SECRET_PARAMETER
+ )
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_match_and_execute_doesnt_match(self, mock_request):
base_data = {
- 'source_channel': 'chat',
- 'notification_route': 'hubot',
- 'user': 'chat-user'
+ "source_channel": "chat",
+ "notification_route": "hubot",
+ "user": "chat-user",
}
# Command doesnt match any patterns
data = copy.deepcopy(base_data)
- data['command'] = 'hello donny'
- resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True)
+ data["command"] = "hello donny"
+ resp = self.app.post_json(
+ "/v1/aliasexecution/match_and_execute", data, expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(str(resp.json['faultstring']),
- "Command 'hello donny' matched no patterns")
+ self.assertEqual(
+ str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns"
+ )
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_match_and_execute_matches_many(self, mock_request):
base_data = {
- 'source_channel': 'chat',
- 'notification_route': 'hubot',
- 'user': 'chat-user'
+ "source_channel": "chat",
+ "notification_route": "hubot",
+ "user": "chat-user",
}
# Command matches more than one pattern
data = copy.deepcopy(base_data)
- data['command'] = 'Lorem ipsum banana dolor sit pineapple amet.'
- resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True)
+ data["command"] = "Lorem ipsum banana dolor sit pineapple amet."
+ resp = self.app.post_json(
+ "/v1/aliasexecution/match_and_execute", data, expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(str(resp.json['faultstring']),
- "Command 'Lorem ipsum banana dolor sit pineapple amet.' "
- "matched more than 1 pattern")
+ self.assertEqual(
+ str(resp.json["faultstring"]),
+ "Command 'Lorem ipsum banana dolor sit pineapple amet.' "
+ "matched more than 1 pattern",
+ )
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_match_and_execute_matches_one(self, mock_request):
base_data = {
- 'source_channel': 'chat-channel',
- 'notification_route': 'hubot',
- 'user': 'chat-user',
+ "source_channel": "chat-channel",
+ "notification_route": "hubot",
+ "user": "chat-user",
}
# Command matches - should result in action execution
data = copy.deepcopy(base_data)
- data['command'] = 'run date on localhost'
- resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data)
+ data["command"] = "run date on localhost"
+ resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(len(resp.json['results']), 1)
- self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id']))
- self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status'])
+ self.assertEqual(len(resp.json["results"]), 1)
+ self.assertEqual(
+ resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"])
+ )
+ self.assertEqual(
+ resp.json["results"][0]["execution"]["status"], EXECUTION["status"]
+ )
- expected_parameters = {'cmd': 'date', 'hosts': 'localhost'}
+ expected_parameters = {"cmd": "date", "hosts": "localhost"}
self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters)
# Also check for source_channel - see
# https://github.com/StackStorm/st2/issues/4650
actual_context = mock_request.call_args[0][0].context
- self.assertIn('source_channel', mock_request.call_args[0][0].context.keys())
- self.assertEqual(actual_context['source_channel'], 'chat-channel')
- self.assertEqual(actual_context['api_user'], 'chat-user')
- self.assertEqual(actual_context['user'], 'stanley')
+ self.assertIn("source_channel", mock_request.call_args[0][0].context.keys())
+ self.assertEqual(actual_context["source_channel"], "chat-channel")
+ self.assertEqual(actual_context["api_user"], "chat-user")
+ self.assertEqual(actual_context["user"], "stanley")
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_match_and_execute_matches_one_multiple_match(self, mock_request):
base_data = {
- 'source_channel': 'chat',
- 'notification_route': 'hubot',
- 'user': 'chat-user'
+ "source_channel": "chat",
+ "notification_route": "hubot",
+ "user": "chat-user",
}
# Command matches multiple times - should result in multiple action execution
data = copy.deepcopy(base_data)
- data['command'] = ('JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which '
- 'is a duplicate of DRSEUSS-12')
- resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data)
+ data["command"] = (
+ "JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which "
+ "is a duplicate of DRSEUSS-12"
+ )
+ resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(len(resp.json['results']), 2)
- self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id']))
- self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status'])
- self.assertEqual(resp.json['results'][1]['execution']['id'], str(EXECUTION['id']))
- self.assertEqual(resp.json['results'][1]['execution']['status'], EXECUTION['status'])
+ self.assertEqual(len(resp.json["results"]), 2)
+ self.assertEqual(
+ resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"])
+ )
+ self.assertEqual(
+ resp.json["results"][0]["execution"]["status"], EXECUTION["status"]
+ )
+ self.assertEqual(
+ resp.json["results"][1]["execution"]["id"], str(EXECUTION["id"])
+ )
+ self.assertEqual(
+ resp.json["results"][1]["execution"]["status"], EXECUTION["status"]
+ )
# The mock object only stores the parameters of the _last_ time it was called, so that's
# what we assert on. Luckily re.finditer() processes groups in order, so if this was the
@@ -255,34 +282,39 @@ def test_match_and_execute_matches_one_multiple_match(self, mock_request):
#
# We've also already checked the results array
#
- expected_parameters = {'issue_key': 'DRSEUSS-12'}
+ expected_parameters = {"issue_key": "DRSEUSS-12"}
self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters)
- @mock.patch.object(action_service, 'request',
- return_value=(None, EXECUTION))
+ @mock.patch.object(action_service, "request", return_value=(None, EXECUTION))
def test_match_and_execute_matches_many_multiple_match(self, mock_request):
base_data = {
- 'source_channel': 'chat',
- 'notification_route': 'hubot',
- 'user': 'chat-user'
+ "source_channel": "chat",
+ "notification_route": "hubot",
+ "user": "chat-user",
}
# Command matches multiple times - should result in multiple action execution
data = copy.deepcopy(base_data)
- data['command'] = 'JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12'
- resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True)
+ data["command"] = "JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12"
+ resp = self.app.post_json(
+ "/v1/aliasexecution/match_and_execute", data, expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(str(resp.json['faultstring']),
- "Command '{command}' "
- "matched more than 1 (multi) pattern".format(command=data['command']))
+ self.assertEqual(
+ str(resp.json["faultstring"]),
+ "Command '{command}' "
+ "matched more than 1 (multi) pattern".format(command=data["command"]),
+ )
def test_match_and_execute_list_action_param_str_cast_to_list(self):
data = {
- 'command': 'test alias list param str cast',
- 'source_channel': 'hubot',
- 'user': 'foo',
+ "command": "test alias list param str cast",
+ "source_channel": "hubot",
+ "user": "foo",
}
- resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True)
+ resp = self.app.post_json(
+ "/v1/aliasexecution/match_and_execute", data, expect_errors=True
+ )
# Param is a comma delimited string - our custom cast function should cast it to a list.
# I assume that was done to make specifying complex params in chat easier.
@@ -300,15 +332,19 @@ def test_match_and_execute_list_action_param_str_cast_to_list(self):
self.assertEqual(live_action["parameters"]["array_param"][1], "two")
self.assertEqual(live_action["parameters"]["array_param"][2], "three")
self.assertEqual(live_action["parameters"]["array_param"][3], "four")
- self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], str))
+ self.assertTrue(
+ isinstance(action_alias["immutable_parameters"]["array_param"], str)
+ )
def test_match_and_execute_list_action_param_already_a_list(self):
data = {
- 'command': 'test alias foo',
- 'source_channel': 'hubot',
- 'user': 'foo',
+ "command": "test alias foo",
+ "source_channel": "hubot",
+ "user": "foo",
}
- resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True)
+ resp = self.app.post_json(
+ "/v1/aliasexecution/match_and_execute", data, expect_errors=True
+ )
# immutable_param is already a list - verify no casting is performed
self.assertEqual(resp.status_int, 201)
@@ -323,37 +359,53 @@ def test_match_and_execute_list_action_param_already_a_list(self):
self.assertEqual(live_action["parameters"]["array_param"][0]["key2"], "two")
self.assertEqual(live_action["parameters"]["array_param"][1]["key3"], "three")
self.assertEqual(live_action["parameters"]["array_param"][1]["key4"], "four")
- self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], list))
+ self.assertTrue(
+ isinstance(action_alias["immutable_parameters"]["array_param"], list)
+ )
def test_match_and_execute_success(self):
data = {
- 'command': 'run whoami on localhost1',
- 'source_channel': 'hubot',
- 'user': "user",
+ "command": "run whoami on localhost1",
+ "source_channel": "hubot",
+ "user": "user",
}
resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data)
self.assertEqual(resp.status_int, 201)
self.assertEqual(len(resp.json["results"]), 1)
- self.assertTrue(resp.json["results"][0]["actionalias"]["ref"],
- "aliases.alias_with_undefined_jinja_in_ack_format")
-
- def _do_post(self, alias_execution, command, format_str=None, expect_errors=False,
- show_secrets=False):
- if (isinstance(alias_execution.formats[0], dict) and
- alias_execution.formats[0].get('representation')):
- representation = alias_execution.formats[0].get('representation')[0]
+ self.assertTrue(
+ resp.json["results"][0]["actionalias"]["ref"],
+ "aliases.alias_with_undefined_jinja_in_ack_format",
+ )
+
+ def _do_post(
+ self,
+ alias_execution,
+ command,
+ format_str=None,
+ expect_errors=False,
+ show_secrets=False,
+ ):
+ if isinstance(alias_execution.formats[0], dict) and alias_execution.formats[
+ 0
+ ].get("representation"):
+ representation = alias_execution.formats[0].get("representation")[0]
else:
representation = alias_execution.formats[0]
if not format_str:
format_str = representation
- execution = {'name': alias_execution.name,
- 'format': format_str,
- 'command': command,
- 'user': 'stanley',
- 'source_channel': 'test',
- 'notification_route': 'test'}
- url = show_secrets and '/v1/aliasexecution?show_secrets=true' or '/v1/aliasexecution'
- return self.app.post_json(url, execution,
- expect_errors=expect_errors)
+ execution = {
+ "name": alias_execution.name,
+ "format": format_str,
+ "command": command,
+ "user": "stanley",
+ "source_channel": "test",
+ "notification_route": "test",
+ }
+ url = (
+ show_secrets
+ and "/v1/aliasexecution?show_secrets=true"
+ or "/v1/aliasexecution"
+ )
+ return self.app.post_json(url, execution, expect_errors=expect_errors)
diff --git a/st2api/tests/unit/controllers/v1/test_auth.py b/st2api/tests/unit/controllers/v1/test_auth.py
index fb5a203929..d6f3602c3c 100644
--- a/st2api/tests/unit/controllers/v1/test_auth.py
+++ b/st2api/tests/unit/controllers/v1/test_auth.py
@@ -27,7 +27,7 @@
from st2tests.fixturesloader import FixturesLoader
OBJ_ID = bson.ObjectId()
-USER = 'stanley'
+USER = "stanley"
USER_DB = UserDB(name=USER)
TOKEN = uuid.uuid4().hex
NOW = date_utils.get_datetime_utc_now()
@@ -40,67 +40,84 @@ class TestTokenBasedAuth(FunctionalTest):
enable_auth = True
@mock.patch.object(
- Token, 'get',
- mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB))
+ Token,
+ "get",
+ mock.Mock(
+ return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)
+ ),
+ )
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB))
def test_token_validation_token_in_headers(self):
- response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN},
- expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
@mock.patch.object(
- Token, 'get',
- mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB))
+ Token,
+ "get",
+ mock.Mock(
+ return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)
+ ),
+ )
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB))
def test_token_validation_token_in_query_params(self):
- response = self.app.get('/v1/actions?x-auth-token=%s' % (TOKEN), expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions?x-auth-token=%s" % (TOKEN), expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
@mock.patch.object(
- Token, 'get',
- mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB))
+ Token,
+ "get",
+ mock.Mock(
+ return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)
+ ),
+ )
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB))
def test_token_validation_token_in_cookies(self):
- response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN},
- expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
- with mock.patch.object(self.app.cookiejar, 'clear', return_value=None):
- response = self.app.get('/v1/actions', expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ with mock.patch.object(self.app.cookiejar, "clear", return_value=None):
+ response = self.app.get("/v1/actions", expect_errors=False)
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
@mock.patch.object(
- Token, 'get',
- mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST)))
+ Token,
+ "get",
+ mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST)),
+ )
def test_token_expired(self):
- response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN},
- expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 401)
- @mock.patch.object(
- Token, 'get', mock.MagicMock(side_effect=TokenNotFoundError()))
+ @mock.patch.object(Token, "get", mock.MagicMock(side_effect=TokenNotFoundError()))
def test_token_not_found(self):
- response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN},
- expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 401)
def test_token_not_provided(self):
- response = self.app.get('/v1/actions', expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get("/v1/actions", expect_errors=True)
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 401)
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
-TEST_MODELS = {
- 'apikeys': ['apikey1.yaml', 'apikey_disabled.yaml']
-}
+TEST_MODELS = {"apikeys": ["apikey1.yaml", "apikey_disabled.yaml"]}
# Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures.
KEY1_KEY = "1234"
@@ -117,62 +134,83 @@ class TestApiKeyBasedAuth(FunctionalTest):
@classmethod
def setUpClass(cls):
super(TestApiKeyBasedAuth, cls).setUpClass()
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
- cls.apikey1 = models['apikeys']['apikey1.yaml']
- cls.apikey_disabled = models['apikeys']['apikey_disabled.yaml']
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+ cls.apikey1 = models["apikeys"]["apikey1.yaml"]
+ cls.apikey_disabled = models["apikeys"]["apikey_disabled.yaml"]
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill')))
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill")))
def test_apikey_validation_apikey_in_headers(self):
- response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY},
- expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill')))
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill")))
def test_apikey_validation_apikey_in_query_params(self):
- response = self.app.get('/v1/actions?st2-api-key=%s' % (KEY1_KEY), expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions?st2-api-key=%s" % (KEY1_KEY), expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill')))
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill")))
def test_apikey_validation_apikey_in_cookies(self):
- response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY},
- expect_errors=False)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
- with mock.patch.object(self.app.cookiejar, 'clear', return_value=None):
- response = self.app.get('/v1/actions', expect_errors=True)
+ with mock.patch.object(self.app.cookiejar, "clear", return_value=None):
+ response = self.app.get("/v1/actions", expect_errors=True)
self.assertEqual(response.status_int, 401)
- self.assertEqual(response.json_body['faultstring'],
- 'Unauthorized - One of Token or API key required.')
+ self.assertEqual(
+ response.json_body["faultstring"],
+ "Unauthorized - One of Token or API key required.",
+ )
def test_apikey_disabled(self):
- response = self.app.get('/v1/actions', headers={'St2-Api-key': DISABLED_KEY},
- expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"St2-Api-key": DISABLED_KEY}, expect_errors=True
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 401)
- self.assertEqual(response.json_body['faultstring'], 'Unauthorized - API key is disabled.')
+ self.assertEqual(
+ response.json_body["faultstring"], "Unauthorized - API key is disabled."
+ )
def test_apikey_not_found(self):
- response = self.app.get('/v1/actions', headers={'St2-Api-key': 'UNKNOWN'},
- expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions", headers={"St2-Api-key": "UNKNOWN"}, expect_errors=True
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 401)
- self.assertRegexpMatches(response.json_body['faultstring'],
- '^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$')
+ self.assertRegexpMatches(
+ response.json_body["faultstring"],
+ "^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$",
+ )
@mock.patch.object(
- Token, 'get',
- mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)))
+ Token,
+ "get",
+ mock.Mock(
+ return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE)
+ ),
+ )
@mock.patch.object(
- ApiKey, 'get',
- mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True)))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB))
+ ApiKey,
+ "get",
+ mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True)),
+ )
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB))
def test_multiple_auth_sources(self):
- response = self.app.get('/v1/actions',
- headers={'X-Auth-Token': TOKEN, 'St2-Api-key': KEY1_KEY},
- expect_errors=True)
- self.assertIn('application/json', response.headers['content-type'])
+ response = self.app.get(
+ "/v1/actions",
+ headers={"X-Auth-Token": TOKEN, "St2-Api-key": KEY1_KEY},
+ expect_errors=True,
+ )
+ self.assertIn("application/json", response.headers["content-type"])
self.assertEqual(response.status_int, 200)
diff --git a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py
index c172b22445..bf76d41276 100644
--- a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py
+++ b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py
@@ -22,11 +22,16 @@
from st2tests.fixturesloader import FixturesLoader
from st2tests.api import FunctionalTest
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_MODELS = {
- 'apikeys': ['apikey1.yaml', 'apikey2.yaml', 'apikey3.yaml', 'apikey_disabled.yaml',
- 'apikey_malformed.yaml']
+ "apikeys": [
+ "apikey1.yaml",
+ "apikey2.yaml",
+ "apikey3.yaml",
+ "apikey_disabled.yaml",
+ "apikey_malformed.yaml",
+ ]
}
# Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures.
@@ -45,205 +50,239 @@ class TestApiKeyController(FunctionalTest):
def setUpClass(cls):
super(TestApiKeyController, cls).setUpClass()
- cfg.CONF.set_override(name='mask_secrets', override=True, group='api')
- cfg.CONF.set_override(name='mask_secrets', override=True, group='log')
+ cfg.CONF.set_override(name="mask_secrets", override=True, group="api")
+ cfg.CONF.set_override(name="mask_secrets", override=True, group="log")
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
- cls.apikey1 = models['apikeys']['apikey1.yaml']
- cls.apikey2 = models['apikeys']['apikey2.yaml']
- cls.apikey3 = models['apikeys']['apikey3.yaml']
- cls.apikey4 = models['apikeys']['apikey_disabled.yaml']
- cls.apikey5 = models['apikeys']['apikey_malformed.yaml']
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+ cls.apikey1 = models["apikeys"]["apikey1.yaml"]
+ cls.apikey2 = models["apikeys"]["apikey2.yaml"]
+ cls.apikey3 = models["apikeys"]["apikey3.yaml"]
+ cls.apikey4 = models["apikeys"]["apikey_disabled.yaml"]
+ cls.apikey5 = models["apikeys"]["apikey_malformed.yaml"]
def test_get_all_and_minus_one(self):
- resp = self.app.get('/v1/apikeys')
+ resp = self.app.get("/v1/apikeys")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "5")
- self.assertEqual(resp.headers['X-Limit'], "50")
- self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.')
-
- retrieved_ids = [apikey['id'] for apikey in resp.json]
- self.assertEqual(retrieved_ids,
- [str(self.apikey1.id), str(self.apikey2.id), str(self.apikey3.id),
- str(self.apikey4.id), str(self.apikey5.id)],
- 'Incorrect api keys retrieved.')
-
- resp = self.app.get('/v1/apikeys/?limit=-1')
+ self.assertEqual(resp.headers["X-Total-Count"], "5")
+ self.assertEqual(resp.headers["X-Limit"], "50")
+ self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.")
+
+ retrieved_ids = [apikey["id"] for apikey in resp.json]
+ self.assertEqual(
+ retrieved_ids,
+ [
+ str(self.apikey1.id),
+ str(self.apikey2.id),
+ str(self.apikey3.id),
+ str(self.apikey4.id),
+ str(self.apikey5.id),
+ ],
+ "Incorrect api keys retrieved.",
+ )
+
+ resp = self.app.get("/v1/apikeys/?limit=-1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "5")
- self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.')
+ self.assertEqual(resp.headers["X-Total-Count"], "5")
+ self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.")
def test_get_all_with_pagnination_with_offset_and_limit(self):
- resp = self.app.get('/v1/apikeys?offset=2&limit=1')
+ resp = self.app.get("/v1/apikeys?offset=2&limit=1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "5")
- self.assertEqual(resp.headers['X-Limit'], "1")
+ self.assertEqual(resp.headers["X-Total-Count"], "5")
+ self.assertEqual(resp.headers["X-Limit"], "1")
self.assertEqual(len(resp.json), 1)
- retrieved_ids = [apikey['id'] for apikey in resp.json]
+ retrieved_ids = [apikey["id"] for apikey in resp.json]
self.assertEqual(retrieved_ids, [str(self.apikey3.id)])
def test_get_all_with_pagnination_with_only_offset(self):
- resp = self.app.get('/v1/apikeys?offset=3')
+ resp = self.app.get("/v1/apikeys?offset=3")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "5")
- self.assertEqual(resp.headers['X-Limit'], "50")
+ self.assertEqual(resp.headers["X-Total-Count"], "5")
+ self.assertEqual(resp.headers["X-Limit"], "50")
self.assertEqual(len(resp.json), 2)
- retrieved_ids = [apikey['id'] for apikey in resp.json]
+ retrieved_ids = [apikey["id"] for apikey in resp.json]
self.assertEqual(retrieved_ids, [str(self.apikey4.id), str(self.apikey5.id)])
def test_get_all_with_pagnination_with_only_limit(self):
- resp = self.app.get('/v1/apikeys?limit=2')
+ resp = self.app.get("/v1/apikeys?limit=2")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "5")
- self.assertEqual(resp.headers['X-Limit'], "2")
+ self.assertEqual(resp.headers["X-Total-Count"], "5")
+ self.assertEqual(resp.headers["X-Limit"], "2")
self.assertEqual(len(resp.json), 2)
- retrieved_ids = [apikey['id'] for apikey in resp.json]
+ retrieved_ids = [apikey["id"] for apikey in resp.json]
self.assertEqual(retrieved_ids, [str(self.apikey1.id), str(self.apikey2.id)])
- @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin',
- mock.Mock(return_value=False))
+ @mock.patch(
+ "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin",
+ mock.Mock(return_value=False),
+ )
def test_get_all_invalid_limit_too_large_none_admin(self):
# limit > max_page_size, but user is not admin
- resp = self.app.get('/v1/apikeys?offset=2&limit=1000', expect_errors=True)
+ resp = self.app.get("/v1/apikeys?offset=2&limit=1000", expect_errors=True)
self.assertEqual(resp.status_int, http_client.FORBIDDEN)
- self.assertEqual(resp.json['faultstring'],
- 'Limit "1000" specified, maximum value is "100"')
+ self.assertEqual(
+ resp.json["faultstring"], 'Limit "1000" specified, maximum value is "100"'
+ )
def test_get_all_invalid_limit_negative_integer(self):
- resp = self.app.get('/v1/apikeys?offset=2&limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/apikeys?offset=2&limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- 'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_get_all_invalid_offset_too_large(self):
- offset = '2141564789454123457895412237483648'
- resp = self.app.get('/v1/apikeys?offset=%s&limit=1' % (offset), expect_errors=True)
+ offset = "2141564789454123457895412237483648"
+ resp = self.app.get(
+ "/v1/apikeys?offset=%s&limit=1" % (offset), expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- 'Offset "%s" specified is more than 32 bit int' % (offset))
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Offset "%s" specified is more than 32 bit int' % (offset),
+ )
def test_get_one_by_id(self):
- resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id)
+ resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['id'], str(self.apikey1.id),
- 'Incorrect api key retrieved.')
- self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE,
- 'Key should be masked.')
+ self.assertEqual(
+ resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved."
+ )
+ self.assertEqual(
+ resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked."
+ )
def test_get_one_by_key(self):
# key1
- resp = self.app.get('/v1/apikeys/%s' % KEY1_KEY)
+ resp = self.app.get("/v1/apikeys/%s" % KEY1_KEY)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['id'], str(self.apikey1.id),
- 'Incorrect api key retrieved.')
- self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE,
- 'Key should be masked.')
+ self.assertEqual(
+ resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved."
+ )
+ self.assertEqual(
+ resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked."
+ )
# key2
- resp = self.app.get('/v1/apikeys/%s' % KEY2_KEY)
+ resp = self.app.get("/v1/apikeys/%s" % KEY2_KEY)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['id'], str(self.apikey2.id),
- 'Incorrect api key retrieved.')
- self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE,
- 'Key should be masked.')
+ self.assertEqual(
+ resp.json["id"], str(self.apikey2.id), "Incorrect api key retrieved."
+ )
+ self.assertEqual(
+ resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked."
+ )
# key3
- resp = self.app.get('/v1/apikeys/%s' % KEY3_KEY)
+ resp = self.app.get("/v1/apikeys/%s" % KEY3_KEY)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['id'], str(self.apikey3.id),
- 'Incorrect api key retrieved.')
- self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE,
- 'Key should be masked.')
+ self.assertEqual(
+ resp.json["id"], str(self.apikey3.id), "Incorrect api key retrieved."
+ )
+ self.assertEqual(
+ resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked."
+ )
def test_get_show_secrets(self):
- resp = self.app.get('/v1/apikeys?show_secrets=True', expect_errors=True)
+ resp = self.app.get("/v1/apikeys?show_secrets=True", expect_errors=True)
self.assertEqual(resp.status_int, 200)
for key in resp.json:
- self.assertNotEqual(key['key_hash'], MASKED_ATTRIBUTE_VALUE)
- self.assertNotEqual(key['uid'], MASKED_ATTRIBUTE_VALUE)
+ self.assertNotEqual(key["key_hash"], MASKED_ATTRIBUTE_VALUE)
+ self.assertNotEqual(key["uid"], MASKED_ATTRIBUTE_VALUE)
def test_post_delete_key(self):
- api_key = {
- 'user': 'herge'
- }
- resp1 = self.app.post_json('/v1/apikeys', api_key)
+ api_key = {"user": "herge"}
+ resp1 = self.app.post_json("/v1/apikeys", api_key)
self.assertEqual(resp1.status_int, 201)
- self.assertTrue(resp1.json['key'], 'Key should be non-None.')
- self.assertNotEqual(resp1.json['key'], MASKED_ATTRIBUTE_VALUE,
- 'Key should not be masked.')
+ self.assertTrue(resp1.json["key"], "Key should be non-None.")
+ self.assertNotEqual(
+ resp1.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked."
+ )
# should lead to creation of another key
- resp2 = self.app.post_json('/v1/apikeys', api_key)
+ resp2 = self.app.post_json("/v1/apikeys", api_key)
self.assertEqual(resp2.status_int, 201)
- self.assertTrue(resp2.json['key'], 'Key should be non-None.')
- self.assertNotEqual(resp2.json['key'], MASKED_ATTRIBUTE_VALUE, 'Key should not be masked.')
- self.assertNotEqual(resp1.json['key'], resp2.json['key'], 'Should be different')
+ self.assertTrue(resp2.json["key"], "Key should be non-None.")
+ self.assertNotEqual(
+ resp2.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked."
+ )
+ self.assertNotEqual(resp1.json["key"], resp2.json["key"], "Should be different")
- resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id'])
+ resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"])
self.assertEqual(resp.status_int, 204)
- resp = self.app.delete('/v1/apikeys/%s' % resp2.json['key'])
+ resp = self.app.delete("/v1/apikeys/%s" % resp2.json["key"])
self.assertEqual(resp.status_int, 204)
# With auth disabled, use system_user
- resp3 = self.app.post_json('/v1/apikeys', {})
+ resp3 = self.app.post_json("/v1/apikeys", {})
self.assertEqual(resp3.status_int, 201)
- self.assertTrue(resp3.json['key'], 'Key should be non-None.')
- self.assertNotEqual(resp3.json['key'], MASKED_ATTRIBUTE_VALUE,
- 'Key should not be masked.')
- self.assertTrue(resp3.json['user'], cfg.CONF.system_user.user)
+ self.assertTrue(resp3.json["key"], "Key should be non-None.")
+ self.assertNotEqual(
+ resp3.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked."
+ )
+ self.assertTrue(resp3.json["user"], cfg.CONF.system_user.user)
def test_post_delete_same_key_hash(self):
api_key = {
- 'id': '5c5dbb576cb8de06a2d79a4d',
- 'user': 'herge',
- 'key_hash': 'ABCDE'
+ "id": "5c5dbb576cb8de06a2d79a4d",
+ "user": "herge",
+ "key_hash": "ABCDE",
}
- resp1 = self.app.post_json('/v1/apikeys', api_key)
+ resp1 = self.app.post_json("/v1/apikeys", api_key)
self.assertEqual(resp1.status_int, 201)
- self.assertEqual(resp1.json['key'], None, 'Key should be None.')
+ self.assertEqual(resp1.json["key"], None, "Key should be None.")
# drop into the DB since API will be masking this value.
- api_key_db = ApiKey.get_by_id(resp1.json['id'])
+ api_key_db = ApiKey.get_by_id(resp1.json["id"])
- self.assertEqual(resp1.json['id'], api_key['id'], 'PK ID of created API should match.')
- self.assertEqual(api_key_db.key_hash, api_key['key_hash'], 'Key_hash should match.')
- self.assertEqual(api_key_db.user, api_key['user'], 'User should match.')
+ self.assertEqual(
+ resp1.json["id"], api_key["id"], "PK ID of created API should match."
+ )
+ self.assertEqual(
+ api_key_db.key_hash, api_key["key_hash"], "Key_hash should match."
+ )
+ self.assertEqual(api_key_db.user, api_key["user"], "User should match.")
- resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id'])
+ resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"])
self.assertEqual(resp.status_int, 204)
def test_put_api_key(self):
- resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id)
+ resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id)
self.assertEqual(resp.status_int, 200)
update_input = resp.json
- update_input['enabled'] = not update_input['enabled']
- put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input,
- expect_errors=True)
+ update_input["enabled"] = not update_input["enabled"]
+ put_resp = self.app.put_json(
+ "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['enabled'], not resp.json['enabled'])
+ self.assertEqual(put_resp.json["enabled"], not resp.json["enabled"])
update_input = put_resp.json
- update_input['enabled'] = not update_input['enabled']
- put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input,
- expect_errors=True)
+ update_input["enabled"] = not update_input["enabled"]
+ put_resp = self.app.put_json(
+ "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['enabled'], resp.json['enabled'])
+ self.assertEqual(put_resp.json["enabled"], resp.json["enabled"])
def test_put_api_key_fail(self):
- resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id)
+ resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id)
self.assertEqual(resp.status_int, 200)
update_input = resp.json
- update_input['key_hash'] = '1'
- put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input,
- expect_errors=True)
+ update_input["key_hash"] = "1"
+ put_resp = self.app.put_json(
+ "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True
+ )
self.assertEqual(put_resp.status_int, 400)
- self.assertTrue(put_resp.json['faultstring'])
+ self.assertTrue(put_resp.json["faultstring"])
def test_post_no_user_fail(self):
- self.app.post_json('/v1/apikeys', {}, expect_errors=True)
+ self.app.post_json("/v1/apikeys", {}, expect_errors=True)
diff --git a/st2api/tests/unit/controllers/v1/test_base.py b/st2api/tests/unit/controllers/v1/test_base.py
index fa8b4f1c92..cbfe3e54c2 100644
--- a/st2api/tests/unit/controllers/v1/test_base.py
+++ b/st2api/tests/unit/controllers/v1/test_base.py
@@ -19,77 +19,79 @@
class TestBase(FunctionalTest):
def test_defaults(self):
- response = self.app.get('/')
+ response = self.app.get("/")
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers['Access-Control-Allow-Origin'],
- 'http://127.0.0.1:3000')
- self.assertEqual(response.headers['Access-Control-Allow-Methods'],
- 'GET,POST,PUT,DELETE,OPTIONS')
- self.assertEqual(response.headers['Access-Control-Allow-Headers'],
- 'Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID')
- self.assertEqual(response.headers['Access-Control-Expose-Headers'],
- 'Content-Type,X-Limit,X-Total-Count,X-Request-ID')
+ self.assertEqual(
+ response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000"
+ )
+ self.assertEqual(
+ response.headers["Access-Control-Allow-Methods"],
+ "GET,POST,PUT,DELETE,OPTIONS",
+ )
+ self.assertEqual(
+ response.headers["Access-Control-Allow-Headers"],
+ "Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID",
+ )
+ self.assertEqual(
+ response.headers["Access-Control-Expose-Headers"],
+ "Content-Type,X-Limit,X-Total-Count,X-Request-ID",
+ )
def test_origin(self):
- response = self.app.get('/', headers={
- 'origin': 'http://127.0.0.1:3000'
- })
+ response = self.app.get("/", headers={"origin": "http://127.0.0.1:3000"})
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers['Access-Control-Allow-Origin'],
- 'http://127.0.0.1:3000')
+ self.assertEqual(
+ response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000"
+ )
def test_additional_origin(self):
- response = self.app.get('/', headers={
- 'origin': 'http://dev'
- })
+ response = self.app.get("/", headers={"origin": "http://dev"})
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers['Access-Control-Allow-Origin'],
- 'http://dev')
+ self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://dev")
def test_wrong_origin(self):
# Invalid origin (not specified in the config), we return first allowed origin specified
# in the config
- response = self.app.get('/', headers={
- 'origin': 'http://xss'
- })
+ response = self.app.get("/", headers={"origin": "http://xss"})
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers.get('Access-Control-Allow-Origin'),
- 'http://127.0.0.1:3000')
+ self.assertEqual(
+ response.headers.get("Access-Control-Allow-Origin"), "http://127.0.0.1:3000"
+ )
invalid_origins = [
- 'http://',
- 'https://',
- 'https://www.example.com',
- 'null',
- '*'
+ "http://",
+ "https://",
+ "https://www.example.com",
+ "null",
+ "*",
]
for origin in invalid_origins:
- response = self.app.get('/', headers={
- 'origin': origin
- })
+ response = self.app.get("/", headers={"origin": origin})
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers.get('Access-Control-Allow-Origin'),
- 'http://127.0.0.1:3000')
+ self.assertEqual(
+ response.headers.get("Access-Control-Allow-Origin"),
+ "http://127.0.0.1:3000",
+ )
def test_wildcard_origin(self):
try:
- cfg.CONF.set_override('allow_origin', ['*'], 'api')
- response = self.app.get('/', headers={
- 'origin': 'http://xss'
- })
+ cfg.CONF.set_override("allow_origin", ["*"], "api")
+ response = self.app.get("/", headers={"origin": "http://xss"})
finally:
- cfg.CONF.clear_override('allow_origin', 'api')
+ cfg.CONF.clear_override("allow_origin", "api")
self.assertEqual(response.status_int, 200)
- self.assertEqual(response.headers['Access-Control-Allow-Origin'],
- 'http://xss')
+ self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://xss")
def test_valid_status_code_is_returned_on_invalid_path(self):
# TypeError: get_all() takes exactly 1 argument (2 given)
- resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run', expect_errors=True)
+ resp = self.app.get(
+ "/v1/executions/577f775b0640fd1451f2030b/re_run", expect_errors=True
+ )
self.assertEqual(resp.status_int, 404)
# get_one() takes exactly 2 arguments (4 given)
- resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run/a/b',
- expect_errors=True)
+ resp = self.app.get(
+ "/v1/executions/577f775b0640fd1451f2030b/re_run/a/b", expect_errors=True
+ )
self.assertEqual(resp.status_int, 404)
diff --git a/st2api/tests/unit/controllers/v1/test_executions.py b/st2api/tests/unit/controllers/v1/test_executions.py
index 57dad1f9f3..5a59f6aab5 100644
--- a/st2api/tests/unit/controllers/v1/test_executions.py
+++ b/st2api/tests/unit/controllers/v1/test_executions.py
@@ -55,324 +55,286 @@
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
__all__ = [
- 'ActionExecutionControllerTestCase',
- 'ActionExecutionOutputControllerTestCase'
+ "ActionExecutionControllerTestCase",
+ "ActionExecutionOutputControllerTestCase",
]
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action1.sh',
- 'pack': 'sixpack',
- 'runner_type': 'remote-shell-cmd',
- 'parameters': {
- 'a': {
- 'type': 'string',
- 'default': 'abc'
- },
- 'b': {
- 'type': 'number',
- 'default': 123
- },
- 'c': {
- 'type': 'number',
- 'default': 123,
- 'immutable': True
- },
- 'd': {
- 'type': 'string',
- 'secret': True
- }
- }
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action1.sh",
+ "pack": "sixpack",
+ "runner_type": "remote-shell-cmd",
+ "parameters": {
+ "a": {"type": "string", "default": "abc"},
+ "b": {"type": "number", "default": 123},
+ "c": {"type": "number", "default": 123, "immutable": True},
+ "d": {"type": "string", "secret": True},
+ },
}
ACTION_2 = {
- 'name': 'st2.dummy.action2',
- 'description': 'another test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action2.sh',
- 'pack': 'familypack',
- 'runner_type': 'remote-shell-cmd',
- 'parameters': {
- 'c': {
- 'type': 'object',
- 'properties': {
- 'c1': {
- 'type': 'string'
- }
- }
- },
- 'd': {
- 'type': 'boolean',
- 'default': False
- }
- }
+ "name": "st2.dummy.action2",
+ "description": "another test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action2.sh",
+ "pack": "familypack",
+ "runner_type": "remote-shell-cmd",
+ "parameters": {
+ "c": {"type": "object", "properties": {"c1": {"type": "string"}}},
+ "d": {"type": "boolean", "default": False},
+ },
}
ACTION_3 = {
- 'name': 'st2.dummy.action3',
- 'description': 'another test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action3.sh',
- 'pack': 'wolfpack',
- 'runner_type': 'remote-shell-cmd',
- 'parameters': {
- 'e': {},
- 'f': {}
- }
+ "name": "st2.dummy.action3",
+ "description": "another test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action3.sh",
+ "pack": "wolfpack",
+ "runner_type": "remote-shell-cmd",
+ "parameters": {"e": {}, "f": {}},
}
ACTION_4 = {
- 'name': 'st2.dummy.action4',
- 'description': 'another test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/workflows/action4.yaml',
- 'pack': 'starterpack',
- 'runner_type': 'orquesta',
- 'parameters': {
- 'a': {
- 'type': 'string',
- 'default': 'abc'
- },
- 'b': {
- 'type': 'number',
- 'default': 123
- }
- }
+ "name": "st2.dummy.action4",
+ "description": "another test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/workflows/action4.yaml",
+ "pack": "starterpack",
+ "runner_type": "orquesta",
+ "parameters": {
+ "a": {"type": "string", "default": "abc"},
+ "b": {"type": "number", "default": 123},
+ },
}
ACTION_INQUIRY = {
- 'name': 'st2.dummy.ask',
- 'description': 'another test description',
- 'enabled': True,
- 'pack': 'wolfpack',
- 'runner_type': 'inquirer',
+ "name": "st2.dummy.ask",
+ "description": "another test description",
+ "enabled": True,
+ "pack": "wolfpack",
+ "runner_type": "inquirer",
}
ACTION_DEFAULT_TEMPLATE = {
- 'name': 'st2.dummy.default_template',
- 'description': 'An action that uses a jinja template as a default value for a parameter',
- 'enabled': True,
- 'pack': 'starterpack',
- 'runner_type': 'local-shell-cmd',
- 'parameters': {
- 'intparam': {
- 'type': 'integer',
- 'default': '{{ st2kv.system.test_int | int }}'
- }
- }
+ "name": "st2.dummy.default_template",
+ "description": "An action that uses a jinja template as a default value for a parameter",
+ "enabled": True,
+ "pack": "starterpack",
+ "runner_type": "local-shell-cmd",
+ "parameters": {
+ "intparam": {"type": "integer", "default": "{{ st2kv.system.test_int | int }}"}
+ },
}
ACTION_DEFAULT_ENCRYPT = {
- 'name': 'st2.dummy.default_encrypted_value',
- 'description': 'An action that uses a jinja template with decrypt_kv filter '
- 'in default parameter',
- 'enabled': True,
- 'pack': 'starterpack',
- 'runner_type': 'local-shell-cmd',
- 'parameters': {
- 'encrypted_param': {
- 'type': 'string',
- 'default': '{{ st2kv.system.secret | decrypt_kv }}'
+ "name": "st2.dummy.default_encrypted_value",
+ "description": "An action that uses a jinja template with decrypt_kv filter "
+ "in default parameter",
+ "enabled": True,
+ "pack": "starterpack",
+ "runner_type": "local-shell-cmd",
+ "parameters": {
+ "encrypted_param": {
+ "type": "string",
+ "default": "{{ st2kv.system.secret | decrypt_kv }}",
},
- 'encrypted_user_param': {
- 'type': 'string',
- 'default': '{{ st2kv.user.secret | decrypt_kv }}'
- }
- }
+ "encrypted_user_param": {
+ "type": "string",
+ "default": "{{ st2kv.user.secret | decrypt_kv }}",
+ },
+ },
}
ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS = {
- 'name': 'st2.dummy.default_encrypted_value_secret_param',
- 'description': 'An action that uses a jinja template with decrypt_kv filter '
- 'in default parameter',
- 'enabled': True,
- 'pack': 'starterpack',
- 'runner_type': 'local-shell-cmd',
- 'parameters': {
- 'encrypted_param': {
- 'type': 'string',
- 'default': '{{ st2kv.system.secret | decrypt_kv }}',
- 'secret': True
+ "name": "st2.dummy.default_encrypted_value_secret_param",
+ "description": "An action that uses a jinja template with decrypt_kv filter "
+ "in default parameter",
+ "enabled": True,
+ "pack": "starterpack",
+ "runner_type": "local-shell-cmd",
+ "parameters": {
+ "encrypted_param": {
+ "type": "string",
+ "default": "{{ st2kv.system.secret | decrypt_kv }}",
+ "secret": True,
},
- 'encrypted_user_param': {
- 'type': 'string',
- 'default': '{{ st2kv.user.secret | decrypt_kv }}',
- 'secret': True
- }
- }
+ "encrypted_user_param": {
+ "type": "string",
+ "default": "{{ st2kv.user.secret | decrypt_kv }}",
+ "secret": True,
+ },
+ },
}
LIVE_ACTION_1 = {
- 'action': 'sixpack.st2.dummy.action1',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'uname -a',
- 'd': SUPER_SECRET_PARAMETER
- }
+ "action": "sixpack.st2.dummy.action1",
+ "parameters": {
+ "hosts": "localhost",
+ "cmd": "uname -a",
+ "d": SUPER_SECRET_PARAMETER,
+ },
}
LIVE_ACTION_2 = {
- 'action': 'familypack.st2.dummy.action2',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'ls -l'
- }
+ "action": "familypack.st2.dummy.action2",
+ "parameters": {"hosts": "localhost", "cmd": "ls -l"},
}
LIVE_ACTION_3 = {
- 'action': 'wolfpack.st2.dummy.action3',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'ls -l',
- 'e': 'abcde',
- 'f': 12345
- }
+ "action": "wolfpack.st2.dummy.action3",
+ "parameters": {"hosts": "localhost", "cmd": "ls -l", "e": "abcde", "f": 12345},
}
LIVE_ACTION_4 = {
- 'action': 'starterpack.st2.dummy.action4',
+ "action": "starterpack.st2.dummy.action4",
}
LIVE_ACTION_DELAY = {
- 'action': 'sixpack.st2.dummy.action1',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'uname -a',
- 'd': SUPER_SECRET_PARAMETER
+ "action": "sixpack.st2.dummy.action1",
+ "parameters": {
+ "hosts": "localhost",
+ "cmd": "uname -a",
+ "d": SUPER_SECRET_PARAMETER,
},
- 'delay': 100
+ "delay": 100,
}
LIVE_ACTION_INQUIRY = {
- 'parameters': {
- 'route': 'developers',
- 'schema': {
- 'type': 'object',
- 'properties': {
- 'secondfactor': {
- 'secret': True,
- 'required': True,
- 'type': u'string',
- 'description': 'Please enter second factor for authenticating to "foo" service'
+ "parameters": {
+ "route": "developers",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "secondfactor": {
+ "secret": True,
+ "required": True,
+ "type": "string",
+ "description": 'Please enter second factor for authenticating to "foo" service',
}
- }
- }
- },
- 'action': 'wolfpack.st2.dummy.ask',
- 'result': {
- 'users': [],
- 'roles': [],
- 'route': 'developers',
- 'ttl': 1440,
- 'response': {
- 'secondfactor': 'supersecretvalue'
+ },
},
- 'schema': {
- 'type': 'object',
- 'properties': {
- 'secondfactor': {
- 'secret': True,
- 'required': True,
- 'type': 'string',
- 'description': 'Please enter second factor for authenticating to "foo" service'
+ },
+ "action": "wolfpack.st2.dummy.ask",
+ "result": {
+ "users": [],
+ "roles": [],
+ "route": "developers",
+ "ttl": 1440,
+ "response": {"secondfactor": "supersecretvalue"},
+ "schema": {
+ "type": "object",
+ "properties": {
+ "secondfactor": {
+ "secret": True,
+ "required": True,
+ "type": "string",
+ "description": 'Please enter second factor for authenticating to "foo" service',
}
- }
- }
- }
+ },
+ },
+ },
}
LIVE_ACTION_WITH_SECRET_PARAM = {
- 'parameters': {
+ "parameters": {
# action params
- 'a': 'param a',
- 'd': 'secretpassword1',
-
+ "a": "param a",
+ "d": "secretpassword1",
# runner params
- 'password': 'secretpassword2',
- 'hosts': 'localhost'
+ "password": "secretpassword2",
+ "hosts": "localhost",
},
- 'action': 'sixpack.st2.dummy.action1'
+ "action": "sixpack.st2.dummy.action1",
}
# Do not add parameters to this. There are tests that will test first without params,
# then make a copy with params.
LIVE_ACTION_DEFAULT_TEMPLATE = {
- 'action': 'starterpack.st2.dummy.default_template',
+ "action": "starterpack.st2.dummy.default_template",
}
LIVE_ACTION_DEFAULT_ENCRYPT = {
- 'action': 'starterpack.st2.dummy.default_encrypted_value',
+ "action": "starterpack.st2.dummy.default_encrypted_value",
}
LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM = {
- 'action': 'starterpack.st2.dummy.default_encrypted_value_secret_param',
+ "action": "starterpack.st2.dummy.default_encrypted_value_secret_param",
}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_FIXTURES = {
- 'runners': ['testrunner1.yaml'],
- 'actions': ['action1.yaml', 'local.yaml']
+ "runners": ["testrunner1.yaml"],
+ "actions": ["action1.yaml", "local.yaml"],
}
-@mock.patch.object(content_utils, 'get_pack_base_path', mock.MagicMock(return_value='/tmp/test'))
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
-class ActionExecutionControllerTestCase(BaseActionExecutionControllerTestCase, FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/executions'
+@mock.patch.object(
+ content_utils, "get_pack_base_path", mock.MagicMock(return_value="/tmp/test")
+)
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+class ActionExecutionControllerTestCase(
+ BaseActionExecutionControllerTestCase,
+ FunctionalTest,
+ APIControllerWithIncludeAndExcludeFilterTestCase,
+):
+ get_all_path = "/v1/executions"
controller_cls = ActionExecutionsController
- include_attribute_field_name = 'status'
- exclude_attribute_field_name = 'status'
+ include_attribute_field_name = "status"
+ exclude_attribute_field_name = "status"
test_exact_object_count = False
@classmethod
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def setUpClass(cls):
super(BaseActionExecutionControllerTestCase, cls).setUpClass()
cls.action1 = copy.deepcopy(ACTION_1)
- post_resp = cls.app.post_json('/v1/actions', cls.action1)
- cls.action1['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action1)
+ cls.action1["id"] = post_resp.json["id"]
cls.action2 = copy.deepcopy(ACTION_2)
- post_resp = cls.app.post_json('/v1/actions', cls.action2)
- cls.action2['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action2)
+ cls.action2["id"] = post_resp.json["id"]
cls.action3 = copy.deepcopy(ACTION_3)
- post_resp = cls.app.post_json('/v1/actions', cls.action3)
- cls.action3['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action3)
+ cls.action3["id"] = post_resp.json["id"]
cls.action4 = copy.deepcopy(ACTION_4)
- post_resp = cls.app.post_json('/v1/actions', cls.action4)
- cls.action4['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action4)
+ cls.action4["id"] = post_resp.json["id"]
cls.action_inquiry = copy.deepcopy(ACTION_INQUIRY)
- post_resp = cls.app.post_json('/v1/actions', cls.action_inquiry)
- cls.action_inquiry['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action_inquiry)
+ cls.action_inquiry["id"] = post_resp.json["id"]
cls.action_template = copy.deepcopy(ACTION_DEFAULT_TEMPLATE)
- post_resp = cls.app.post_json('/v1/actions', cls.action_template)
- cls.action_template['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action_template)
+ cls.action_template["id"] = post_resp.json["id"]
cls.action_decrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT)
- post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt)
- cls.action_decrypt['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt)
+ cls.action_decrypt["id"] = post_resp.json["id"]
- cls.action_decrypt_secret_param = copy.deepcopy(ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS)
- post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt_secret_param)
- cls.action_decrypt_secret_param['id'] = post_resp.json['id']
+ cls.action_decrypt_secret_param = copy.deepcopy(
+ ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS
+ )
+ post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt_secret_param)
+ cls.action_decrypt_secret_param["id"] = post_resp.json["id"]
@classmethod
def tearDownClass(cls):
- cls.app.delete('/v1/actions/%s' % cls.action1['id'])
- cls.app.delete('/v1/actions/%s' % cls.action2['id'])
- cls.app.delete('/v1/actions/%s' % cls.action3['id'])
- cls.app.delete('/v1/actions/%s' % cls.action4['id'])
- cls.app.delete('/v1/actions/%s' % cls.action_inquiry['id'])
- cls.app.delete('/v1/actions/%s' % cls.action_template['id'])
- cls.app.delete('/v1/actions/%s' % cls.action_decrypt['id'])
+ cls.app.delete("/v1/actions/%s" % cls.action1["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action2["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action3["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action4["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action_inquiry["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action_template["id"])
+ cls.app.delete("/v1/actions/%s" % cls.action_decrypt["id"])
super(BaseActionExecutionControllerTestCase, cls).tearDownClass()
def test_get_one(self):
@@ -381,11 +343,11 @@ def test_get_one(self):
get_resp = self._do_get_one(actionexecution_id)
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id)
- self.assertIn('web_url', get_resp)
- if 'end_timestamp' in get_resp:
- self.assertIn('elapsed_seconds', get_resp)
+ self.assertIn("web_url", get_resp)
+ if "end_timestamp" in get_resp:
+ self.assertIn("elapsed_seconds", get_resp)
- get_resp = self._do_get_one('last')
+ get_resp = self._do_get_one("last")
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id)
@@ -396,13 +358,15 @@ def test_get_all_id_query_param_filtering_success(self):
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id)
- resp = self.app.get('/v1/executions?id=%s' % (actionexecution_id), expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions?id=%s" % (actionexecution_id), expect_errors=False
+ )
self.assertEqual(resp.status_int, 200)
def test_get_all_id_query_param_filtering_invalid_id(self):
- resp = self.app.get('/v1/executions?id=invalidid', expect_errors=True)
+ resp = self.app.get("/v1/executions?id=invalidid", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertIn('not a valid ObjectId', resp.json['faultstring'])
+ self.assertIn("not a valid ObjectId", resp.json["faultstring"])
def test_get_all_id_query_param_filtering_multiple_ids_provided(self):
post_resp = self._do_post(LIVE_ACTION_1)
@@ -413,94 +377,118 @@ def test_get_all_id_query_param_filtering_multiple_ids_provided(self):
self.assertEqual(post_resp.status_int, 201)
id_2 = self._get_actionexecution_id(post_resp)
- resp = self.app.get('/v1/executions?id=%s,%s' % (id_1, id_2), expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions?id=%s,%s" % (id_1, id_2), expect_errors=False
+ )
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 2)
def test_get_all(self):
self._get_actionexecution_id(self._do_post(LIVE_ACTION_1))
self._get_actionexecution_id(self._do_post(LIVE_ACTION_2))
- resp = self.app.get('/v1/executions')
+ resp = self.app.get("/v1/executions")
body = resp.json
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.headers['X-Total-Count'], "2")
- self.assertEqual(len(resp.json), 2,
- '/v1/executions did not return all '
- 'actionexecutions.')
+ self.assertEqual(resp.headers["X-Total-Count"], "2")
+ self.assertEqual(
+ len(resp.json), 2, "/v1/executions did not return all " "actionexecutions."
+ )
# Assert liveactions are sorted by timestamp.
for i in range(len(body) - 1):
- self.assertTrue(isotime.parse(body[i]['start_timestamp']) >=
- isotime.parse(body[i + 1]['start_timestamp']))
- self.assertIn('web_url', body[i])
- if 'end_timestamp' in body[i]:
- self.assertIn('elapsed_seconds', body[i])
+ self.assertTrue(
+ isotime.parse(body[i]["start_timestamp"])
+ >= isotime.parse(body[i + 1]["start_timestamp"])
+ )
+ self.assertIn("web_url", body[i])
+ if "end_timestamp" in body[i]:
+ self.assertIn("elapsed_seconds", body[i])
def test_get_all_invalid_offset_too_large(self):
- offset = '2141564789454123457895412237483648'
- resp = self.app.get('/v1/executions?offset=%s&limit=1' % (offset), expect_errors=True)
+ offset = "2141564789454123457895412237483648"
+ resp = self.app.get(
+ "/v1/executions?offset=%s&limit=1" % (offset), expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Offset "%s" specified is more than 32-bit int' % (offset))
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Offset "%s" specified is more than 32-bit int' % (offset),
+ )
def test_get_query(self):
- actionexecution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1))
+ actionexecution_1_id = self._get_actionexecution_id(
+ self._do_post(LIVE_ACTION_1)
+ )
- resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action'])
+ resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"])
self.assertEqual(resp.status_int, 200)
- matching_execution = filter(lambda ae: ae['id'] == actionexecution_1_id, resp.json)
- self.assertEqual(len(list(matching_execution)), 1,
- '/v1/executions did not return correct liveaction.')
+ matching_execution = filter(
+ lambda ae: ae["id"] == actionexecution_1_id, resp.json
+ )
+ self.assertEqual(
+ len(list(matching_execution)),
+ 1,
+ "/v1/executions did not return correct liveaction.",
+ )
def test_get_query_with_limit_and_offset(self):
self._get_actionexecution_id(self._do_post(LIVE_ACTION_1))
self._get_actionexecution_id(self._do_post(LIVE_ACTION_1))
- resp = self.app.get('/v1/executions')
+ resp = self.app.get("/v1/executions")
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) > 0)
- resp = self.app.get('/v1/executions?limit=1')
+ resp = self.app.get("/v1/executions?limit=1")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- resp = self.app.get('/v1/executions?limit=0', expect_errors=True)
+ resp = self.app.get("/v1/executions?limit=0", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertTrue(resp.json['faultstring'],
- u'Limit, "0" specified, must be a positive number or -1 for full \
- result set.')
+ self.assertTrue(
+ resp.json["faultstring"],
+ 'Limit, "0" specified, must be a positive number or -1 for full \
+ result set.',
+ )
- resp = self.app.get('/v1/executions?limit=-1')
+ resp = self.app.get("/v1/executions?limit=-1")
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) > 1)
- resp = self.app.get('/v1/executions?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/executions?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
- resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action'])
+ resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"])
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) > 1)
- resp = self.app.get('/v1/executions?action=%s&limit=0' %
- LIVE_ACTION_1['action'], expect_errors=True)
+ resp = self.app.get(
+ "/v1/executions?action=%s&limit=0" % LIVE_ACTION_1["action"],
+ expect_errors=True,
+ )
self.assertEqual(resp.status_int, 400)
- self.assertTrue(resp.json['faultstring'],
- u'Limit, "0" specified, must be a positive number or -1 for full \
- result set.')
-
- resp = self.app.get('/v1/executions?action=%s&limit=1' %
- LIVE_ACTION_1['action'])
+ self.assertTrue(
+ resp.json["faultstring"],
+ 'Limit, "0" specified, must be a positive number or -1 for full \
+ result set.',
+ )
+
+ resp = self.app.get(
+ "/v1/executions?action=%s&limit=1" % LIVE_ACTION_1["action"]
+ )
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- total_count = resp.headers['X-Total-Count']
+ total_count = resp.headers["X-Total-Count"]
- resp = self.app.get('/v1/executions?offset=%s&limit=1' % total_count)
+ resp = self.app.get("/v1/executions?offset=%s&limit=1" % total_count)
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json), 0)
def test_get_one_fail(self):
- resp = self.app.get('/v1/executions/100', expect_errors=True)
+ resp = self.app.get("/v1/executions/100", expect_errors=True)
self.assertEqual(resp.status_int, 404)
def test_post_delete(self):
@@ -508,13 +496,13 @@ def test_post_delete(self):
self.assertEqual(post_resp.status_int, 201)
delete_resp = self._do_delete(self._get_actionexecution_id(post_resp))
self.assertEqual(delete_resp.status_int, 200)
- self.assertEqual(delete_resp.json['status'], 'canceled')
- expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'}
- self.assertDictEqual(delete_resp.json['result'], expected_result)
+ self.assertEqual(delete_resp.json["status"], "canceled")
+ expected_result = {"message": "Action canceled by user.", "user": "stanley"}
+ self.assertDictEqual(delete_resp.json["result"], expected_result)
def test_post_delete_duplicate(self):
"""Cancels an execution twice, to ensure that a full execution object
- is returned instead of an error message
+ is returned instead of an error message
"""
post_resp = self._do_post(LIVE_ACTION_1)
@@ -524,59 +512,65 @@ def test_post_delete_duplicate(self):
for i in range(2):
delete_resp = self._do_delete(self._get_actionexecution_id(post_resp))
self.assertEqual(delete_resp.status_int, 200)
- self.assertEqual(delete_resp.json['status'], 'canceled')
- expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'}
- self.assertDictEqual(delete_resp.json['result'], expected_result)
+ self.assertEqual(delete_resp.json["status"], "canceled")
+ expected_result = {"message": "Action canceled by user.", "user": "stanley"}
+ self.assertDictEqual(delete_resp.json["result"], expected_result)
def test_post_delete_trace(self):
LIVE_ACTION_TRACE = copy.copy(LIVE_ACTION_1)
- LIVE_ACTION_TRACE['context'] = {'trace_context': {'trace_tag': 'balleilaka'}}
+ LIVE_ACTION_TRACE["context"] = {"trace_context": {"trace_tag": "balleilaka"}}
post_resp = self._do_post(LIVE_ACTION_TRACE)
self.assertEqual(post_resp.status_int, 201)
delete_resp = self._do_delete(self._get_actionexecution_id(post_resp))
self.assertEqual(delete_resp.status_int, 200)
- self.assertEqual(delete_resp.json['status'], 'canceled')
+ self.assertEqual(delete_resp.json["status"], "canceled")
trace_id = str(Trace.get_all()[0].id)
- LIVE_ACTION_TRACE['context'] = {'trace_context': {'id_': trace_id}}
+ LIVE_ACTION_TRACE["context"] = {"trace_context": {"id_": trace_id}}
post_resp = self._do_post(LIVE_ACTION_TRACE)
self.assertEqual(post_resp.status_int, 201)
delete_resp = self._do_delete(self._get_actionexecution_id(post_resp))
self.assertEqual(delete_resp.status_int, 200)
- self.assertEqual(delete_resp.json['status'], 'canceled')
+ self.assertEqual(delete_resp.json["status"], "canceled")
def test_post_nonexistent_action(self):
live_action = copy.deepcopy(LIVE_ACTION_1)
- live_action['action'] = 'mock.foobar'
+ live_action["action"] = "mock.foobar"
post_resp = self._do_post(live_action, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- expected_error = 'Action "%s" cannot be found.' % live_action['action']
- self.assertEqual(expected_error, post_resp.json['faultstring'])
+ expected_error = 'Action "%s" cannot be found.' % live_action["action"]
+ self.assertEqual(expected_error, post_resp.json["faultstring"])
def test_post_parameter_validation_failed(self):
execution = copy.deepcopy(LIVE_ACTION_1)
# Runner type does not expects additional properties.
- execution['parameters']['foo'] = 'bar'
+ execution["parameters"]["foo"] = "bar"
post_resp = self._do_post(execution, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- self.assertEqual(post_resp.json['faultstring'],
- "Additional properties are not allowed ('foo' was unexpected)")
+ self.assertEqual(
+ post_resp.json["faultstring"],
+ "Additional properties are not allowed ('foo' was unexpected)",
+ )
# Runner type expects parameter "hosts".
- execution['parameters'] = {}
+ execution["parameters"] = {}
post_resp = self._do_post(execution, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- self.assertEqual(post_resp.json['faultstring'], "'hosts' is a required property")
+ self.assertEqual(
+ post_resp.json["faultstring"], "'hosts' is a required property"
+ )
# Runner type expects parameters "cmd" to be str.
- execution['parameters'] = {"hosts": "127.0.0.1", "cmd": 1000}
+ execution["parameters"] = {"hosts": "127.0.0.1", "cmd": 1000}
post_resp = self._do_post(execution, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- self.assertIn('Value "1000" must either be a string or None. Got "int"',
- post_resp.json['faultstring'])
+ self.assertIn(
+ 'Value "1000" must either be a string or None. Got "int"',
+ post_resp.json["faultstring"],
+ )
# Runner type expects parameters "cmd" to be str.
- execution['parameters'] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1}
+ execution["parameters"] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1}
post_resp = self._do_post(execution, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
@@ -589,53 +583,55 @@ def test_post_parameter_render_failed(self):
execution = copy.deepcopy(LIVE_ACTION_1)
# Runner type does not expects additional properties.
- execution['parameters']['hosts'] = '{{ABSENT}}'
+ execution["parameters"]["hosts"] = "{{ABSENT}}"
post_resp = self._do_post(execution, expect_errors=True)
self.assertEqual(post_resp.status_int, 400)
- self.assertEqual(post_resp.json['faultstring'],
- 'Dependency unsatisfied in variable "ABSENT"')
+ self.assertEqual(
+ post_resp.json["faultstring"], 'Dependency unsatisfied in variable "ABSENT"'
+ )
def test_post_parameter_validation_explicit_none(self):
execution = copy.deepcopy(LIVE_ACTION_1)
- execution['parameters']['a'] = None
+ execution["parameters"]["a"] = None
post_resp = self._do_post(execution)
self.assertEqual(post_resp.status_int, 201)
def test_post_with_st2_context_in_headers(self):
resp = self._do_post(copy.deepcopy(LIVE_ACTION_1))
self.assertEqual(resp.status_int, 201)
- parent_user = resp.json['context']['user']
- parent_exec_id = str(resp.json['id'])
+ parent_user = resp.json["context"]["user"]
+ parent_exec_id = str(resp.json["id"])
context = {
- 'parent': {
- 'execution_id': parent_exec_id,
- 'user': parent_user
- },
- 'user': None,
- 'other': {'k1': 'v1'}
+ "parent": {"execution_id": parent_exec_id, "user": parent_user},
+ "user": None,
+ "other": {"k1": "v1"},
+ }
+ headers = {
+ "content-type": "application/json",
+ "st2-context": json.dumps(context),
}
- headers = {'content-type': 'application/json', 'st2-context': json.dumps(context)}
resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], parent_user, 'Should use parent\'s user.')
+ self.assertEqual(
+ resp.json["context"]["user"], parent_user, "Should use parent's user."
+ )
expected = {
- 'parent': {
- 'execution_id': parent_exec_id,
- 'user': parent_user
- },
- 'user': parent_user,
- 'pack': 'sixpack',
- 'other': {'k1': 'v1'}
+ "parent": {"execution_id": parent_exec_id, "user": parent_user},
+ "user": parent_user,
+ "pack": "sixpack",
+ "other": {"k1": "v1"},
}
- self.assertDictEqual(resp.json['context'], expected)
+ self.assertDictEqual(resp.json["context"], expected)
def test_post_with_st2_context_in_headers_failed(self):
resp = self._do_post(copy.deepcopy(LIVE_ACTION_1))
self.assertEqual(resp.status_int, 201)
- headers = {'content-type': 'application/json', 'st2-context': 'foobar'}
- resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True)
+ headers = {"content-type": "application/json", "st2-context": "foobar"}
+ resp = self._do_post(
+ copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertIn('Unable to convert st2-context', resp.json['faultstring'])
+ self.assertIn("Unable to convert st2-context", resp.json["faultstring"])
def test_re_run_success(self):
# Create a new execution
@@ -645,12 +641,16 @@ def test_re_run_success(self):
# Re-run created execution (no parameters overrides)
data = {}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
# Re-run created execution (with parameters overrides)
- data = {'parameters': {'a': 'val1'}}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ data = {"parameters": {"a": "val1"}}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
def test_re_run_with_delay(self):
@@ -659,21 +659,24 @@ def test_re_run_with_delay(self):
execution_id = self._get_actionexecution_id(post_resp)
delay_time = 100
- data = {'delay': delay_time}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ data = {"delay": delay_time}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
resp = json.loads(re_run_resp.body)
- self.assertEqual(resp['delay'], delay_time)
+ self.assertEqual(resp["delay"], delay_time)
def test_re_run_with_incorrect_delay(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- delay_time = 'sudo apt -y upgrade winson'
- data = {'delay': delay_time}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ delay_time = "sudo apt -y upgrade winson"
+ data = {"delay": delay_time}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
def test_re_run_with_very_large_delay(self):
@@ -682,8 +685,10 @@ def test_re_run_with_very_large_delay(self):
execution_id = self._get_actionexecution_id(post_resp)
delay_time = 10 ** 10
- data = {'delay': delay_time}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ data = {"delay": delay_time}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
def test_re_run_delayed_aciton_with_no_delay(self):
@@ -692,11 +697,13 @@ def test_re_run_delayed_aciton_with_no_delay(self):
execution_id = self._get_actionexecution_id(post_resp)
delay_time = 0
- data = {'delay': delay_time}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ data = {"delay": delay_time}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
resp = json.loads(re_run_resp.body)
- self.assertNotIn('delay', resp.keys())
+ self.assertNotIn("delay", resp.keys())
def test_re_run_failure_execution_doesnt_exist(self):
# Create a new execution
@@ -705,8 +712,9 @@ def test_re_run_failure_execution_doesnt_exist(self):
# Re-run created execution (override parameter with an invalid value)
data = {}
- re_run_resp = self.app.post_json('/v1/executions/doesntexist/re_run',
- data, expect_errors=True)
+ re_run_resp = self.app.post_json(
+ "/v1/executions/doesntexist/re_run", data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 404)
def test_re_run_failure_parameter_override_invalid_type(self):
@@ -716,12 +724,15 @@ def test_re_run_failure_parameter_override_invalid_type(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (override parameter and task together)
- data = {'parameters': {'a': 1000}}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"parameters": {"a": 1000}}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
- self.assertIn('Value "1000" must either be a string or None. Got "int"',
- re_run_resp.json['faultstring'])
+ self.assertIn(
+ 'Value "1000" must either be a string or None. Got "int"',
+ re_run_resp.json["faultstring"],
+ )
def test_template_param(self):
@@ -731,31 +742,46 @@ def test_template_param(self):
# Assert that the template in the parameter default value
# was rendered and st2kv was used
- self.assertEqual(post_resp.json['parameters']['intparam'], 0)
+ self.assertEqual(post_resp.json["parameters"]["intparam"], 0)
# Test with live param
live_int_param = 3
livaction_with_params = copy.deepcopy(LIVE_ACTION_DEFAULT_TEMPLATE)
- livaction_with_params['parameters'] = {
- "intparam": live_int_param
- }
+ livaction_with_params["parameters"] = {"intparam": live_int_param}
post_resp = self._do_post(livaction_with_params)
self.assertEqual(post_resp.status_int, 201)
# Assert that the template in the parameter default value
# was not rendered, and the provided parameter was used
- self.assertEqual(post_resp.json['parameters']['intparam'], live_int_param)
+ self.assertEqual(post_resp.json["parameters"]["intparam"], live_int_param)
def test_template_encrypted_params(self):
# register datastore values which are used in this test case
KeyValuePairAPI._setup_crypto()
register_items = [
- {'name': 'secret', 'secret': True,
- 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')},
- {'name': 'stanley:secret', 'secret': True, 'scope': FULL_USER_SCOPE,
- 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')},
- {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE,
- 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'baz')},
+ {
+ "name": "secret",
+ "secret": True,
+ "value": crypto_utils.symmetric_encrypt(
+ KeyValuePairAPI.crypto_key, "foo"
+ ),
+ },
+ {
+ "name": "stanley:secret",
+ "secret": True,
+ "scope": FULL_USER_SCOPE,
+ "value": crypto_utils.symmetric_encrypt(
+ KeyValuePairAPI.crypto_key, "bar"
+ ),
+ },
+ {
+ "name": "user1:secret",
+ "secret": True,
+ "scope": FULL_USER_SCOPE,
+ "value": crypto_utils.symmetric_encrypt(
+ KeyValuePairAPI.crypto_key, "baz"
+ ),
+ },
]
kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items]
@@ -763,43 +789,53 @@ def test_template_encrypted_params(self):
# 1. parameters are not marked as secret
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], 'stanley')
- self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo')
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar')
+ self.assertEqual(resp.json["context"]["user"], "stanley")
+ self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo")
+ self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar")
# 2. parameters are marked as secret
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], 'stanley')
- self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(resp.json["context"]["user"], "stanley")
+ self.assertEqual(
+ resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE
+ )
+ self.assertEqual(
+ resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE
+ )
# After switching to the 'user1', that value will be read from switched user's scope
- self.use_user(UserDB(name='user1'))
+ self.use_user(UserDB(name="user1"))
# 1. parameters are not marked as secret
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], 'user1')
- self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo')
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'baz')
+ self.assertEqual(resp.json["context"]["user"], "user1")
+ self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo")
+ self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "baz")
# 2. parameters are marked as secret
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], 'user1')
- self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(resp.json["context"]["user"], "user1")
+ self.assertEqual(
+ resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE
+ )
+ self.assertEqual(
+ resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE
+ )
# This switches to the 'user2', there is no value in that user's scope. When a request
# that tries to evaluate Jinja expression to decrypt empty value is sent, a HTTP response
# which has 4xx status code will be returned.
- self.use_user(UserDB(name='user2'))
+ self.use_user(UserDB(name="user2"))
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- 'Failed to render parameter "encrypted_user_param": Referenced datastore '
- 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Failed to render parameter "encrypted_user_param": Referenced datastore '
+ 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string',
+ )
# clean-up values that are registered at first
for kvp in kvps:
@@ -808,7 +844,9 @@ def test_template_encrypted_params(self):
def test_template_encrypted_params_without_registering(self):
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'].index('Failed to render parameter'), 0)
+ self.assertEqual(
+ resp.json["faultstring"].index("Failed to render parameter"), 0
+ )
def test_re_run_workflow_success(self):
# Create a new execution
@@ -818,26 +856,25 @@ def test_re_run_workflow_success(self):
# Re-run created execution (tasks option for non workflow)
data = {}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 201)
# Get the trace
- trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id)
+ trace = trace_service.get_trace_db_by_action_execution(
+ action_execution_id=execution_id
+ )
expected_context = {
- 'user': 'stanley',
- 'pack': 'starterpack',
- 're-run': {
- 'ref': execution_id
- },
- 'trace_context': {
- 'id_': str(trace.id)
- }
+ "user": "stanley",
+ "pack": "starterpack",
+ "re-run": {"ref": execution_id},
+ "trace_context": {"id_": str(trace.id)},
}
- self.assertDictEqual(re_run_resp.json['context'], expected_context)
+ self.assertDictEqual(re_run_resp.json["context"], expected_context)
def test_re_run_workflow_task_success(self):
# Create a new execution
@@ -846,28 +883,26 @@ def test_re_run_workflow_task_success(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (tasks option for non workflow)
- data = {'tasks': ['x']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"tasks": ["x"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 201)
# Get the trace
- trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id)
+ trace = trace_service.get_trace_db_by_action_execution(
+ action_execution_id=execution_id
+ )
expected_context = {
- 'pack': 'starterpack',
- 'user': 'stanley',
- 're-run': {
- 'ref': execution_id,
- 'tasks': data['tasks']
- },
- 'trace_context': {
- 'id_': str(trace.id)
- }
+ "pack": "starterpack",
+ "user": "stanley",
+ "re-run": {"ref": execution_id, "tasks": data["tasks"]},
+ "trace_context": {"id_": str(trace.id)},
}
- self.assertDictEqual(re_run_resp.json['context'], expected_context)
+ self.assertDictEqual(re_run_resp.json["context"], expected_context)
def test_re_run_workflow_tasks_success(self):
# Create a new execution
@@ -876,28 +911,26 @@ def test_re_run_workflow_tasks_success(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (tasks option for non workflow)
- data = {'tasks': ['x', 'y']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"tasks": ["x", "y"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 201)
# Get the trace
- trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id)
+ trace = trace_service.get_trace_db_by_action_execution(
+ action_execution_id=execution_id
+ )
expected_context = {
- 'pack': 'starterpack',
- 'user': 'stanley',
- 're-run': {
- 'ref': execution_id,
- 'tasks': data['tasks']
- },
- 'trace_context': {
- 'id_': str(trace.id)
- }
+ "pack": "starterpack",
+ "user": "stanley",
+ "re-run": {"ref": execution_id, "tasks": data["tasks"]},
+ "trace_context": {"id_": str(trace.id)},
}
- self.assertDictEqual(re_run_resp.json['context'], expected_context)
+ self.assertDictEqual(re_run_resp.json["context"], expected_context)
def test_re_run_workflow_tasks_reset_success(self):
# Create a new execution
@@ -906,29 +939,30 @@ def test_re_run_workflow_tasks_reset_success(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (tasks option for non workflow)
- data = {'tasks': ['x', 'y'], 'reset': ['y']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"tasks": ["x", "y"], "reset": ["y"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 201)
# Get the trace
- trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id)
+ trace = trace_service.get_trace_db_by_action_execution(
+ action_execution_id=execution_id
+ )
expected_context = {
- 'pack': 'starterpack',
- 'user': 'stanley',
- 're-run': {
- 'ref': execution_id,
- 'tasks': data['tasks'],
- 'reset': data['reset']
+ "pack": "starterpack",
+ "user": "stanley",
+ "re-run": {
+ "ref": execution_id,
+ "tasks": data["tasks"],
+ "reset": data["reset"],
},
- 'trace_context': {
- 'id_': str(trace.id)
- }
+ "trace_context": {"id_": str(trace.id)},
}
- self.assertDictEqual(re_run_resp.json['context'], expected_context)
+ self.assertDictEqual(re_run_resp.json["context"], expected_context)
def test_re_run_failure_tasks_option_for_non_workflow(self):
# Create a new execution
@@ -937,14 +971,15 @@ def test_re_run_failure_tasks_option_for_non_workflow(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (tasks option for non workflow)
- data = {'tasks': ['x']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"tasks": ["x"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
- expected_substring = 'only supported for Orquesta workflows'
- self.assertIn(expected_substring, re_run_resp.json['faultstring'])
+ expected_substring = "only supported for Orquesta workflows"
+ self.assertIn(expected_substring, re_run_resp.json["faultstring"])
def test_re_run_workflow_failure_given_both_params_and_tasks(self):
# Create a new execution
@@ -953,13 +988,16 @@ def test_re_run_workflow_failure_given_both_params_and_tasks(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (override parameter with an invalid value)
- data = {'parameters': {'a': 'xyz'}, 'tasks': ['x']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"parameters": {"a": "xyz"}, "tasks": ["x"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
- self.assertIn('not supported when re-running task(s) for a workflow',
- re_run_resp.json['faultstring'])
+ self.assertIn(
+ "not supported when re-running task(s) for a workflow",
+ re_run_resp.json["faultstring"],
+ )
def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self):
# Create a new execution
@@ -968,13 +1006,16 @@ def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (override parameter with an invalid value)
- data = {'parameters': {'a': 'xyz'}, 'reset': ['x']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"parameters": {"a": "xyz"}, "reset": ["x"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
- self.assertIn('not supported when re-running task(s) for a workflow',
- re_run_resp.json['faultstring'])
+ self.assertIn(
+ "not supported when re-running task(s) for a workflow",
+ re_run_resp.json["faultstring"],
+ )
def test_re_run_workflow_failure_invalid_reset_tasks(self):
# Create a new execution
@@ -983,13 +1024,16 @@ def test_re_run_workflow_failure_invalid_reset_tasks(self):
execution_id = self._get_actionexecution_id(post_resp)
# Re-run created execution (override parameter with an invalid value)
- data = {'tasks': ['x'], 'reset': ['y']}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id),
- data, expect_errors=True)
+ data = {"tasks": ["x"], "reset": ["y"]}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True
+ )
self.assertEqual(re_run_resp.status_int, 400)
- self.assertIn('tasks to reset does not match the tasks to rerun',
- re_run_resp.json['faultstring'])
+ self.assertIn(
+ "tasks to reset does not match the tasks to rerun",
+ re_run_resp.json["faultstring"],
+ )
def test_re_run_secret_parameter(self):
# Create a new execution
@@ -999,96 +1043,100 @@ def test_re_run_secret_parameter(self):
# Re-run created execution (no parameters overrides)
data = {}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
execution_id = self._get_actionexecution_id(re_run_resp)
- re_run_result = self._do_get_one(execution_id,
- params={'show_secrets': True},
- expect_errors=True)
- self.assertEqual(re_run_result.json['parameters'], LIVE_ACTION_1['parameters'])
+ re_run_result = self._do_get_one(
+ execution_id, params={"show_secrets": True}, expect_errors=True
+ )
+ self.assertEqual(re_run_result.json["parameters"], LIVE_ACTION_1["parameters"])
# Re-run created execution (with parameters overrides)
- data = {'parameters': {'a': 'val1', 'd': ANOTHER_SUPER_SECRET_PARAMETER}}
- re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data)
+ data = {"parameters": {"a": "val1", "d": ANOTHER_SUPER_SECRET_PARAMETER}}
+ re_run_resp = self.app.post_json(
+ "/v1/executions/%s/re_run" % (execution_id), data
+ )
self.assertEqual(re_run_resp.status_int, 201)
execution_id = self._get_actionexecution_id(re_run_resp)
- re_run_result = self._do_get_one(execution_id,
- params={'show_secrets': True},
- expect_errors=True)
- self.assertEqual(re_run_result.json['parameters']['d'], data['parameters']['d'])
+ re_run_result = self._do_get_one(
+ execution_id, params={"show_secrets": True}, expect_errors=True
+ )
+ self.assertEqual(re_run_result.json["parameters"]["d"], data["parameters"]["d"])
def test_put_status_and_result(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}}
+ updates = {"status": "succeeded", "result": {"stdout": "foobar"}}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'succeeded')
- self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'})
+ self.assertEqual(put_resp.json["status"], "succeeded")
+ self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"})
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'succeeded')
- self.assertDictEqual(get_resp.json['result'], {'stdout': 'foobar'})
+ self.assertEqual(get_resp.json["status"], "succeeded")
+ self.assertDictEqual(get_resp.json["result"], {"stdout": "foobar"})
def test_put_bad_state(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'married'}
+ updates = {"status": "married"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('\'married\' is not one of', put_resp.json['faultstring'])
+ self.assertIn("'married' is not one of", put_resp.json["faultstring"])
def test_put_bad_result(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'result': 'foobar'}
+ updates = {"result": "foobar"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('is not of type \'object\'', put_resp.json['faultstring'])
+ self.assertIn("is not of type 'object'", put_resp.json["faultstring"])
def test_put_bad_property(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'abandoned', 'foo': 'bar'}
+ updates = {"status": "abandoned", "foo": "bar"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('Additional properties are not allowed', put_resp.json['faultstring'])
+ self.assertIn(
+ "Additional properties are not allowed", put_resp.json["faultstring"]
+ )
def test_put_status_to_completed_execution(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}}
+ updates = {"status": "succeeded", "result": {"stdout": "foobar"}}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'succeeded')
- self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'})
+ self.assertEqual(put_resp.json["status"], "succeeded")
+ self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"})
- updates = {'status': 'abandoned'}
+ updates = {"status": "abandoned"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- @mock.patch.object(
- LiveAction, 'get_by_id',
- mock.MagicMock(return_value=None))
+ @mock.patch.object(LiveAction, "get_by_id", mock.MagicMock(return_value=None))
def test_put_execution_missing_liveaction(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}}
+ updates = {"status": "succeeded", "result": {"stdout": "foobar"}}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 500)
@@ -1098,19 +1146,19 @@ def test_put_pause_unsupported(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('it is not supported', put_resp.json['faultstring'])
+ self.assertIn("it is not supported", put_resp.json["faultstring"])
- updates = {'status': 'paused'}
+ updates = {"status": "paused"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('it is not supported', put_resp.json['faultstring'])
+ self.assertIn("it is not supported", put_resp.json["faultstring"])
def test_put_pause(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
@@ -1118,50 +1166,50 @@ def test_put_pause(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'running'}
+ updates = {"status": "running"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'pausing')
- self.assertIsNone(put_resp.json.get('result'))
+ self.assertEqual(put_resp.json["status"], "pausing")
+ self.assertIsNone(put_resp.json.get("result"))
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'pausing')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "pausing")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_put_pause_not_running(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
- self.assertEqual(post_resp.json['status'], 'requested')
+ self.assertEqual(post_resp.json["status"], "requested")
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('is not in a running state', put_resp.json['faultstring'])
+ self.assertIn("is not in a running state", put_resp.json["faultstring"])
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'requested')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "requested")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_put_pause_already_pausing(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
@@ -1169,44 +1217,46 @@ def test_put_pause_already_pausing(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'running'}
+ updates = {"status": "running"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'pausing')
- self.assertIsNone(put_resp.json.get('result'))
+ self.assertEqual(put_resp.json["status"], "pausing")
+ self.assertIsNone(put_resp.json.get("result"))
- with mock.patch.object(action_service, 'update_status', return_value=None) as mocked:
- updates = {'status': 'pausing'}
+ with mock.patch.object(
+ action_service, "update_status", return_value=None
+ ) as mocked:
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'pausing')
+ self.assertEqual(put_resp.json["status"], "pausing")
mocked.assert_not_called()
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'pausing')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "pausing")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_put_resume_unsupported(self):
post_resp = self._do_post(LIVE_ACTION_1)
self.assertEqual(post_resp.status_int, 201)
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'resuming'}
+ updates = {"status": "resuming"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- self.assertIn('it is not supported', put_resp.json['faultstring'])
+ self.assertIn("it is not supported", put_resp.json["faultstring"])
def test_put_resume(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
@@ -1214,44 +1264,46 @@ def test_put_resume(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'running'}
+ updates = {"status": "running"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'pausing')
- self.assertIsNone(put_resp.json.get('result'))
+ self.assertEqual(put_resp.json["status"], "pausing")
+ self.assertIsNone(put_resp.json.get("result"))
# Manually change the status to paused because only the runner pause method should
# set the paused status directly to the liveaction and execution database objects.
liveaction_id = self._get_liveaction_id(post_resp)
liveaction = action_db_util.get_liveaction_by_id(liveaction_id)
- action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED)
+ action_service.update_status(
+ liveaction, action_constants.LIVEACTION_STATUS_PAUSED
+ )
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'paused')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "paused")
+ self.assertIsNone(get_resp.json.get("result"))
- updates = {'status': 'resuming'}
+ updates = {"status": "resuming"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'resuming')
- self.assertIsNone(put_resp.json.get('result'))
+ self.assertEqual(put_resp.json["status"], "resuming")
+ self.assertIsNone(put_resp.json.get("result"))
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'resuming')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "resuming")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_put_resume_not_paused(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
@@ -1259,33 +1311,35 @@ def test_put_resume_not_paused(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'running'}
+ updates = {"status": "running"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
- updates = {'status': 'pausing'}
+ updates = {"status": "pausing"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'pausing')
- self.assertIsNone(put_resp.json.get('result'))
+ self.assertEqual(put_resp.json["status"], "pausing")
+ self.assertIsNone(put_resp.json.get("result"))
- updates = {'status': 'resuming'}
+ updates = {"status": "resuming"}
put_resp = self._do_put(execution_id, updates, expect_errors=True)
self.assertEqual(put_resp.status_int, 400)
- expected_error_message = 'it is in "pausing" state and not in "paused" state'
- self.assertIn(expected_error_message, put_resp.json['faultstring'])
+ expected_error_message = (
+ 'it is in "pausing" state and not in "paused" state'
+ )
+ self.assertIn(expected_error_message, put_resp.json["faultstring"])
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'pausing')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "pausing")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_put_resume_already_running(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"])
try:
post_resp = self._do_post(LIVE_ACTION_1)
@@ -1293,24 +1347,26 @@ def test_put_resume_already_running(self):
execution_id = self._get_actionexecution_id(post_resp)
- updates = {'status': 'running'}
+ updates = {"status": "running"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
- with mock.patch.object(action_service, 'update_status', return_value=None) as mocked:
- updates = {'status': 'resuming'}
+ with mock.patch.object(
+ action_service, "update_status", return_value=None
+ ) as mocked:
+ updates = {"status": "resuming"}
put_resp = self._do_put(execution_id, updates)
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['status'], 'running')
+ self.assertEqual(put_resp.json["status"], "running")
mocked.assert_not_called()
get_resp = self._do_get_one(execution_id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['status'], 'running')
- self.assertIsNone(get_resp.json.get('result'))
+ self.assertEqual(get_resp.json["status"], "running")
+ self.assertIsNone(get_resp.json.get("result"))
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"])
def test_get_inquiry_mask(self):
"""Ensure Inquiry responses are masked when retrieved via ActionExecution GET
@@ -1327,194 +1383,213 @@ def test_get_inquiry_mask(self):
self.assertEqual(get_resp.status_int, 200)
resp = json.loads(get_resp.body)
- self.assertEqual(resp['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(
+ resp["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE
+ )
post_resp = self._do_post(LIVE_ACTION_INQUIRY)
actionexecution_id = self._get_actionexecution_id(post_resp)
- get_resp = self._do_get_one(actionexecution_id, params={'show_secrets': True})
+ get_resp = self._do_get_one(actionexecution_id, params={"show_secrets": True})
self.assertEqual(get_resp.status_int, 200)
resp = json.loads(get_resp.body)
- self.assertEqual(resp['result']['response']['secondfactor'], "supersecretvalue")
+ self.assertEqual(resp["result"]["response"]["secondfactor"], "supersecretvalue")
def test_get_include_attributes_and_secret_parameters(self):
# Verify that secret parameters are correctly masked when using ?include_attributes filter
self._do_post(LIVE_ACTION_WITH_SECRET_PARAM)
urls = [
- '/v1/actionexecutions?include_attributes=parameters',
- '/v1/actionexecutions?include_attributes=parameters,action',
- '/v1/actionexecutions?include_attributes=parameters,runner',
- '/v1/actionexecutions?include_attributes=parameters,action,runner'
+ "/v1/actionexecutions?include_attributes=parameters",
+ "/v1/actionexecutions?include_attributes=parameters,action",
+ "/v1/actionexecutions?include_attributes=parameters,runner",
+ "/v1/actionexecutions?include_attributes=parameters,action,runner",
]
for url in urls:
- resp = self.app.get(url + '&limit=1')
+ resp = self.app.get(url + "&limit=1")
- self.assertIn('parameters', resp.json[0])
- self.assertEqual(resp.json[0]['parameters']['a'], 'param a')
- self.assertEqual(resp.json[0]['parameters']['d'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json[0]['parameters']['password'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost')
+ self.assertIn("parameters", resp.json[0])
+ self.assertEqual(resp.json[0]["parameters"]["a"], "param a")
+ self.assertEqual(resp.json[0]["parameters"]["d"], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(
+ resp.json[0]["parameters"]["password"], MASKED_ATTRIBUTE_VALUE
+ )
+ self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost")
# With ?show_secrets=True
urls = [
- ('/v1/actionexecutions?&include_attributes=parameters'),
- ('/v1/actionexecutions?include_attributes=parameters,action'),
- ('/v1/actionexecutions?include_attributes=parameters,runner'),
- ('/v1/actionexecutions?include_attributes=parameters,action,runner')
+ ("/v1/actionexecutions?&include_attributes=parameters"),
+ ("/v1/actionexecutions?include_attributes=parameters,action"),
+ ("/v1/actionexecutions?include_attributes=parameters,runner"),
+ ("/v1/actionexecutions?include_attributes=parameters,action,runner"),
]
for url in urls:
- resp = self.app.get(url + '&limit=1&show_secrets=True')
+ resp = self.app.get(url + "&limit=1&show_secrets=True")
- self.assertIn('parameters', resp.json[0])
- self.assertEqual(resp.json[0]['parameters']['a'], 'param a')
- self.assertEqual(resp.json[0]['parameters']['d'], 'secretpassword1')
- self.assertEqual(resp.json[0]['parameters']['password'], 'secretpassword2')
- self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost')
+ self.assertIn("parameters", resp.json[0])
+ self.assertEqual(resp.json[0]["parameters"]["a"], "param a")
+ self.assertEqual(resp.json[0]["parameters"]["d"], "secretpassword1")
+ self.assertEqual(resp.json[0]["parameters"]["password"], "secretpassword2")
+ self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost")
# NOTE: We don't allow exclusion of attributes such as "action" and "runner" because
# that would break secrets masking
urls = [
- '/v1/actionexecutions?limit=1&exclude_attributes=action',
- '/v1/actionexecutions?limit=1&exclude_attributes=runner',
- '/v1/actionexecutions?limit=1&exclude_attributes=action,runner',
+ "/v1/actionexecutions?limit=1&exclude_attributes=action",
+ "/v1/actionexecutions?limit=1&exclude_attributes=runner",
+ "/v1/actionexecutions?limit=1&exclude_attributes=action,runner",
]
for url in urls:
- resp = self.app.get(url + '&limit=1', expect_errors=True)
+ resp = self.app.get(url + "&limit=1", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertTrue('Invalid or unsupported exclude attribute specified:' in
- resp.json['faultstring'])
+ self.assertTrue(
+ "Invalid or unsupported exclude attribute specified:"
+ in resp.json["faultstring"]
+ )
def test_get_single_attribute_success(self):
- exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id']
+ exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"]
- resp = self.app.get('/v1/executions/%s/attribute/status' % (exec_id))
+ resp = self.app.get("/v1/executions/%s/attribute/status" % (exec_id))
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, 'requested')
+ self.assertEqual(resp.json, "requested")
- resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id))
+ resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id))
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.json, None)
- resp = self.app.get('/v1/executions/%s/attribute/trigger_instance' % (exec_id))
+ resp = self.app.get("/v1/executions/%s/attribute/trigger_instance" % (exec_id))
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.json, None)
data = {}
- data['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- data['result'] = {'foo': 'bar'}
+ data["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ data["result"] = {"foo": "bar"}
- resp = self.app.put_json('/v1/executions/%s' % (exec_id), data)
+ resp = self.app.put_json("/v1/executions/%s" % (exec_id), data)
self.assertEqual(resp.status_int, 200)
- resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id))
+ resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id))
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, data['result'])
+ self.assertEqual(resp.json, data["result"])
def test_get_single_attribute_failure_invalid_attribute(self):
- exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id']
+ exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"]
- resp = self.app.get('/v1/executions/%s/attribute/start_timestamp' % (exec_id),
- expect_errors=True)
+ resp = self.app.get(
+ "/v1/executions/%s/attribute/start_timestamp" % (exec_id),
+ expect_errors=True,
+ )
self.assertEqual(resp.status_int, 400)
- self.assertTrue('Invalid attribute "start_timestamp" specified.' in
- resp.json['faultstring'])
+ self.assertTrue(
+ 'Invalid attribute "start_timestamp" specified.' in resp.json["faultstring"]
+ )
def test_get_single_include_attributes_and_secret_parameters(self):
# Verify that secret parameters are correctly masked when using ?include_attributes filter
self._do_post(LIVE_ACTION_WITH_SECRET_PARAM)
- exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id']
+ exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"]
# FYI, the response always contains the 'id' parameter
urls = [
{
- 'url': '/v1/executions/%s?include_attributes=parameters' % (exec_id),
- 'expected_parameters': ['id', 'parameters'],
+ "url": "/v1/executions/%s?include_attributes=parameters" % (exec_id),
+ "expected_parameters": ["id", "parameters"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'action'],
+ "url": "/v1/executions/%s?include_attributes=parameters,action"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "action"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'runner'],
+ "url": "/v1/executions/%s?include_attributes=parameters,runner"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "runner"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'action', 'runner'],
- }
+ "url": "/v1/executions/%s?include_attributes=parameters,action,runner"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "action", "runner"],
+ },
]
for item in urls:
- url = item['url']
+ url = item["url"]
resp = self.app.get(url)
- self.assertIn('parameters', resp.json)
- self.assertEqual(resp.json['parameters']['a'], 'param a')
- self.assertEqual(resp.json['parameters']['d'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json['parameters']['password'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(resp.json['parameters']['hosts'], 'localhost')
+ self.assertIn("parameters", resp.json)
+ self.assertEqual(resp.json["parameters"]["a"], "param a")
+ self.assertEqual(resp.json["parameters"]["d"], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(
+ resp.json["parameters"]["password"], MASKED_ATTRIBUTE_VALUE
+ )
+ self.assertEqual(resp.json["parameters"]["hosts"], "localhost")
# ensure that the response has only the keys we epect, no more, no less
resp_keys = set(resp.json.keys())
- expected_params = set(item['expected_parameters'])
+ expected_params = set(item["expected_parameters"])
diff = resp_keys.symmetric_difference(expected_params)
self.assertEqual(diff, set())
# With ?show_secrets=True
urls = [
{
- 'url': '/v1/executions/%s?&include_attributes=parameters' % (exec_id),
- 'expected_parameters': ['id', 'parameters'],
+ "url": "/v1/executions/%s?&include_attributes=parameters" % (exec_id),
+ "expected_parameters": ["id", "parameters"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'action'],
+ "url": "/v1/executions/%s?include_attributes=parameters,action"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "action"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'runner'],
+ "url": "/v1/executions/%s?include_attributes=parameters,runner"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "runner"],
},
{
- 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id),
- 'expected_parameters': ['id', 'parameters', 'action', 'runner'],
+ "url": "/v1/executions/%s?include_attributes=parameters,action,runner"
+ % (exec_id),
+ "expected_parameters": ["id", "parameters", "action", "runner"],
},
]
for item in urls:
- url = item['url']
- resp = self.app.get(url + '&show_secrets=True')
+ url = item["url"]
+ resp = self.app.get(url + "&show_secrets=True")
- self.assertIn('parameters', resp.json)
- self.assertEqual(resp.json['parameters']['a'], 'param a')
- self.assertEqual(resp.json['parameters']['d'], 'secretpassword1')
- self.assertEqual(resp.json['parameters']['password'], 'secretpassword2')
- self.assertEqual(resp.json['parameters']['hosts'], 'localhost')
+ self.assertIn("parameters", resp.json)
+ self.assertEqual(resp.json["parameters"]["a"], "param a")
+ self.assertEqual(resp.json["parameters"]["d"], "secretpassword1")
+ self.assertEqual(resp.json["parameters"]["password"], "secretpassword2")
+ self.assertEqual(resp.json["parameters"]["hosts"], "localhost")
# ensure that the response has only the keys we epect, no more, no less
resp_keys = set(resp.json.keys())
- expected_params = set(item['expected_parameters'])
+ expected_params = set(item["expected_parameters"])
diff = resp_keys.symmetric_difference(expected_params)
self.assertEqual(diff, set())
# NOTE: We don't allow exclusion of attributes such as "action" and "runner" because
# that would break secrets masking
urls = [
- '/v1/executions/%s?limit=1&exclude_attributes=action',
- '/v1/executions/%s?limit=1&exclude_attributes=runner',
- '/v1/executions/%s?limit=1&exclude_attributes=action,runner',
+ "/v1/executions/%s?limit=1&exclude_attributes=action",
+ "/v1/executions/%s?limit=1&exclude_attributes=runner",
+ "/v1/executions/%s?limit=1&exclude_attributes=action,runner",
]
for url in urls:
resp = self.app.get(url, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertTrue('Invalid or unsupported exclude attribute specified:' in
- resp.json['faultstring'])
+ self.assertTrue(
+ "Invalid or unsupported exclude attribute specified:"
+ in resp.json["faultstring"]
+ )
def _insert_mock_models(self):
execution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1))
@@ -1522,37 +1597,44 @@ def _insert_mock_models(self):
return [execution_1_id, execution_2_id]
-class ActionExecutionOutputControllerTestCase(BaseActionExecutionControllerTestCase,
- FunctionalTest):
+class ActionExecutionOutputControllerTestCase(
+ BaseActionExecutionControllerTestCase, FunctionalTest
+):
def test_get_output_id_last_no_executions_in_the_database(self):
ActionExecution.query().delete()
- resp = self.app.get('/v1/executions/last/output', expect_errors=True)
+ resp = self.app.get("/v1/executions/last/output", expect_errors=True)
self.assertEqual(resp.status_int, http_client.BAD_REQUEST)
- self.assertEqual(resp.json['faultstring'], 'No executions found in the database')
+ self.assertEqual(
+ resp.json["faultstring"], "No executions found in the database"
+ )
def test_get_output_running_execution(self):
# Only the output produced so far should be returned
# Test the execution output API endpoint for execution which is running (blocking)
status = action_constants.LIVEACTION_STATUS_RUNNING
timestamp = date_utils.get_datetime_utc_now()
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
action_execution_db = ActionExecution.add_or_update(action_execution_db)
- output_params = dict(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout before start\n')
+ output_params = dict(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout before start\n",
+ )
def insert_mock_data(data):
- output_params['data'] = data
+ output_params["data"] = data
output_db = ActionExecutionOutputDB(**output_params)
ActionExecutionOutput.add_or_update(output_db)
@@ -1561,45 +1643,51 @@ def insert_mock_data(data):
ActionExecutionOutput.add_or_update(output_db, publish=False)
# Retrieve data while execution is running - data produced so far should be retrieved
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
- lines = resp.text.strip().split('\n')
+ lines = resp.text.strip().split("\n")
lines = [line for line in lines if line.strip()]
self.assertEqual(len(lines), 1)
- self.assertEqual(lines[0], 'stdout before start')
+ self.assertEqual(lines[0], "stdout before start")
# Insert more data
- insert_mock_data('stdout mid 1\n')
+ insert_mock_data("stdout mid 1\n")
# Retrieve data while execution is running - data produced so far should be retrieved
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
- lines = resp.text.strip().split('\n')
+ lines = resp.text.strip().split("\n")
lines = [line for line in lines if line.strip()]
self.assertEqual(len(lines), 2)
- self.assertEqual(lines[0], 'stdout before start')
- self.assertEqual(lines[1], 'stdout mid 1')
+ self.assertEqual(lines[0], "stdout before start")
+ self.assertEqual(lines[1], "stdout mid 1")
# Insert more data
- insert_mock_data('stdout pre finish 1\n')
+ insert_mock_data("stdout pre finish 1\n")
# Transition execution to completed state
action_execution_db.status = action_constants.LIVEACTION_STATUS_SUCCEEDED
action_execution_db = ActionExecution.add_or_update(action_execution_db)
# Execution has finished
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
- lines = resp.text.strip().split('\n')
+ lines = resp.text.strip().split("\n")
lines = [line for line in lines if line.strip()]
self.assertEqual(len(lines), 3)
- self.assertEqual(lines[0], 'stdout before start')
- self.assertEqual(lines[1], 'stdout mid 1')
- self.assertEqual(lines[2], 'stdout pre finish 1')
+ self.assertEqual(lines[0], "stdout before start")
+ self.assertEqual(lines[1], "stdout mid 1")
+ self.assertEqual(lines[2], "stdout pre finish 1")
def test_get_output_finished_execution(self):
# Test the execution output API endpoint for execution which has finished
@@ -1607,42 +1695,50 @@ def test_get_output_finished_execution(self):
# Insert mock execution and output objects
status = action_constants.LIVEACTION_STATUS_SUCCEEDED
timestamp = date_utils.get_datetime_utc_now()
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
action_execution_db = ActionExecution.add_or_update(action_execution_db)
for i in range(1, 6):
- stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout %s\n' % (i))
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout %s\n" % (i),
+ )
ActionExecutionOutput.add_or_update(stdout_db)
for i in range(10, 15):
- stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stderr',
- data='stderr %s\n' % (i))
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stderr",
+ data="stderr %s\n" % (i),
+ )
ActionExecutionOutput.add_or_update(stderr_db)
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
- lines = resp.text.strip().split('\n')
+ lines = resp.text.strip().split("\n")
self.assertEqual(len(lines), 10)
- self.assertEqual(lines[0], 'stdout 1')
- self.assertEqual(lines[9], 'stderr 14')
+ self.assertEqual(lines[0], "stdout 1")
+ self.assertEqual(lines[9], "stderr 14")
# Verify "last" short-hand id works
- resp = self.app.get('/v1/executions/last/output', expect_errors=False)
+ resp = self.app.get("/v1/executions/last/output", expect_errors=False)
self.assertEqual(resp.status_int, 200)
- lines = resp.text.strip().split('\n')
+ lines = resp.text.strip().split("\n")
self.assertEqual(len(lines), 10)
diff --git a/st2api/tests/unit/controllers/v1/test_executions_auth.py b/st2api/tests/unit/controllers/v1/test_executions_auth.py
index e408d053dc..f1045a7d54 100644
--- a/st2api/tests/unit/controllers/v1/test_executions_auth.py
+++ b/st2api/tests/unit/controllers/v1/test_executions_auth.py
@@ -44,61 +44,48 @@
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action1.sh',
- 'pack': 'sixpack',
- 'runner_type': 'remote-shell-cmd',
- 'parameters': {
- 'a': {
- 'type': 'string',
- 'default': 'abc'
- },
- 'b': {
- 'type': 'number',
- 'default': 123
- },
- 'c': {
- 'type': 'number',
- 'default': 123,
- 'immutable': True
- },
- 'd': {
- 'type': 'string',
- 'secret': True
- }
- }
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action1.sh",
+ "pack": "sixpack",
+ "runner_type": "remote-shell-cmd",
+ "parameters": {
+ "a": {"type": "string", "default": "abc"},
+ "b": {"type": "number", "default": 123},
+ "c": {"type": "number", "default": 123, "immutable": True},
+ "d": {"type": "string", "secret": True},
+ },
}
ACTION_DEFAULT_ENCRYPT = {
- 'name': 'st2.dummy.default_encrypted_value',
- 'description': 'An action that uses a jinja template with decrypt_kv filter '
- 'in default parameter',
- 'enabled': True,
- 'pack': 'starterpack',
- 'runner_type': 'local-shell-cmd',
- 'parameters': {
- 'encrypted_param': {
- 'type': 'string',
- 'default': '{{ st2kv.system.secret | decrypt_kv }}'
+ "name": "st2.dummy.default_encrypted_value",
+ "description": "An action that uses a jinja template with decrypt_kv filter "
+ "in default parameter",
+ "enabled": True,
+ "pack": "starterpack",
+ "runner_type": "local-shell-cmd",
+ "parameters": {
+ "encrypted_param": {
+ "type": "string",
+ "default": "{{ st2kv.system.secret | decrypt_kv }}",
},
- 'encrypted_user_param': {
- 'type': 'string',
- 'default': '{{ st2kv.user.secret | decrypt_kv }}'
- }
- }
+ "encrypted_user_param": {
+ "type": "string",
+ "default": "{{ st2kv.user.secret | decrypt_kv }}",
+ },
+ },
}
LIVE_ACTION_1 = {
- 'action': 'sixpack.st2.dummy.action1',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'uname -a',
- 'd': SUPER_SECRET_PARAMETER
- }
+ "action": "sixpack.st2.dummy.action1",
+ "parameters": {
+ "hosts": "localhost",
+ "cmd": "uname -a",
+ "d": SUPER_SECRET_PARAMETER,
+ },
}
LIVE_ACTION_DEFAULT_ENCRYPT = {
- 'action': 'starterpack.st2.dummy.default_encrypted_value',
+ "action": "starterpack.st2.dummy.default_encrypted_value",
}
# NOTE: We use a longer expiry time because this variable is initialized on module import (aka
@@ -107,19 +94,23 @@
# by that time and the tests would fail.
NOW = date_utils.get_datetime_utc_now()
EXPIRY = NOW + datetime.timedelta(seconds=1000)
-SYS_TOKEN = TokenDB(id=bson.ObjectId(), user='system', token=uuid.uuid4().hex, expiry=EXPIRY)
-USR_TOKEN = TokenDB(id=bson.ObjectId(), user='tokenuser', token=uuid.uuid4().hex, expiry=EXPIRY)
+SYS_TOKEN = TokenDB(
+ id=bson.ObjectId(), user="system", token=uuid.uuid4().hex, expiry=EXPIRY
+)
+USR_TOKEN = TokenDB(
+ id=bson.ObjectId(), user="tokenuser", token=uuid.uuid4().hex, expiry=EXPIRY
+)
-FIXTURES_PACK = 'generic'
-FIXTURES = {
- 'users': ['system_user.yaml', 'token_user.yaml']
-}
+FIXTURES_PACK = "generic"
+FIXTURES = {"users": ["system_user.yaml", "token_user.yaml"]}
# These parameters are used for the tests of getting value from datastore and decrypting it at
# Jinja expression in a action metadata definition.
-TEST_USER = UserDB(name='user1')
-TEST_TOKEN = TokenDB(id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY)
-TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash='secret_key', enabled=True)
+TEST_USER = UserDB(name="user1")
+TEST_TOKEN = TokenDB(
+ id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY
+)
+TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash="secret_key", enabled=True)
def mock_get_token(*args, **kwargs):
@@ -128,50 +119,69 @@ def mock_get_token(*args, **kwargs):
return USR_TOKEN
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ActionExecutionControllerTestCaseAuthEnabled(FunctionalTest):
enable_auth = True
@classmethod
+ @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token))
+ @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=UserDB))
@mock.patch.object(
- Token, 'get',
- mock.MagicMock(side_effect=mock_get_token))
- @mock.patch.object(User, 'get_by_name', mock.MagicMock(side_effect=UserDB))
- @mock.patch.object(action_validator, 'validate_action', mock.MagicMock(
- return_value=True))
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def setUpClass(cls):
super(ActionExecutionControllerTestCaseAuthEnabled, cls).setUpClass()
cls.action = copy.deepcopy(ACTION_1)
- headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)}
- post_resp = cls.app.post_json('/v1/actions', cls.action, headers=headers)
- cls.action['id'] = post_resp.json['id']
+ headers = {
+ "content-type": "application/json",
+ "X-Auth-Token": str(SYS_TOKEN.token),
+ }
+ post_resp = cls.app.post_json("/v1/actions", cls.action, headers=headers)
+ cls.action["id"] = post_resp.json["id"]
cls.action_encrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT)
- post_resp = cls.app.post_json('/v1/actions', cls.action_encrypt, headers=headers)
- cls.action_encrypt['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json(
+ "/v1/actions", cls.action_encrypt, headers=headers
+ )
+ cls.action_encrypt["id"] = post_resp.json["id"]
- FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=FIXTURES)
+ FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=FIXTURES
+ )
# register datastore values which are used in this tests
KeyValuePairAPI._setup_crypto()
register_items = [
- {'name': 'secret', 'secret': True,
- 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')},
- {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE,
- 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')},
+ {
+ "name": "secret",
+ "secret": True,
+ "value": crypto_utils.symmetric_encrypt(
+ KeyValuePairAPI.crypto_key, "foo"
+ ),
+ },
+ {
+ "name": "user1:secret",
+ "secret": True,
+ "scope": FULL_USER_SCOPE,
+ "value": crypto_utils.symmetric_encrypt(
+ KeyValuePairAPI.crypto_key, "bar"
+ ),
+ },
+ ]
+ cls.kvps = [
+ KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items
]
- cls.kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items]
@classmethod
- @mock.patch.object(
- Token, 'get',
- mock.MagicMock(side_effect=mock_get_token))
+ @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token))
def tearDownClass(cls):
- headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)}
- cls.app.delete('/v1/actions/%s' % cls.action['id'], headers=headers)
- cls.app.delete('/v1/actions/%s' % cls.action_encrypt['id'], headers=headers)
+ headers = {
+ "content-type": "application/json",
+ "X-Auth-Token": str(SYS_TOKEN.token),
+ }
+ cls.app.delete("/v1/actions/%s" % cls.action["id"], headers=headers)
+ cls.app.delete("/v1/actions/%s" % cls.action_encrypt["id"], headers=headers)
# unregister key-value pairs for tests
[KeyValuePair.delete(x) for x in cls.kvps]
@@ -179,49 +189,53 @@ def tearDownClass(cls):
super(ActionExecutionControllerTestCaseAuthEnabled, cls).tearDownClass()
def _do_post(self, liveaction, *args, **kwargs):
- return self.app.post_json('/v1/executions', liveaction, *args, **kwargs)
+ return self.app.post_json("/v1/executions", liveaction, *args, **kwargs)
- @mock.patch.object(
- Token, 'get',
- mock.MagicMock(side_effect=mock_get_token))
+ @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token))
def test_post_with_st2_context_in_headers(self):
- headers = {'content-type': 'application/json', 'X-Auth-Token': str(USR_TOKEN.token)}
+ headers = {
+ "content-type": "application/json",
+ "X-Auth-Token": str(USR_TOKEN.token),
+ }
resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers)
self.assertEqual(resp.status_int, 201)
- token_user = resp.json['context']['user']
- self.assertEqual(token_user, 'tokenuser')
- context = {'parent': {'execution_id': str(resp.json['id']), 'user': token_user}}
- headers = {'content-type': 'application/json',
- 'X-Auth-Token': str(SYS_TOKEN.token),
- 'st2-context': json.dumps(context)}
+ token_user = resp.json["context"]["user"]
+ self.assertEqual(token_user, "tokenuser")
+ context = {"parent": {"execution_id": str(resp.json["id"]), "user": token_user}}
+ headers = {
+ "content-type": "application/json",
+ "X-Auth-Token": str(SYS_TOKEN.token),
+ "st2-context": json.dumps(context),
+ }
resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers)
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['context']['user'], 'tokenuser')
- self.assertEqual(resp.json['context']['parent'], context['parent'])
+ self.assertEqual(resp.json["context"]["user"], "tokenuser")
+ self.assertEqual(resp.json["context"]["parent"], context["parent"])
- @mock.patch.object(ApiKey, 'get', mock.Mock(return_value=TEST_APIKEY))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER))
+ @mock.patch.object(ApiKey, "get", mock.Mock(return_value=TEST_APIKEY))
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER))
def test_template_encrypted_params_with_apikey(self):
- resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={
- 'St2-Api-key': 'secret_key'
- })
+ resp = self._do_post(
+ LIVE_ACTION_DEFAULT_ENCRYPT, headers={"St2-Api-key": "secret_key"}
+ )
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo')
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar')
+ self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo")
+ self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar")
- @mock.patch.object(Token, 'get', mock.Mock(return_value=TEST_TOKEN))
- @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER))
+ @mock.patch.object(Token, "get", mock.Mock(return_value=TEST_TOKEN))
+ @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER))
def test_template_encrypted_params_with_access_token(self):
- resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={
- 'X-Auth-Token': str(TEST_TOKEN.token)
- })
+ resp = self._do_post(
+ LIVE_ACTION_DEFAULT_ENCRYPT, headers={"X-Auth-Token": str(TEST_TOKEN.token)}
+ )
self.assertEqual(resp.status_int, 201)
- self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo')
- self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar')
+ self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo")
+ self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar")
def test_template_encrypted_params_without_auth(self):
resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True)
self.assertEqual(resp.status_int, 401)
- self.assertEqual(resp.json['faultstring'],
- 'Unauthorized - One of Token or API key required.')
+ self.assertEqual(
+ resp.json["faultstring"], "Unauthorized - One of Token or API key required."
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_executions_descendants.py b/st2api/tests/unit/controllers/v1/test_executions_descendants.py
index 1afbcdde2f..945e03feeb 100644
--- a/st2api/tests/unit/controllers/v1/test_executions_descendants.py
+++ b/st2api/tests/unit/controllers/v1/test_executions_descendants.py
@@ -19,64 +19,85 @@
from st2tests.api import FunctionalTest
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
class ActionExecutionControllerTestCaseDescendantsTest(FunctionalTest):
-
@classmethod
def setUpClass(cls):
super(ActionExecutionControllerTestCaseDescendantsTest, cls).setUpClass()
- cls.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
+ cls.MODELS = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
def test_get_all_descendants(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- resp = self.app.get('/v1/executions/%s/children' % str(root_execution.id))
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ resp = self.app.get("/v1/executions/%s/children" % str(root_execution.id))
self.assertEqual(resp.status_int, 200)
- all_descendants_ids = [descendant['id'] for descendant in resp.json]
+ all_descendants_ids = [descendant["id"] for descendant in resp.json]
all_descendants_ids.sort()
# everything except the root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.id != root_execution.id]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.id != root_execution.id
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
def test_get_all_descendants_depth_neg_1(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- resp = self.app.get('/v1/executions/%s/children?depth=-1' % str(root_execution.id))
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ resp = self.app.get(
+ "/v1/executions/%s/children?depth=-1" % str(root_execution.id)
+ )
self.assertEqual(resp.status_int, 200)
- all_descendants_ids = [descendant['id'] for descendant in resp.json]
+ all_descendants_ids = [descendant["id"] for descendant in resp.json]
all_descendants_ids.sort()
# everything except the root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.id != root_execution.id]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.id != root_execution.id
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
def test_get_1_level_descendants(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- resp = self.app.get('/v1/executions/%s/children?depth=1' % str(root_execution.id))
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ resp = self.app.get(
+ "/v1/executions/%s/children?depth=1" % str(root_execution.id)
+ )
self.assertEqual(resp.status_int, 200)
- all_descendants_ids = [descendant['id'] for descendant in resp.json]
+ all_descendants_ids = [descendant["id"] for descendant in resp.json]
all_descendants_ids.sort()
# All children of root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.parent == str(root_execution.id)]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.parent == str(root_execution.id)
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
diff --git a/st2api/tests/unit/controllers/v1/test_executions_filters.py b/st2api/tests/unit/controllers/v1/test_executions_filters.py
index e33e8bf87d..af451ca519 100644
--- a/st2api/tests/unit/controllers/v1/test_executions_filters.py
+++ b/st2api/tests/unit/controllers/v1/test_executions_filters.py
@@ -22,6 +22,7 @@
from six.moves import http_client
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2tests.api import FunctionalTest
@@ -36,7 +37,6 @@
class TestActionExecutionFilters(FunctionalTest):
-
@classmethod
def testDownClass(cls):
pass
@@ -52,29 +52,33 @@ def setUpClass(cls):
cls.start_timestamps = []
cls.fake_types = [
{
- 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']),
- 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']),
- 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']),
- 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']),
- 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']),
- 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']),
- 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']),
- 'context': copy.deepcopy(fixture.ARTIFACTS['context']),
- 'children': []
+ "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]),
+ "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]),
+ "trigger_instance": copy.deepcopy(
+ fixture.ARTIFACTS["trigger_instance"]
+ ),
+ "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]),
+ "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]),
+ "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]),
+ "liveaction": copy.deepcopy(
+ fixture.ARTIFACTS["liveactions"]["workflow"]
+ ),
+ "context": copy.deepcopy(fixture.ARTIFACTS["context"]),
+ "children": [],
},
{
- 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']),
- 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']),
- 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1'])
- }
+ "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]),
+ "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]),
+ "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]),
+ },
]
def assign_parent(child):
- candidates = [v for k, v in cls.refs.items() if v.action['name'] == 'chain']
+ candidates = [v for k, v in cls.refs.items() if v.action["name"] == "chain"]
if candidates:
parent = random.choice(candidates)
- child['parent'] = str(parent.id)
- parent.children.append(child['id'])
+ child["parent"] = str(parent.id)
+ parent.children.append(child["id"])
cls.refs[str(parent.id)] = ActionExecution.add_or_update(parent)
for i in range(cls.num_records):
@@ -82,12 +86,12 @@ def assign_parent(child):
timestamp = cls.dt_base + datetime.timedelta(seconds=i)
fake_type = random.choice(cls.fake_types)
data = copy.deepcopy(fake_type)
- data['id'] = obj_id
- data['start_timestamp'] = isotime.format(timestamp, offset=False)
- data['end_timestamp'] = isotime.format(timestamp, offset=False)
- data['status'] = data['liveaction']['status']
- data['result'] = data['liveaction']['result']
- if fake_type['action']['name'] == 'local' and random.choice([True, False]):
+ data["id"] = obj_id
+ data["start_timestamp"] = isotime.format(timestamp, offset=False)
+ data["end_timestamp"] = isotime.format(timestamp, offset=False)
+ data["status"] = data["liveaction"]["status"]
+ data["result"] = data["liveaction"]["result"]
+ if fake_type["action"]["name"] == "local" and random.choice([True, False]):
assign_parent(data)
wb_obj = ActionExecutionAPI(**data)
db_obj = ActionExecutionAPI.to_model(wb_obj)
@@ -97,154 +101,185 @@ def assign_parent(child):
cls.start_timestamps = sorted(cls.start_timestamps)
def test_get_all(self):
- response = self.app.get('/v1/executions')
+ response = self.app.get("/v1/executions")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), self.num_records)
- self.assertEqual(response.headers['X-Total-Count'], str(self.num_records))
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Total-Count"], str(self.num_records))
+ ids = [item["id"] for item in response.json]
self.assertListEqual(sorted(ids), sorted(self.refs.keys()))
def test_get_all_exclude_attributes(self):
# No attributes excluded
- response = self.app.get('/v1/executions?action=executions.local&limit=1')
+ response = self.app.get("/v1/executions?action=executions.local&limit=1")
self.assertEqual(response.status_int, 200)
- self.assertIn('result', response.json[0])
+ self.assertIn("result", response.json[0])
# Exclude "result" attribute
- path = '/v1/executions?action=executions.local&limit=1&exclude_attributes=result'
+ path = (
+ "/v1/executions?action=executions.local&limit=1&exclude_attributes=result"
+ )
response = self.app.get(path)
self.assertEqual(response.status_int, 200)
- self.assertNotIn('result', response.json[0])
+ self.assertNotIn("result", response.json[0])
def test_get_one(self):
obj_id = random.choice(list(self.refs.keys()))
- response = self.app.get('/v1/executions/%s' % obj_id)
+ response = self.app.get("/v1/executions/%s" % obj_id)
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, dict)
record = response.json
fake_record = ActionExecutionAPI.from_model(self.refs[obj_id])
- self.assertEqual(record['id'], obj_id)
- self.assertDictEqual(record['action'], fake_record.action)
- self.assertDictEqual(record['runner'], fake_record.runner)
- self.assertDictEqual(record['liveaction'], fake_record.liveaction)
+ self.assertEqual(record["id"], obj_id)
+ self.assertDictEqual(record["action"], fake_record.action)
+ self.assertDictEqual(record["runner"], fake_record.runner)
+ self.assertDictEqual(record["liveaction"], fake_record.liveaction)
def test_get_one_failed(self):
- response = self.app.get('/v1/executions/%s' % str(bson.ObjectId()),
- expect_errors=True)
+ response = self.app.get(
+ "/v1/executions/%s" % str(bson.ObjectId()), expect_errors=True
+ )
self.assertEqual(response.status_int, http_client.NOT_FOUND)
def test_limit(self):
limit = 10
- refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain']
- response = self.app.get('/v1/executions?action=executions.chain&limit=%s' %
- limit)
+ refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"]
+ response = self.app.get(
+ "/v1/executions?action=executions.chain&limit=%s" % limit
+ )
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), limit)
- self.assertEqual(response.headers['X-Limit'], str(limit))
- self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json)
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Limit"], str(limit))
+ self.assertEqual(
+ response.headers["X-Total-Count"], str(len(refs)), response.json
+ )
+ ids = [item["id"] for item in response.json]
self.assertListEqual(list(set(ids) - set(refs)), [])
def test_limit_minus_one(self):
limit = -1
- refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain']
- response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit)
+ refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"]
+ response = self.app.get(
+ "/v1/executions?action=executions.chain&limit=%s" % limit
+ )
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), len(refs))
- self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json)
- ids = [item['id'] for item in response.json]
+ self.assertEqual(
+ response.headers["X-Total-Count"], str(len(refs)), response.json
+ )
+ ids = [item["id"] for item in response.json]
self.assertListEqual(list(set(ids) - set(refs)), [])
def test_limit_negative(self):
limit = -22
- response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit,
- expect_errors=True)
+ response = self.app.get(
+ "/v1/executions?action=executions.chain&limit=%s" % limit,
+ expect_errors=True,
+ )
self.assertEqual(response.status_int, 400)
- self.assertEqual(response.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ response.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_query(self):
- refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain']
- response = self.app.get('/v1/executions?action=executions.chain')
+ refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"]
+ response = self.app.get("/v1/executions?action=executions.chain")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), len(refs))
- self.assertEqual(response.headers['X-Total-Count'], str(len(refs)))
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Total-Count"], str(len(refs)))
+ ids = [item["id"] for item in response.json]
self.assertListEqual(sorted(ids), sorted(refs))
def test_filters(self):
- excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt',
- 'timestamp_lt', 'status']
+ excludes = [
+ "parent",
+ "timestamp",
+ "action",
+ "liveaction",
+ "timestamp_gt",
+ "timestamp_lt",
+ "status",
+ ]
for param, field in six.iteritems(ActionExecutionsController.supported_filters):
if param in excludes:
continue
value = self.fake_types[0]
- for item in field.split('.'):
+ for item in field.split("."):
value = value[item]
- response = self.app.get('/v1/executions?%s=%s' % (param, value))
+ response = self.app.get("/v1/executions?%s=%s" % (param, value))
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertGreater(len(response.json), 0)
- self.assertGreater(int(response.headers['X-Total-Count']), 0)
+ self.assertGreater(int(response.headers["X-Total-Count"]), 0)
def test_advanced_filters(self):
- excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt',
- 'timestamp_lt', 'status']
+ excludes = [
+ "parent",
+ "timestamp",
+ "action",
+ "liveaction",
+ "timestamp_gt",
+ "timestamp_lt",
+ "status",
+ ]
for param, field in six.iteritems(ActionExecutionsController.supported_filters):
if param in excludes:
continue
value = self.fake_types[0]
- for item in field.split('.'):
+ for item in field.split("."):
value = value[item]
- response = self.app.get('/v1/executions?filter=%s:%s' % (field, value))
+ response = self.app.get("/v1/executions?filter=%s:%s" % (field, value))
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertGreater(len(response.json), 0)
- self.assertGreater(int(response.headers['X-Total-Count']), 0)
+ self.assertGreater(int(response.headers["X-Total-Count"]), 0)
def test_advanced_filters_malformed(self):
- response = self.app.get('/v1/executions?filter=a:b,c:d', expect_errors=True)
+ response = self.app.get("/v1/executions?filter=a:b,c:d", expect_errors=True)
self.assertEqual(response.status_int, 400)
- self.assertEqual(response.json, {
- "faultstring": "Cannot resolve field \"a\""
- })
- response = self.app.get('/v1/executions?filter=action.ref', expect_errors=True)
+ self.assertEqual(response.json, {"faultstring": 'Cannot resolve field "a"'})
+ response = self.app.get("/v1/executions?filter=action.ref", expect_errors=True)
self.assertEqual(response.status_int, 400)
- self.assertEqual(response.json, {
- "faultstring": "invalid format for filter \"action.ref\""
- })
+ self.assertEqual(
+ response.json, {"faultstring": 'invalid format for filter "action.ref"'}
+ )
def test_parent(self):
- refs = [v for k, v in six.iteritems(self.refs)
- if v.action['name'] == 'chain' and v.children]
+ refs = [
+ v
+ for k, v in six.iteritems(self.refs)
+ if v.action["name"] == "chain" and v.children
+ ]
self.assertTrue(refs)
ref = random.choice(refs)
- response = self.app.get('/v1/executions?parent=%s' % str(ref.id))
+ response = self.app.get("/v1/executions?parent=%s" % str(ref.id))
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), len(ref.children))
- self.assertEqual(response.headers['X-Total-Count'], str(len(ref.children)))
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Total-Count"], str(len(ref.children)))
+ ids = [item["id"] for item in response.json]
self.assertListEqual(sorted(ids), sorted(ref.children))
def test_parentless(self):
- refs = {k: v for k, v in six.iteritems(self.refs) if not getattr(v, 'parent', None)}
+ refs = {
+ k: v for k, v in six.iteritems(self.refs) if not getattr(v, "parent", None)
+ }
self.assertTrue(refs)
self.assertNotEqual(len(refs), self.num_records)
- response = self.app.get('/v1/executions?parent=null')
+ response = self.app.get("/v1/executions?parent=null")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), len(refs))
- self.assertEqual(response.headers['X-Total-Count'], str(len(refs)))
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Total-Count"], str(len(refs)))
+ ids = [item["id"] for item in response.json]
self.assertListEqual(sorted(ids), sorted(refs.keys()))
def test_pagination(self):
@@ -253,14 +288,15 @@ def test_pagination(self):
page_count = int(self.num_records / page_size)
for i in range(page_count):
offset = i * page_size
- response = self.app.get('/v1/executions?offset=%s&limit=%s' % (
- offset, page_size))
+ response = self.app.get(
+ "/v1/executions?offset=%s&limit=%s" % (offset, page_size)
+ )
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), page_size)
- self.assertEqual(response.headers['X-Limit'], str(page_size))
- self.assertEqual(response.headers['X-Total-Count'], str(self.num_records))
- ids = [item['id'] for item in response.json]
+ self.assertEqual(response.headers["X-Limit"], str(page_size))
+ self.assertEqual(response.headers["X-Total-Count"], str(self.num_records))
+ ids = [item["id"] for item in response.json]
self.assertListEqual(list(set(ids) - set(self.refs.keys())), [])
self.assertListEqual(sorted(list(set(ids) - set(retrieved))), sorted(ids))
retrieved += ids
@@ -270,60 +306,62 @@ def test_ui_history_query(self):
# In this test we only care about making sure this exact query works. This query is used
# by the webui for the history page so it is special and breaking this is bad.
limit = 50
- history_query = '/v1/executions?limit={}&parent=null&exclude_attributes=' \
- 'result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&' \
- 'offset=0'.format(limit)
+ history_query = (
+ "/v1/executions?limit={}&parent=null&exclude_attributes="
+ "result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&"
+ "offset=0".format(limit)
+ )
response = self.app.get(history_query)
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), limit)
- self.assertTrue(int(response.headers['X-Total-Count']) > limit)
+ self.assertTrue(int(response.headers["X-Total-Count"]) > limit)
def test_datetime_range(self):
- dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z'
- response = self.app.get('/v1/executions?timestamp=%s' % dt_range)
+ dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z"
+ response = self.app.get("/v1/executions?timestamp=%s" % dt_range)
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), 10)
- self.assertEqual(response.headers['X-Total-Count'], '10')
+ self.assertEqual(response.headers["X-Total-Count"], "10")
- dt1 = response.json[0]['start_timestamp']
- dt2 = response.json[9]['start_timestamp']
+ dt1 = response.json[0]["start_timestamp"]
+ dt2 = response.json[9]["start_timestamp"]
self.assertLess(isotime.parse(dt1), isotime.parse(dt2))
- dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z'
- response = self.app.get('/v1/executions?timestamp=%s' % dt_range)
+ dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z"
+ response = self.app.get("/v1/executions?timestamp=%s" % dt_range)
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
self.assertEqual(len(response.json), 10)
- self.assertEqual(response.headers['X-Total-Count'], '10')
- dt1 = response.json[0]['start_timestamp']
- dt2 = response.json[9]['start_timestamp']
+ self.assertEqual(response.headers["X-Total-Count"], "10")
+ dt1 = response.json[0]["start_timestamp"]
+ dt2 = response.json[9]["start_timestamp"]
self.assertLess(isotime.parse(dt2), isotime.parse(dt1))
def test_default_sort(self):
- response = self.app.get('/v1/executions')
+ response = self.app.get("/v1/executions")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
- dt1 = response.json[0]['start_timestamp']
- dt2 = response.json[len(response.json) - 1]['start_timestamp']
+ dt1 = response.json[0]["start_timestamp"]
+ dt2 = response.json[len(response.json) - 1]["start_timestamp"]
self.assertLess(isotime.parse(dt2), isotime.parse(dt1))
def test_ascending_sort(self):
- response = self.app.get('/v1/executions?sort_asc=True')
+ response = self.app.get("/v1/executions?sort_asc=True")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
- dt1 = response.json[0]['start_timestamp']
- dt2 = response.json[len(response.json) - 1]['start_timestamp']
+ dt1 = response.json[0]["start_timestamp"]
+ dt2 = response.json[len(response.json) - 1]["start_timestamp"]
self.assertLess(isotime.parse(dt1), isotime.parse(dt2))
def test_descending_sort(self):
- response = self.app.get('/v1/executions?sort_desc=True')
+ response = self.app.get("/v1/executions?sort_desc=True")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, list)
- dt1 = response.json[0]['start_timestamp']
- dt2 = response.json[len(response.json) - 1]['start_timestamp']
+ dt1 = response.json[0]["start_timestamp"]
+ dt2 = response.json[len(response.json) - 1]["start_timestamp"]
self.assertLess(isotime.parse(dt2), isotime.parse(dt1))
def test_timestamp_lt_and_gt_filter(self):
@@ -335,57 +373,81 @@ def isoformat(timestamp):
# Last (largest) timestamp, there are no executions with a greater timestamp
timestamp = self.start_timestamps[-1]
- response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), 0)
# First (smallest) timestamp, there are no executions with a smaller timestamp
timestamp = self.start_timestamps[0]
- response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), 0)
# Second last, there should be one timestamp greater than it
timestamp = self.start_timestamps[-2]
- response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), 1)
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp)
+ self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp)
# Second one, there should be one timestamp smaller than it
timestamp = self.start_timestamps[1]
- response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), 1)
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp)
+ self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp)
# Half of the timestamps should be smaller
index = (len(self.start_timestamps) - 1) // 2
timestamp = self.start_timestamps[index]
- response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), index)
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp)
+ self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp)
# Half of the timestamps should be greater
index = (len(self.start_timestamps) - 1) // 2
timestamp = self.start_timestamps[-index]
- response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp)))
+ response = self.app.get(
+ "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp))
+ )
self.assertEqual(len(response.json), (index - 1))
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp)
+ self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp)
# Both, lt and gt filters, should return exactly two results
timestamp_gt = self.start_timestamps[10]
timestamp_lt = self.start_timestamps[13]
- response = self.app.get('/v1/executions?timestamp_gt=%s×tamp_lt=%s' %
- (isoformat(timestamp_gt), isoformat(timestamp_lt)))
+ response = self.app.get(
+ "/v1/executions?timestamp_gt=%s×tamp_lt=%s"
+ % (isoformat(timestamp_gt), isoformat(timestamp_lt))
+ )
self.assertEqual(len(response.json), 2)
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp_gt)
- self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) > timestamp_gt)
- self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp_lt)
- self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) < timestamp_lt)
+ self.assertTrue(
+ isotime.parse(response.json[0]["start_timestamp"]) > timestamp_gt
+ )
+ self.assertTrue(
+ isotime.parse(response.json[1]["start_timestamp"]) > timestamp_gt
+ )
+ self.assertTrue(
+ isotime.parse(response.json[0]["start_timestamp"]) < timestamp_lt
+ )
+ self.assertTrue(
+ isotime.parse(response.json[1]["start_timestamp"]) < timestamp_lt
+ )
def test_filters_view(self):
- response = self.app.get('/v1/executions/views/filters')
+ response = self.app.get("/v1/executions/views/filters")
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, dict)
- self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['default']))
- for key, value in six.iteritems(history_views.ARTIFACTS['filters']['default']):
+ self.assertEqual(
+ len(response.json), len(history_views.ARTIFACTS["filters"]["default"])
+ )
+ for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["default"]):
filter_values = response.json[key]
# Verify empty (None / null) filters are excluded
@@ -399,9 +461,13 @@ def test_filters_view(self):
self.assertEqual(set(filter_values), set(value))
def test_filters_view_specific_types(self):
- response = self.app.get('/v1/executions/views/filters?types=action,user,nonexistent')
+ response = self.app.get(
+ "/v1/executions/views/filters?types=action,user,nonexistent"
+ )
self.assertEqual(response.status_int, 200)
self.assertIsInstance(response.json, dict)
- self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['specific']))
- for key, value in six.iteritems(history_views.ARTIFACTS['filters']['specific']):
+ self.assertEqual(
+ len(response.json), len(history_views.ARTIFACTS["filters"]["specific"])
+ )
+ for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["specific"]):
self.assertEqual(set(response.json[key]), set(value))
diff --git a/st2api/tests/unit/controllers/v1/test_inquiries.py b/st2api/tests/unit/controllers/v1/test_inquiries.py
index 469b1c8816..173acbe405 100644
--- a/st2api/tests/unit/controllers/v1/test_inquiries.py
+++ b/st2api/tests/unit/controllers/v1/test_inquiries.py
@@ -36,58 +36,50 @@
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'testpack',
- 'runner_type': 'local-shell-cmd',
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "pack": "testpack",
+ "runner_type": "local-shell-cmd",
}
LIVE_ACTION_1 = {
- 'action': 'testpack.st2.dummy.action1',
- 'parameters': {
- 'cmd': 'uname -a'
- }
+ "action": "testpack.st2.dummy.action1",
+ "parameters": {"cmd": "uname -a"},
}
INQUIRY_ACTION = {
- 'name': 'st2.dummy.ask',
- 'description': 'test description',
- 'enabled': True,
- 'pack': 'testpack',
- 'runner_type': 'inquirer',
+ "name": "st2.dummy.ask",
+ "description": "test description",
+ "enabled": True,
+ "pack": "testpack",
+ "runner_type": "inquirer",
}
INQUIRY_1 = {
- 'action': 'testpack.st2.dummy.ask',
- 'status': 'pending',
- 'parameters': {},
- 'context': {
- 'parent': {
- 'user': 'testu',
- 'execution_id': '59b845e132ed350d396a798f',
- 'pack': 'examples'
+ "action": "testpack.st2.dummy.ask",
+ "status": "pending",
+ "parameters": {},
+ "context": {
+ "parent": {
+ "user": "testu",
+ "execution_id": "59b845e132ed350d396a798f",
+ "pack": "examples",
},
- 'trace_context': {'trace_tag': 'balleilaka'}
- }
+ "trace_context": {"trace_tag": "balleilaka"},
+ },
}
INQUIRY_2 = {
- 'action': 'testpack.st2.dummy.ask',
- 'status': 'pending',
- 'parameters': {
- 'route': 'superlative',
- 'users': ['foo', 'bar']
- }
+ "action": "testpack.st2.dummy.ask",
+ "status": "pending",
+ "parameters": {"route": "superlative", "users": ["foo", "bar"]},
}
INQUIRY_TIMEOUT = {
- 'action': 'testpack.st2.dummy.ask',
- 'status': 'timeout',
- 'parameters': {
- 'route': 'superlative',
- 'users': ['foo', 'bar']
- }
+ "action": "testpack.st2.dummy.ask",
+ "status": "timeout",
+ "parameters": {"route": "superlative", "users": ["foo", "bar"]},
}
SCHEMA_DEFAULT = {
@@ -97,7 +89,7 @@
"continue": {
"type": "boolean",
"description": "Would you like to continue the workflow?",
- "required": True
+ "required": True,
}
},
}
@@ -109,18 +101,18 @@
"name": {
"type": "string",
"description": "What is your name?",
- "required": True
+ "required": True,
},
"pin": {
"type": "integer",
"description": "What is your PIN?",
- "required": True
+ "required": True,
},
"paradox": {
"type": "boolean",
"description": "This statement is False.",
- "required": True
- }
+ "required": True,
+ },
},
}
@@ -132,7 +124,7 @@
"roles": [],
"users": [],
"route": "",
- "ttl": 1440
+ "ttl": 1440,
}
RESULT_2 = {
@@ -140,7 +132,7 @@
"roles": [],
"users": ["foo", "bar"],
"route": "superlative",
- "ttl": 1440
+ "ttl": 1440,
}
RESULT_MULTIPLE = {
@@ -148,58 +140,51 @@
"roles": [],
"users": [],
"route": "",
- "ttl": 1440
+ "ttl": 1440,
}
-RESPONSE_MULTIPLE = {
- "name": "matt",
- "pin": 1234,
- "paradox": True
-}
+RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True}
ROOT_LIVEACTION_DB = lv_db_models.LiveActionDB(
- id=uuid.uuid4().hex,
- status=action_constants.LIVEACTION_STATUS_PAUSED
+ id=uuid.uuid4().hex, status=action_constants.LIVEACTION_STATUS_PAUSED
)
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
-class InquiryControllerTestCase(BaseInquiryControllerTestCase,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/inquiries'
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+class InquiryControllerTestCase(
+ BaseInquiryControllerTestCase, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/inquiries"
controller_cls = InquiriesController
- include_attribute_field_name = 'ttl'
- exclude_attribute_field_name = 'ttl'
+ include_attribute_field_name = "ttl"
+ exclude_attribute_field_name = "ttl"
@mock.patch.object(
- action_validator,
- 'validate_action',
- mock.MagicMock(return_value=True))
+ action_validator, "validate_action", mock.MagicMock(return_value=True)
+ )
def setUp(cls):
super(BaseInquiryControllerTestCase, cls).setUpClass()
cls.inquiry1 = copy.deepcopy(INQUIRY_ACTION)
- post_resp = cls.app.post_json('/v1/actions', cls.inquiry1)
- cls.inquiry1['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.inquiry1)
+ cls.inquiry1["id"] = post_resp.json["id"]
cls.action1 = copy.deepcopy(ACTION_1)
- post_resp = cls.app.post_json('/v1/actions', cls.action1)
- cls.action1['id'] = post_resp.json['id']
+ post_resp = cls.app.post_json("/v1/actions", cls.action1)
+ cls.action1["id"] = post_resp.json["id"]
def test_get_all(self):
- """Test retrieval of a list of Inquiries
- """
+ """Test retrieval of a list of Inquiries"""
inquiry_count = 5
for i in range(inquiry_count):
self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT)
get_all_resp = self._do_get_all()
inquiries = get_all_resp.json
- self.assertEqual(get_all_resp.headers['X-Total-Count'], str(len(inquiries)))
+ self.assertEqual(get_all_resp.headers["X-Total-Count"], str(len(inquiries)))
self.assertIsInstance(inquiries, list)
self.assertEqual(len(inquiries), inquiry_count)
def test_get_all_empty(self):
- """Test retrieval of a list of Inquiries when there are none
- """
+ """Test retrieval of a list of Inquiries when there are none"""
inquiry_count = 0
get_all_resp = self._do_get_all()
inquiries = get_all_resp.json
@@ -207,8 +192,7 @@ def test_get_all_empty(self):
self.assertEqual(len(inquiries), inquiry_count)
def test_get_all_decrease_after_respond(self):
- """Test that the inquiry list decreases when we respond to one of them
- """
+ """Test that the inquiry list decreases when we respond to one of them"""
# Create inquiries
inquiry_count = 5
@@ -221,7 +205,7 @@ def test_get_all_decrease_after_respond(self):
# Respond to one of them
response = {"continue": True}
- self._do_respond(inquiries[0].get('id'), response)
+ self._do_respond(inquiries[0].get("id"), response)
# Ensure the list is one smaller
get_all_resp = self._do_get_all()
@@ -230,8 +214,7 @@ def test_get_all_decrease_after_respond(self):
self.assertEqual(len(inquiries), inquiry_count - 1)
def test_get_all_limit(self):
- """Test that the limit parameter works correctly
- """
+ """Test that the limit parameter works correctly"""
# Create inquiries
inquiry_count = 5
@@ -241,12 +224,11 @@ def test_get_all_limit(self):
get_all_resp = self._do_get_all(limit=limit)
inquiries = get_all_resp.json
self.assertIsInstance(inquiries, list)
- self.assertEqual(inquiry_count, int(get_all_resp.headers['X-Total-Count']))
+ self.assertEqual(inquiry_count, int(get_all_resp.headers["X-Total-Count"]))
self.assertEqual(len(inquiries), limit)
def test_get_one(self):
- """Test retrieval of a single Inquiry
- """
+ """Test retrieval of a single Inquiry"""
post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT)
inquiry_id = self._get_inquiry_id(post_resp)
get_resp = self._do_get_one(inquiry_id)
@@ -254,24 +236,21 @@ def test_get_one(self):
self.assertEqual(self._get_inquiry_id(get_resp), inquiry_id)
def test_get_one_failed(self):
- """Test failed retrieval of an Inquiry
- """
- inquiry_id = 'asdfeoijasdf'
+ """Test failed retrieval of an Inquiry"""
+ inquiry_id = "asdfeoijasdf"
get_resp = self._do_get_one(inquiry_id, expect_errors=True)
self.assertEqual(get_resp.status_int, http_client.NOT_FOUND)
- self.assertIn('resource could not be found', get_resp.json['faultstring'])
+ self.assertIn("resource could not be found", get_resp.json["faultstring"])
def test_get_one_not_an_inquiry(self):
- """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails
- """
- test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body)
- get_resp = self._do_get_one(test_exec.get('id'), expect_errors=True)
+ """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails"""
+ test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body)
+ get_resp = self._do_get_one(test_exec.get("id"), expect_errors=True)
self.assertEqual(get_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('is not an inquiry', get_resp.json['faultstring'])
+ self.assertIn("is not an inquiry", get_resp.json["faultstring"])
def test_get_one_nondefault_params(self):
- """Ensure an Inquiry with custom parameters contains those in result
- """
+ """Ensure an Inquiry with custom parameters contains those in result"""
post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_2)
inquiry_id = self._get_inquiry_id(post_resp)
get_resp = self._do_get_one(inquiry_id)
@@ -282,14 +261,15 @@ def test_get_one_nondefault_params(self):
self.assertEqual(get_resp.json.get(param), RESULT_2.get(param))
@mock.patch.object(
- action_service, 'get_root_liveaction',
- mock.MagicMock(return_value=ROOT_LIVEACTION_DB))
+ action_service,
+ "get_root_liveaction",
+ mock.MagicMock(return_value=ROOT_LIVEACTION_DB),
+ )
@mock.patch.object(
- action_service, 'request_resume',
- mock.MagicMock(return_value=None))
+ action_service, "request_resume", mock.MagicMock(return_value=None)
+ )
def test_respond(self):
- """Test that a correct response is successful
- """
+ """Test that a correct response is successful"""
post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT)
inquiry_id = self._get_inquiry_id(post_resp)
@@ -300,21 +280,22 @@ def test_respond(self):
# The inquiry no longer exists, since the status should not be "pending"
# Get the execution and confirm this.
inquiry_execution = self._do_get_execution(inquiry_id)
- self.assertEqual(inquiry_execution.json.get('status'), 'succeeded')
+ self.assertEqual(inquiry_execution.json.get("status"), "succeeded")
# This Inquiry is in a workflow, so has a parent. Assert that the resume
# was requested for this parent.
action_service.request_resume.assert_called_once()
@mock.patch.object(
- action_service, 'get_root_liveaction',
- mock.MagicMock(return_value=ROOT_LIVEACTION_DB))
+ action_service,
+ "get_root_liveaction",
+ mock.MagicMock(return_value=ROOT_LIVEACTION_DB),
+ )
@mock.patch.object(
- action_service, 'request_resume',
- mock.MagicMock(return_value=None))
+ action_service, "request_resume", mock.MagicMock(return_value=None)
+ )
def test_respond_multiple(self):
- """Test that a more complicated response is successful
- """
+ """Test that a more complicated response is successful"""
post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_MULTIPLE)
inquiry_id = self._get_inquiry_id(post_resp)
@@ -324,38 +305,35 @@ def test_respond_multiple(self):
# The inquiry no longer exists, since the status should not be "pending"
# Get the execution and confirm this.
inquiry_execution = self._do_get_execution(inquiry_id)
- self.assertEqual(inquiry_execution.json.get('status'), 'succeeded')
+ self.assertEqual(inquiry_execution.json.get("status"), "succeeded")
# This Inquiry is in a workflow, so has a parent. Assert that the resume
# was requested for this parent.
action_service.request_resume.assert_called_once()
def test_respond_fail(self):
- """Test that an incorrect response is unsuccessful
- """
+ """Test that an incorrect response is unsuccessful"""
post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT)
inquiry_id = self._get_inquiry_id(post_resp)
response = {"continue": 123}
put_resp = self._do_respond(inquiry_id, response, expect_errors=True)
self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('did not pass schema validation', put_resp.json['faultstring'])
+ self.assertIn("did not pass schema validation", put_resp.json["faultstring"])
def test_respond_not_an_inquiry(self):
- """Test that attempts to respond to an execution ID that isn't an Inquiry fails
- """
- test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body)
+ """Test that attempts to respond to an execution ID that isn't an Inquiry fails"""
+ test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body)
response = {"continue": 123}
- put_resp = self._do_respond(test_exec.get('id'), response, expect_errors=True)
+ put_resp = self._do_respond(test_exec.get("id"), response, expect_errors=True)
self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('is not an inquiry', put_resp.json['faultstring'])
+ self.assertIn("is not an inquiry", put_resp.json["faultstring"])
@mock.patch.object(
- action_service, 'request_resume',
- mock.MagicMock(return_value=None))
+ action_service, "request_resume", mock.MagicMock(return_value=None)
+ )
def test_respond_no_parent(self):
- """Test that a resume was not requested for an Inquiry without a parent
- """
+ """Test that a resume was not requested for an Inquiry without a parent"""
post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT)
inquiry_id = self._get_inquiry_id(post_resp)
@@ -365,8 +343,7 @@ def test_respond_no_parent(self):
action_service.request_resume.assert_not_called()
def test_respond_duplicate_rejected(self):
- """Test that responding to an already-responded Inquiry fails
- """
+ """Test that responding to an already-responded Inquiry fails"""
post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT)
inquiry_id = self._get_inquiry_id(post_resp)
@@ -377,28 +354,30 @@ def test_respond_duplicate_rejected(self):
# The inquiry no longer exists, since the status should not be "pending"
# Get the execution and confirm this.
inquiry_execution = self._do_get_execution(inquiry_id)
- self.assertEqual(inquiry_execution.json.get('status'), 'succeeded')
+ self.assertEqual(inquiry_execution.json.get("status"), "succeeded")
# A second, equivalent response attempt should not succeed, since the Inquiry
# has already been successfully responded to
put_resp = self._do_respond(inquiry_id, response, expect_errors=True)
self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('has already been responded to', put_resp.json['faultstring'])
+ self.assertIn("has already been responded to", put_resp.json["faultstring"])
def test_respond_timeout_rejected(self):
- """Test that responding to a timed-out Inquiry fails
- """
+ """Test that responding to a timed-out Inquiry fails"""
- post_resp = self._do_create_inquiry(INQUIRY_TIMEOUT, RESULT_DEFAULT, status='timeout')
+ post_resp = self._do_create_inquiry(
+ INQUIRY_TIMEOUT, RESULT_DEFAULT, status="timeout"
+ )
inquiry_id = self._get_inquiry_id(post_resp)
response = {"continue": True}
put_resp = self._do_respond(inquiry_id, response, expect_errors=True)
self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('timed out and cannot be responded to', put_resp.json['faultstring'])
+ self.assertIn(
+ "timed out and cannot be responded to", put_resp.json["faultstring"]
+ )
def test_respond_restrict_users(self):
- """Test that Inquiries can reject responses from users not in a list
- """
+ """Test that Inquiries can reject responses from users not in a list"""
# Default user for tests is "stanley", which is not in the 'users' list
# Should be rejected
@@ -407,7 +386,9 @@ def test_respond_restrict_users(self):
response = {"continue": True}
put_resp = self._do_respond(inquiry_id, response, expect_errors=True)
self.assertEqual(put_resp.status_int, http_client.FORBIDDEN)
- self.assertIn('does not have permission to respond', put_resp.json['faultstring'])
+ self.assertIn(
+ "does not have permission to respond", put_resp.json["faultstring"]
+ )
# Responding as a use in the list should be accepted
old_user = cfg.CONF.system_user.user
@@ -425,8 +406,8 @@ def test_get_all_invalid_exclude_and_include_parameter(self):
pass
def _insert_mock_models(self):
- id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id']
- id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id']
+ id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"]
+ id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"]
return [id_1, id_2]
diff --git a/st2api/tests/unit/controllers/v1/test_kvps.py b/st2api/tests/unit/controllers/v1/test_kvps.py
index 06103134bd..61a903a3ad 100644
--- a/st2api/tests/unit/controllers/v1/test_kvps.py
+++ b/st2api/tests/unit/controllers/v1/test_kvps.py
@@ -21,83 +21,66 @@
from six.moves import http_client
-__all__ = [
- 'KeyValuePairControllerTestCase'
-]
+__all__ = ["KeyValuePairControllerTestCase"]
-KVP = {
- 'name': 'keystone_endpoint',
- 'value': 'http://127.0.0.1:5000/v3'
-}
+KVP = {"name": "keystone_endpoint", "value": "http://127.0.0.1:5000/v3"}
-KVP_2 = {
- 'name': 'keystone_version',
- 'value': 'v3'
-}
+KVP_2 = {"name": "keystone_version", "value": "v3"}
-KVP_2_USER = {
- 'name': 'keystone_version',
- 'value': 'user_v3',
- 'scope': 'st2kv.user'
-}
+KVP_2_USER = {"name": "keystone_version", "value": "user_v3", "scope": "st2kv.user"}
-KVP_2_USER_LEGACY = {
- 'name': 'keystone_version',
- 'value': 'user_v3',
- 'scope': 'user'
-}
+KVP_2_USER_LEGACY = {"name": "keystone_version", "value": "user_v3", "scope": "user"}
KVP_3_USER = {
- 'name': 'keystone_endpoint',
- 'value': 'http://127.0.1.1:5000/v3',
- 'scope': 'st2kv.user'
+ "name": "keystone_endpoint",
+ "value": "http://127.0.1.1:5000/v3",
+ "scope": "st2kv.user",
}
KVP_4_USER = {
- 'name': 'customer_ssn',
- 'value': '123-456-7890',
- 'secret': True,
- 'scope': 'st2kv.user'
+ "name": "customer_ssn",
+ "value": "123-456-7890",
+ "secret": True,
+ "scope": "st2kv.user",
}
KVP_WITH_TTL = {
- 'name': 'keystone_endpoint',
- 'value': 'http://127.0.0.1:5000/v3',
- 'ttl': 10
+ "name": "keystone_endpoint",
+ "value": "http://127.0.0.1:5000/v3",
+ "ttl": 10,
}
-SECRET_KVP = {
- 'name': 'secret_key1',
- 'value': 'secret_value1',
- 'secret': True
-}
+SECRET_KVP = {"name": "secret_key1", "value": "secret_value1", "secret": True}
# value = S3cret!Value
# encrypted with st2tests/conf/st2_kvstore_tests.crypto.key.json
ENCRYPTED_KVP = {
- 'name': 'secret_key1',
- 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E'
- 'B30170DACF79498F30520236A629912C3584847098D'),
- 'encrypted': True
+ "name": "secret_key1",
+ "value": (
+ "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E"
+ "B30170DACF79498F30520236A629912C3584847098D"
+ ),
+ "encrypted": True,
}
ENCRYPTED_KVP_SECRET_FALSE = {
- 'name': 'secret_key2',
- 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E'
- 'B30170DACF79498F30520236A629912C3584847098D'),
- 'secret': True,
- 'encrypted': True
+ "name": "secret_key2",
+ "value": (
+ "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E"
+ "B30170DACF79498F30520236A629912C3584847098D"
+ ),
+ "secret": True,
+ "encrypted": True,
}
class KeyValuePairControllerTestCase(FunctionalTest):
-
def test_get_all(self):
- resp = self.app.get('/v1/keys')
+ resp = self.app.get("/v1/keys")
self.assertEqual(resp.status_int, 200)
def test_get_one(self):
- put_resp = self.__do_put('key1', KVP)
+ put_resp = self.__do_put("key1", KVP)
kvp_id = self.__get_kvp_id(put_resp)
get_resp = self.__do_get_one(kvp_id)
self.assertEqual(get_resp.status_int, 200)
@@ -107,484 +90,534 @@ def test_get_one(self):
def test_get_all_all_scope(self):
# Test which cases various scenarios which ensure non-admin users can't read / view keys
# from other users
- user_db_1 = UserDB(name='user1')
- user_db_2 = UserDB(name='user2')
- user_db_3 = UserDB(name='user3')
+ user_db_1 = UserDB(name="user1")
+ user_db_2 = UserDB(name="user2")
+ user_db_3 = UserDB(name="user3")
# Insert some mock data
# System scoped keys
- put_resp = self.__do_put('system1', {'name': 'system1', 'value': 'val1',
- 'scope': 'st2kv.system'})
+ put_resp = self.__do_put(
+ "system1", {"name": "system1", "value": "val1", "scope": "st2kv.system"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'system1')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
+ self.assertEqual(put_resp.json["name"], "system1")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
- put_resp = self.__do_put('system2', {'name': 'system2', 'value': 'val2',
- 'scope': 'st2kv.system'})
+ put_resp = self.__do_put(
+ "system2", {"name": "system2", "value": "val2", "scope": "st2kv.system"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'system2')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
+ self.assertEqual(put_resp.json["name"], "system2")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
# user1 scoped keys
self.use_user(user_db_1)
- put_resp = self.__do_put('user1', {'name': 'user1', 'value': 'user1',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "user1", {"name": "user1", "value": "user1", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'user1')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user1')
+ self.assertEqual(put_resp.json["name"], "user1")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user1")
- put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user1',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "userkey", {"name": "userkey", "value": "user1", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'userkey')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user1')
+ self.assertEqual(put_resp.json["name"], "userkey")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user1")
# user2 scoped keys
self.use_user(user_db_2)
- put_resp = self.__do_put('user2', {'name': 'user2', 'value': 'user2',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "user2", {"name": "user2", "value": "user2", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'user2')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user2')
+ self.assertEqual(put_resp.json["name"], "user2")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user2")
- put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user2',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "userkey", {"name": "userkey", "value": "user2", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'userkey')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user2')
+ self.assertEqual(put_resp.json["name"], "userkey")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user2")
# user3 scoped keys
self.use_user(user_db_3)
- put_resp = self.__do_put('user3', {'name': 'user3', 'value': 'user3',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "user3", {"name": "user3", "value": "user3", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'user3')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user3')
+ self.assertEqual(put_resp.json["name"], "user3")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user3")
- put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user3',
- 'scope': 'st2kv.user'})
+ put_resp = self.__do_put(
+ "userkey", {"name": "userkey", "value": "user3", "scope": "st2kv.user"}
+ )
self.assertEqual(put_resp.status_int, 200)
- self.assertEqual(put_resp.json['name'], 'userkey')
- self.assertEqual(put_resp.json['scope'], 'st2kv.user')
- self.assertEqual(put_resp.json['value'], 'user3')
+ self.assertEqual(put_resp.json["name"], "userkey")
+ self.assertEqual(put_resp.json["scope"], "st2kv.user")
+ self.assertEqual(put_resp.json["value"], "user3")
# 1. "all" scope as user1 - should only be able to view system + current user items
self.use_user(user_db_1)
- resp = self.app.get('/v1/keys?scope=all')
+ resp = self.app.get("/v1/keys?scope=all")
self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user
- self.assertEqual(resp.json[0]['name'], 'system1')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[0]["name"], "system1")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.system")
- self.assertEqual(resp.json[1]['name'], 'system2')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[1]["name"], "system2")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.system")
- self.assertEqual(resp.json[2]['name'], 'user1')
- self.assertEqual(resp.json[2]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[2]['user'], 'user1')
+ self.assertEqual(resp.json[2]["name"], "user1")
+ self.assertEqual(resp.json[2]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[2]["user"], "user1")
- self.assertEqual(resp.json[3]['name'], 'userkey')
- self.assertEqual(resp.json[3]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[3]['user'], 'user1')
+ self.assertEqual(resp.json[3]["name"], "userkey")
+ self.assertEqual(resp.json[3]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[3]["user"], "user1")
# Verify user can't retrieve values for other users by manipulating "prefix"
- resp = self.app.get('/v1/keys?scope=all&prefix=user2:')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user2:")
self.assertEqual(resp.json, [])
- resp = self.app.get('/v1/keys?scope=all&prefix=user')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user")
self.assertEqual(len(resp.json), 2) # 2 user
- self.assertEqual(resp.json[0]['name'], 'user1')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[0]['user'], 'user1')
+ self.assertEqual(resp.json[0]["name"], "user1")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[0]["user"], "user1")
- self.assertEqual(resp.json[1]['name'], 'userkey')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[1]['user'], 'user1')
+ self.assertEqual(resp.json[1]["name"], "userkey")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[1]["user"], "user1")
# 2. "all" scope user user2 - should only be able to view system + current user items
self.use_user(user_db_2)
- resp = self.app.get('/v1/keys?scope=all')
+ resp = self.app.get("/v1/keys?scope=all")
self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user
- self.assertEqual(resp.json[0]['name'], 'system1')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[0]["name"], "system1")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.system")
- self.assertEqual(resp.json[1]['name'], 'system2')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[1]["name"], "system2")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.system")
- self.assertEqual(resp.json[2]['name'], 'user2')
- self.assertEqual(resp.json[2]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[2]['user'], 'user2')
+ self.assertEqual(resp.json[2]["name"], "user2")
+ self.assertEqual(resp.json[2]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[2]["user"], "user2")
- self.assertEqual(resp.json[3]['name'], 'userkey')
- self.assertEqual(resp.json[3]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[3]['user'], 'user2')
+ self.assertEqual(resp.json[3]["name"], "userkey")
+ self.assertEqual(resp.json[3]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[3]["user"], "user2")
# Verify user can't retrieve values for other users by manipulating "prefix"
- resp = self.app.get('/v1/keys?scope=all&prefix=user1:')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user1:")
self.assertEqual(resp.json, [])
- resp = self.app.get('/v1/keys?scope=all&prefix=user')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user")
self.assertEqual(len(resp.json), 2) # 2 user
- self.assertEqual(resp.json[0]['name'], 'user2')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[0]['user'], 'user2')
+ self.assertEqual(resp.json[0]["name"], "user2")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[0]["user"], "user2")
- self.assertEqual(resp.json[1]['name'], 'userkey')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[1]['user'], 'user2')
+ self.assertEqual(resp.json[1]["name"], "userkey")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[1]["user"], "user2")
# Verify non-admon user can't retrieve key for an arbitrary users
- resp = self.app.get('/v1/keys?scope=user&user=user1', expect_errors=True)
- expected_error = '"user" attribute can only be provided by admins when RBAC is enabled'
+ resp = self.app.get("/v1/keys?scope=user&user=user1", expect_errors=True)
+ expected_error = (
+ '"user" attribute can only be provided by admins when RBAC is enabled'
+ )
self.assertEqual(resp.status_int, http_client.FORBIDDEN)
- self.assertEqual(resp.json['faultstring'], expected_error)
+ self.assertEqual(resp.json["faultstring"], expected_error)
# 3. "all" scope user user3 - should only be able to view system + current user items
self.use_user(user_db_3)
- resp = self.app.get('/v1/keys?scope=all')
+ resp = self.app.get("/v1/keys?scope=all")
self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user
- self.assertEqual(resp.json[0]['name'], 'system1')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[0]["name"], "system1")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.system")
- self.assertEqual(resp.json[1]['name'], 'system2')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.system')
+ self.assertEqual(resp.json[1]["name"], "system2")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.system")
- self.assertEqual(resp.json[2]['name'], 'user3')
- self.assertEqual(resp.json[2]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[2]['user'], 'user3')
+ self.assertEqual(resp.json[2]["name"], "user3")
+ self.assertEqual(resp.json[2]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[2]["user"], "user3")
- self.assertEqual(resp.json[3]['name'], 'userkey')
- self.assertEqual(resp.json[3]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[3]['user'], 'user3')
+ self.assertEqual(resp.json[3]["name"], "userkey")
+ self.assertEqual(resp.json[3]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[3]["user"], "user3")
# Verify user can't retrieve values for other users by manipulating "prefix"
- resp = self.app.get('/v1/keys?scope=all&prefix=user1:')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user1:")
self.assertEqual(resp.json, [])
- resp = self.app.get('/v1/keys?scope=all&prefix=user')
+ resp = self.app.get("/v1/keys?scope=all&prefix=user")
self.assertEqual(len(resp.json), 2) # 2 user
- self.assertEqual(resp.json[0]['name'], 'user3')
- self.assertEqual(resp.json[0]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[0]['user'], 'user3')
+ self.assertEqual(resp.json[0]["name"], "user3")
+ self.assertEqual(resp.json[0]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[0]["user"], "user3")
- self.assertEqual(resp.json[1]['name'], 'userkey')
- self.assertEqual(resp.json[1]['scope'], 'st2kv.user')
- self.assertEqual(resp.json[1]['user'], 'user3')
+ self.assertEqual(resp.json[1]["name"], "userkey")
+ self.assertEqual(resp.json[1]["scope"], "st2kv.user")
+ self.assertEqual(resp.json[1]["user"], "user3")
# Clean up
- self.__do_delete('system1')
- self.__do_delete('system2')
+ self.__do_delete("system1")
+ self.__do_delete("system2")
self.use_user(user_db_1)
- self.__do_delete('user1?scope=user')
- self.__do_delete('userkey?scope=user')
+ self.__do_delete("user1?scope=user")
+ self.__do_delete("userkey?scope=user")
self.use_user(user_db_2)
- self.__do_delete('user2?scope=user')
- self.__do_delete('userkey?scope=user')
+ self.__do_delete("user2?scope=user")
+ self.__do_delete("userkey?scope=user")
self.use_user(user_db_3)
- self.__do_delete('user3?scope=user')
- self.__do_delete('userkey?scope=user')
+ self.__do_delete("user3?scope=user")
+ self.__do_delete("userkey?scope=user")
def test_get_all_user_query_param_can_only_be_used_with_rbac(self):
- resp = self.app.get('/v1/keys?user=foousera', expect_errors=True)
+ resp = self.app.get("/v1/keys?user=foousera", expect_errors=True)
- expected_error = '"user" attribute can only be provided by admins when RBAC is enabled'
+ expected_error = (
+ '"user" attribute can only be provided by admins when RBAC is enabled'
+ )
self.assertEqual(resp.status_int, http_client.FORBIDDEN)
- self.assertEqual(resp.json['faultstring'], expected_error)
+ self.assertEqual(resp.json["faultstring"], expected_error)
def test_get_one_user_query_param_can_only_be_used_with_rbac(self):
- resp = self.app.get('/v1/keys/keystone_endpoint?user=foousera', expect_errors=True)
+ resp = self.app.get(
+ "/v1/keys/keystone_endpoint?user=foousera", expect_errors=True
+ )
- expected_error = '"user" attribute can only be provided by admins when RBAC is enabled'
+ expected_error = (
+ '"user" attribute can only be provided by admins when RBAC is enabled'
+ )
self.assertEqual(resp.status_int, http_client.FORBIDDEN)
- self.assertEqual(resp.json['faultstring'], expected_error)
+ self.assertEqual(resp.json["faultstring"], expected_error)
def test_get_all_prefix_filtering(self):
- put_resp1 = self.__do_put(KVP['name'], KVP)
- put_resp2 = self.__do_put(KVP_2['name'], KVP_2)
+ put_resp1 = self.__do_put(KVP["name"], KVP)
+ put_resp2 = self.__do_put(KVP_2["name"], KVP_2)
self.assertEqual(put_resp1.status_int, 200)
self.assertEqual(put_resp2.status_int, 200)
# No keys with that prefix
- resp = self.app.get('/v1/keys?prefix=something')
+ resp = self.app.get("/v1/keys?prefix=something")
self.assertEqual(resp.json, [])
# Two keys with the provided prefix
- resp = self.app.get('/v1/keys?prefix=keystone')
+ resp = self.app.get("/v1/keys?prefix=keystone")
self.assertEqual(len(resp.json), 2)
# One key with the provided prefix
- resp = self.app.get('/v1/keys?prefix=keystone_endpoint')
+ resp = self.app.get("/v1/keys?prefix=keystone_endpoint")
self.assertEqual(len(resp.json), 1)
self.__do_delete(self.__get_kvp_id(put_resp1))
self.__do_delete(self.__get_kvp_id(put_resp2))
def test_get_one_fail(self):
- resp = self.app.get('/v1/keys/1', expect_errors=True)
+ resp = self.app.get("/v1/keys/1", expect_errors=True)
self.assertEqual(resp.status_int, 404)
def test_put(self):
- put_resp = self.__do_put('key1', KVP)
+ put_resp = self.__do_put("key1", KVP)
update_input = put_resp.json
- update_input['value'] = 'http://127.0.0.1:35357/v3'
+ update_input["value"] = "http://127.0.0.1:35357/v3"
put_resp = self.__do_put(self.__get_kvp_id(put_resp), update_input)
self.assertEqual(put_resp.status_int, 200)
self.__do_delete(self.__get_kvp_id(put_resp))
def test_put_with_scope(self):
- self.app.put_json('/v1/keys/%s' % 'keystone_endpoint', KVP,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2,
- expect_errors=False)
-
- get_resp_1 = self.app.get('/v1/keys/keystone_endpoint')
+ self.app.put_json("/v1/keys/%s" % "keystone_endpoint", KVP, expect_errors=False)
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.system" % "keystone_version",
+ KVP_2,
+ expect_errors=False,
+ )
+
+ get_resp_1 = self.app.get("/v1/keys/keystone_endpoint")
self.assertTrue(get_resp_1.status_int, 200)
- self.assertEqual(self.__get_kvp_id(get_resp_1), 'keystone_endpoint')
- get_resp_2 = self.app.get('/v1/keys/keystone_version?scope=st2kv.system')
+ self.assertEqual(self.__get_kvp_id(get_resp_1), "keystone_endpoint")
+ get_resp_2 = self.app.get("/v1/keys/keystone_version?scope=st2kv.system")
self.assertTrue(get_resp_2.status_int, 200)
- self.assertEqual(self.__get_kvp_id(get_resp_2), 'keystone_version')
- get_resp_3 = self.app.get('/v1/keys/keystone_version')
+ self.assertEqual(self.__get_kvp_id(get_resp_2), "keystone_version")
+ get_resp_3 = self.app.get("/v1/keys/keystone_version")
self.assertTrue(get_resp_3.status_int, 200)
- self.assertEqual(self.__get_kvp_id(get_resp_3), 'keystone_version')
- self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.system')
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.system')
+ self.assertEqual(self.__get_kvp_id(get_resp_3), "keystone_version")
+ self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.system")
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.system")
def test_put_user_scope_and_system_scope_dont_overlap(self):
- self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER,
- expect_errors=False)
- get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.system')
- self.assertEqual(get_resp.json['value'], KVP_2['value'])
-
- get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.user')
- self.assertEqual(get_resp.json['value'], KVP_2_USER['value'])
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.system')
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.user')
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.system" % "keystone_version",
+ KVP_2,
+ expect_errors=False,
+ )
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.user" % "keystone_version",
+ KVP_2_USER,
+ expect_errors=False,
+ )
+ get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.system")
+ self.assertEqual(get_resp.json["value"], KVP_2["value"])
+
+ get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.user")
+ self.assertEqual(get_resp.json["value"], KVP_2_USER["value"])
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.system")
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.user")
def test_put_invalid_scope(self):
- put_resp = self.app.put_json('/v1/keys/keystone_version?scope=st2', KVP_2,
- expect_errors=True)
+ put_resp = self.app.put_json(
+ "/v1/keys/keystone_version?scope=st2", KVP_2, expect_errors=True
+ )
self.assertTrue(put_resp.status_int, 400)
def test_get_all_with_scope(self):
- self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER,
- expect_errors=False)
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.system" % "keystone_version",
+ KVP_2,
+ expect_errors=False,
+ )
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.user" % "keystone_version",
+ KVP_2_USER,
+ expect_errors=False,
+ )
# Note that the following two calls overwrite st2sytem and st2kv.user scoped variables with
# same name.
- self.app.put_json('/v1/keys/%s?scope=system' % 'keystone_version', KVP_2,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=user' % 'keystone_version', KVP_2_USER_LEGACY,
- expect_errors=False)
-
- get_resp_all = self.app.get('/v1/keys?scope=all')
+ self.app.put_json(
+ "/v1/keys/%s?scope=system" % "keystone_version", KVP_2, expect_errors=False
+ )
+ self.app.put_json(
+ "/v1/keys/%s?scope=user" % "keystone_version",
+ KVP_2_USER_LEGACY,
+ expect_errors=False,
+ )
+
+ get_resp_all = self.app.get("/v1/keys?scope=all")
self.assertTrue(len(get_resp_all.json), 2)
- get_resp_sys = self.app.get('/v1/keys?scope=st2kv.system')
+ get_resp_sys = self.app.get("/v1/keys?scope=st2kv.system")
self.assertTrue(len(get_resp_sys.json), 1)
- self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value'])
+ self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"])
- get_resp_sys = self.app.get('/v1/keys?scope=system')
+ get_resp_sys = self.app.get("/v1/keys?scope=system")
self.assertTrue(len(get_resp_sys.json), 1)
- self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value'])
+ self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"])
- get_resp_sys = self.app.get('/v1/keys?scope=st2kv.user')
+ get_resp_sys = self.app.get("/v1/keys?scope=st2kv.user")
self.assertTrue(len(get_resp_sys.json), 1)
- self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value'])
+ self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"])
- get_resp_sys = self.app.get('/v1/keys?scope=user')
+ get_resp_sys = self.app.get("/v1/keys?scope=user")
self.assertTrue(len(get_resp_sys.json), 1)
- self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value'])
+ self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"])
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.system')
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.user')
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.system")
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.user")
def test_get_all_with_scope_and_prefix_filtering(self):
- self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_endpoint', KVP_3_USER,
- expect_errors=False)
- self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'customer_ssn', KVP_4_USER,
- expect_errors=False)
- get_prefix = self.app.get('/v1/keys?scope=st2kv.user&prefix=keystone')
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.user" % "keystone_version",
+ KVP_2_USER,
+ expect_errors=False,
+ )
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.user" % "keystone_endpoint",
+ KVP_3_USER,
+ expect_errors=False,
+ )
+ self.app.put_json(
+ "/v1/keys/%s?scope=st2kv.user" % "customer_ssn",
+ KVP_4_USER,
+ expect_errors=False,
+ )
+ get_prefix = self.app.get("/v1/keys?scope=st2kv.user&prefix=keystone")
self.assertEqual(len(get_prefix.json), 2)
- self.app.delete('/v1/keys/keystone_version?scope=st2kv.user')
- self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.user')
- self.app.delete('/v1/keys/customer_ssn?scope=st2kv.user')
+ self.app.delete("/v1/keys/keystone_version?scope=st2kv.user")
+ self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.user")
+ self.app.delete("/v1/keys/customer_ssn?scope=st2kv.user")
def test_put_with_ttl(self):
- put_resp = self.__do_put('key_with_ttl', KVP_WITH_TTL)
+ put_resp = self.__do_put("key_with_ttl", KVP_WITH_TTL)
self.assertEqual(put_resp.status_int, 200)
- get_resp = self.app.get('/v1/keys')
- self.assertTrue(get_resp.json[0]['expire_timestamp'])
+ get_resp = self.app.get("/v1/keys")
+ self.assertTrue(get_resp.json[0]["expire_timestamp"])
self.__do_delete(self.__get_kvp_id(put_resp))
def test_put_secret(self):
- put_resp = self.__do_put('secret_key1', SECRET_KVP)
+ put_resp = self.__do_put("secret_key1", SECRET_KVP)
kvp_id = self.__get_kvp_id(put_resp)
get_resp = self.__do_get_one(kvp_id)
- self.assertTrue(get_resp.json['encrypted'])
- crypto_val = get_resp.json['value']
- self.assertNotEqual(SECRET_KVP['value'], crypto_val)
+ self.assertTrue(get_resp.json["encrypted"])
+ crypto_val = get_resp.json["value"]
+ self.assertNotEqual(SECRET_KVP["value"], crypto_val)
self.__do_delete(self.__get_kvp_id(put_resp))
def test_get_one_secret_no_decrypt(self):
- put_resp = self.__do_put('secret_key1', SECRET_KVP)
+ put_resp = self.__do_put("secret_key1", SECRET_KVP)
kvp_id = self.__get_kvp_id(put_resp)
- get_resp = self.app.get('/v1/keys/secret_key1')
+ get_resp = self.app.get("/v1/keys/secret_key1")
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_kvp_id(get_resp), kvp_id)
- self.assertTrue(get_resp.json['secret'])
- self.assertTrue(get_resp.json['encrypted'])
+ self.assertTrue(get_resp.json["secret"])
+ self.assertTrue(get_resp.json["encrypted"])
self.__do_delete(kvp_id)
def test_get_one_secret_decrypt(self):
- put_resp = self.__do_put('secret_key1', SECRET_KVP)
+ put_resp = self.__do_put("secret_key1", SECRET_KVP)
kvp_id = self.__get_kvp_id(put_resp)
- get_resp = self.app.get('/v1/keys/secret_key1?decrypt=true')
+ get_resp = self.app.get("/v1/keys/secret_key1?decrypt=true")
self.assertEqual(get_resp.status_int, 200)
self.assertEqual(self.__get_kvp_id(get_resp), kvp_id)
- self.assertTrue(get_resp.json['secret'])
- self.assertFalse(get_resp.json['encrypted'])
- self.assertEqual(get_resp.json['value'], SECRET_KVP['value'])
+ self.assertTrue(get_resp.json["secret"])
+ self.assertFalse(get_resp.json["encrypted"])
+ self.assertEqual(get_resp.json["value"], SECRET_KVP["value"])
self.__do_delete(kvp_id)
def test_get_all_decrypt(self):
- put_resp = self.__do_put('secret_key1', SECRET_KVP)
+ put_resp = self.__do_put("secret_key1", SECRET_KVP)
kvp_id_1 = self.__get_kvp_id(put_resp)
- put_resp = self.__do_put('key1', KVP)
+ put_resp = self.__do_put("key1", KVP)
kvp_id_2 = self.__get_kvp_id(put_resp)
- kvps = {'key1': KVP, 'secret_key1': SECRET_KVP}
- stored_kvps = self.app.get('/v1/keys?decrypt=true').json
+ kvps = {"key1": KVP, "secret_key1": SECRET_KVP}
+ stored_kvps = self.app.get("/v1/keys?decrypt=true").json
self.assertTrue(len(stored_kvps), 2)
for stored_kvp in stored_kvps:
- self.assertFalse(stored_kvp['encrypted'])
- exp_kvp = kvps.get(stored_kvp['name'])
+ self.assertFalse(stored_kvp["encrypted"])
+ exp_kvp = kvps.get(stored_kvp["name"])
self.assertIsNotNone(exp_kvp)
- self.assertEqual(exp_kvp['value'], stored_kvp['value'])
+ self.assertEqual(exp_kvp["value"], stored_kvp["value"])
self.__do_delete(kvp_id_1)
self.__do_delete(kvp_id_2)
def test_put_encrypted_value(self):
# 1. encrypted=True, secret=True
- put_resp = self.__do_put('secret_key1', ENCRYPTED_KVP)
+ put_resp = self.__do_put("secret_key1", ENCRYPTED_KVP)
kvp_id = self.__get_kvp_id(put_resp)
# Verify there is no secrets leakage
self.assertEqual(put_resp.status_code, 200)
- self.assertEqual(put_resp.json['name'], 'secret_key1')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
- self.assertEqual(put_resp.json['encrypted'], True)
- self.assertEqual(put_resp.json['secret'], True)
- self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value'])
- self.assertTrue(put_resp.json['value'] != 'S3cret!Value')
- self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2)
-
- get_resp = self.__do_get_one(kvp_id + '?decrypt=True')
- self.assertEqual(put_resp.json['name'], 'secret_key1')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
- self.assertEqual(put_resp.json['encrypted'], True)
- self.assertEqual(put_resp.json['secret'], True)
- self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value'])
+ self.assertEqual(put_resp.json["name"], "secret_key1")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
+ self.assertEqual(put_resp.json["encrypted"], True)
+ self.assertEqual(put_resp.json["secret"], True)
+ self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"])
+ self.assertTrue(put_resp.json["value"] != "S3cret!Value")
+ self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2)
+
+ get_resp = self.__do_get_one(kvp_id + "?decrypt=True")
+ self.assertEqual(put_resp.json["name"], "secret_key1")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
+ self.assertEqual(put_resp.json["encrypted"], True)
+ self.assertEqual(put_resp.json["secret"], True)
+ self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"])
# Verify data integrity post decryption
- get_resp = self.__do_get_one(kvp_id + '?decrypt=True')
- self.assertFalse(get_resp.json['encrypted'])
- self.assertEqual(get_resp.json['value'], 'S3cret!Value')
+ get_resp = self.__do_get_one(kvp_id + "?decrypt=True")
+ self.assertFalse(get_resp.json["encrypted"])
+ self.assertEqual(get_resp.json["value"], "S3cret!Value")
self.__do_delete(self.__get_kvp_id(put_resp))
# 2. encrypted=True, secret=False
# encrypted should always imply secret=True
- put_resp = self.__do_put('secret_key2', ENCRYPTED_KVP_SECRET_FALSE)
+ put_resp = self.__do_put("secret_key2", ENCRYPTED_KVP_SECRET_FALSE)
kvp_id = self.__get_kvp_id(put_resp)
# Verify there is no secrets leakage
self.assertEqual(put_resp.status_code, 200)
- self.assertEqual(put_resp.json['name'], 'secret_key2')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
- self.assertEqual(put_resp.json['encrypted'], True)
- self.assertEqual(put_resp.json['secret'], True)
- self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value'])
- self.assertTrue(put_resp.json['value'] != 'S3cret!Value')
- self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2)
-
- get_resp = self.__do_get_one(kvp_id + '?decrypt=True')
- self.assertEqual(put_resp.json['name'], 'secret_key2')
- self.assertEqual(put_resp.json['scope'], 'st2kv.system')
- self.assertEqual(put_resp.json['encrypted'], True)
- self.assertEqual(put_resp.json['secret'], True)
- self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value'])
+ self.assertEqual(put_resp.json["name"], "secret_key2")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
+ self.assertEqual(put_resp.json["encrypted"], True)
+ self.assertEqual(put_resp.json["secret"], True)
+ self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"])
+ self.assertTrue(put_resp.json["value"] != "S3cret!Value")
+ self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2)
+
+ get_resp = self.__do_get_one(kvp_id + "?decrypt=True")
+ self.assertEqual(put_resp.json["name"], "secret_key2")
+ self.assertEqual(put_resp.json["scope"], "st2kv.system")
+ self.assertEqual(put_resp.json["encrypted"], True)
+ self.assertEqual(put_resp.json["secret"], True)
+ self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"])
# Verify data integrity post decryption
- get_resp = self.__do_get_one(kvp_id + '?decrypt=True')
- self.assertFalse(get_resp.json['encrypted'])
- self.assertEqual(get_resp.json['value'], 'S3cret!Value')
+ get_resp = self.__do_get_one(kvp_id + "?decrypt=True")
+ self.assertFalse(get_resp.json["encrypted"])
+ self.assertEqual(get_resp.json["value"], "S3cret!Value")
self.__do_delete(self.__get_kvp_id(put_resp))
def test_put_encrypted_value_integrity_check_failed(self):
data = copy.deepcopy(ENCRYPTED_KVP)
- data['value'] = 'corrupted'
- put_resp = self.__do_put('secret_key1', data, expect_errors=True)
+ data["value"] = "corrupted"
+ put_resp = self.__do_put("secret_key1", data, expect_errors=True)
self.assertEqual(put_resp.status_code, 400)
- expected_error = ('Failed to verify the integrity of the provided value for key '
- '"secret_key1".')
- self.assertIn(expected_error, put_resp.json['faultstring'])
+ expected_error = (
+ "Failed to verify the integrity of the provided value for key "
+ '"secret_key1".'
+ )
+ self.assertIn(expected_error, put_resp.json["faultstring"])
data = copy.deepcopy(ENCRYPTED_KVP)
- data['value'] = str(data['value'][:-2])
- put_resp = self.__do_put('secret_key1', data, expect_errors=True)
+ data["value"] = str(data["value"][:-2])
+ put_resp = self.__do_put("secret_key1", data, expect_errors=True)
self.assertEqual(put_resp.status_code, 400)
- expected_error = ('Failed to verify the integrity of the provided value for key '
- '"secret_key1".')
- self.assertIn(expected_error, put_resp.json['faultstring'])
+ expected_error = (
+ "Failed to verify the integrity of the provided value for key "
+ '"secret_key1".'
+ )
+ self.assertIn(expected_error, put_resp.json["faultstring"])
def test_put_delete(self):
- put_resp = self.__do_put('key1', KVP)
+ put_resp = self.__do_put("key1", KVP)
self.assertEqual(put_resp.status_int, 200)
self.__do_delete(self.__get_kvp_id(put_resp))
def test_delete(self):
- put_resp = self.__do_put('key1', KVP)
+ put_resp = self.__do_put("key1", KVP)
del_resp = self.__do_delete(self.__get_kvp_id(put_resp))
self.assertEqual(del_resp.status_int, 204)
def test_delete_fail(self):
- resp = self.__do_delete('inexistentkey', expect_errors=True)
+ resp = self.__do_delete("inexistentkey", expect_errors=True)
self.assertEqual(resp.status_int, 404)
@staticmethod
def __get_kvp_id(resp):
- return resp.json['name']
+ return resp.json["name"]
def __do_get_one(self, kvp_id, expect_errors=False):
- return self.app.get('/v1/keys/%s' % kvp_id, expect_errors=expect_errors)
+ return self.app.get("/v1/keys/%s" % kvp_id, expect_errors=expect_errors)
def __do_put(self, kvp_id, kvp, expect_errors=False):
- return self.app.put_json('/v1/keys/%s' % kvp_id, kvp, expect_errors=expect_errors)
+ return self.app.put_json(
+ "/v1/keys/%s" % kvp_id, kvp, expect_errors=expect_errors
+ )
def __do_delete(self, kvp_id, expect_errors=False):
- return self.app.delete('/v1/keys/%s' % kvp_id, expect_errors=expect_errors)
+ return self.app.delete("/v1/keys/%s" % kvp_id, expect_errors=expect_errors)
diff --git a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py
index bff5935e38..a38c278f07 100644
--- a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py
+++ b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py
@@ -19,12 +19,10 @@
from st2tests.fixturesloader import get_fixtures_packs_base_path
-__all__ = [
- 'PackConfigSchemasControllerTestCase'
-]
+__all__ = ["PackConfigSchemasControllerTestCase"]
PACKS_PATH = get_fixtures_packs_base_path()
-CONFIG_SCHEMA_COUNT = len(glob.glob('%s/*/config.schema.yaml' % (PACKS_PATH)))
+CONFIG_SCHEMA_COUNT = len(glob.glob("%s/*/config.schema.yaml" % (PACKS_PATH)))
assert CONFIG_SCHEMA_COUNT > 1
@@ -32,29 +30,34 @@ class PackConfigSchemasControllerTestCase(FunctionalTest):
register_packs = True
def test_get_all(self):
- resp = self.app.get('/v1/config_schemas')
+ resp = self.app.get("/v1/config_schemas")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), CONFIG_SCHEMA_COUNT,
- '/v1/config_schemas did not return all schemas.')
+ self.assertEqual(
+ len(resp.json),
+ CONFIG_SCHEMA_COUNT,
+ "/v1/config_schemas did not return all schemas.",
+ )
def test_get_one_success(self):
- resp = self.app.get('/v1/config_schemas/dummy_pack_1')
+ resp = self.app.get("/v1/config_schemas/dummy_pack_1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['pack'], 'dummy_pack_1')
- self.assertIn('api_key', resp.json['attributes'])
+ self.assertEqual(resp.json["pack"], "dummy_pack_1")
+ self.assertIn("api_key", resp.json["attributes"])
def test_get_one_doesnt_exist(self):
# Pack exists, schema doesnt
- resp = self.app.get('/v1/config_schemas/dummy_pack_2',
- expect_errors=True)
+ resp = self.app.get("/v1/config_schemas/dummy_pack_2", expect_errors=True)
self.assertEqual(resp.status_int, 404)
- self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring'])
+ self.assertIn(
+ "Unable to identify resource with pack_ref ", resp.json["faultstring"]
+ )
# Pack doesn't exist
- ref_or_id = 'pack_doesnt_exist'
- resp = self.app.get('/v1/config_schemas/%s' % ref_or_id,
- expect_errors=True)
+ ref_or_id = "pack_doesnt_exist"
+ resp = self.app.get("/v1/config_schemas/%s" % ref_or_id, expect_errors=True)
self.assertEqual(resp.status_int, 404)
# Changed from: 'Unable to find the PackDB instance'
- self.assertTrue('Resource with a ref or id "%s" not found' % ref_or_id in
- resp.json['faultstring'])
+ self.assertTrue(
+ 'Resource with a ref or id "%s" not found' % ref_or_id
+ in resp.json["faultstring"]
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_pack_configs.py b/st2api/tests/unit/controllers/v1/test_pack_configs.py
index 6e789c413a..5a87719eaa 100644
--- a/st2api/tests/unit/controllers/v1/test_pack_configs.py
+++ b/st2api/tests/unit/controllers/v1/test_pack_configs.py
@@ -21,12 +21,10 @@
from st2api.controllers.v1.pack_configs import PackConfigsController
from st2tests.fixturesloader import get_fixtures_packs_base_path
-__all__ = [
- 'PackConfigsControllerTestCase'
-]
+__all__ = ["PackConfigsControllerTestCase"]
PACKS_PATH = get_fixtures_packs_base_path()
-CONFIGS_COUNT = len(glob.glob('%s/configs/*.yaml' % (PACKS_PATH)))
+CONFIGS_COUNT = len(glob.glob("%s/configs/*.yaml" % (PACKS_PATH)))
assert CONFIGS_COUNT > 1
@@ -35,60 +33,80 @@ class PackConfigsControllerTestCase(FunctionalTest):
register_pack_configs = True
def test_get_all(self):
- resp = self.app.get('/v1/configs')
+ resp = self.app.get("/v1/configs")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), CONFIGS_COUNT, '/v1/configs did not return all configs.')
+ self.assertEqual(
+ len(resp.json), CONFIGS_COUNT, "/v1/configs did not return all configs."
+ )
def test_get_one_success(self):
- resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True},
- expect_errors=True)
+ resp = self.app.get(
+ "/v1/configs/dummy_pack_1",
+ params={"show_secrets": True},
+ expect_errors=True,
+ )
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['pack'], 'dummy_pack_1')
- self.assertEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}')
- self.assertEqual(resp.json['values']['region'], 'us-west-1')
+ self.assertEqual(resp.json["pack"], "dummy_pack_1")
+ self.assertEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}")
+ self.assertEqual(resp.json["values"]["region"], "us-west-1")
def test_get_one_mask_secret(self):
- resp = self.app.get('/v1/configs/dummy_pack_1')
+ resp = self.app.get("/v1/configs/dummy_pack_1")
self.assertEqual(resp.status_int, 200)
- self.assertNotEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}')
+ self.assertNotEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}")
def test_get_one_pack_config_doesnt_exist(self):
# Pack exists, config doesnt
- resp = self.app.get('/v1/configs/dummy_pack_2',
- expect_errors=True)
+ resp = self.app.get("/v1/configs/dummy_pack_2", expect_errors=True)
self.assertEqual(resp.status_int, 404)
- self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring'])
+ self.assertIn(
+ "Unable to identify resource with pack_ref ", resp.json["faultstring"]
+ )
# Pack doesn't exist
- resp = self.app.get('/v1/configs/pack_doesnt_exist',
- expect_errors=True)
+ resp = self.app.get("/v1/configs/pack_doesnt_exist", expect_errors=True)
self.assertEqual(resp.status_int, 404)
# Changed from : 'Unable to find the PackDB instance.'
- self.assertIn('Unable to identify resource with pack_ref', resp.json['faultstring'])
+ self.assertIn(
+ "Unable to identify resource with pack_ref", resp.json["faultstring"]
+ )
- @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock())
+ @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock())
def test_put_pack_config(self):
- get_resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True},
- expect_errors=True)
- config = copy.copy(get_resp.json['values'])
- config['region'] = 'us-west-2'
+ get_resp = self.app.get(
+ "/v1/configs/dummy_pack_1",
+ params={"show_secrets": True},
+ expect_errors=True,
+ )
+ config = copy.copy(get_resp.json["values"])
+ config["region"] = "us-west-2"
- put_resp = self.app.put_json('/v1/configs/dummy_pack_1', config)
+ put_resp = self.app.put_json("/v1/configs/dummy_pack_1", config)
self.assertEqual(put_resp.status_int, 200)
- put_resp_undo = self.app.put_json('/v1/configs/dummy_pack_1?show_secrets=true',
- get_resp.json['values'], expect_errors=True)
+ put_resp_undo = self.app.put_json(
+ "/v1/configs/dummy_pack_1?show_secrets=true",
+ get_resp.json["values"],
+ expect_errors=True,
+ )
self.assertEqual(put_resp.status_int, 200)
self.assertEqual(get_resp.json, put_resp_undo.json)
- @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock())
+ @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock())
def test_put_invalid_pack_config(self):
- get_resp = self.app.get('/v1/configs/dummy_pack_11', params={'show_secrets': True},
- expect_errors=True)
- config = copy.copy(get_resp.json['values'])
- put_resp = self.app.put_json('/v1/configs/dummy_pack_11', config, expect_errors=True)
+ get_resp = self.app.get(
+ "/v1/configs/dummy_pack_11",
+ params={"show_secrets": True},
+ expect_errors=True,
+ )
+ config = copy.copy(get_resp.json["values"])
+ put_resp = self.app.put_json(
+ "/v1/configs/dummy_pack_11", config, expect_errors=True
+ )
self.assertEqual(put_resp.status_int, 400)
- expected_msg = ('Values specified as "secret: True" in config schema are automatically '
- 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
- 'for such values. Please check the specified values in the config or '
- 'the default values in the schema.')
- self.assertIn(expected_msg, put_resp.json['faultstring'])
+ expected_msg = (
+ 'Values specified as "secret: True" in config schema are automatically '
+ 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
+ "for such values. Please check the specified values in the config or "
+ "the default values in the schema."
+ )
+ self.assertIn(expected_msg, put_resp.json["faultstring"])
diff --git a/st2api/tests/unit/controllers/v1/test_packs.py b/st2api/tests/unit/controllers/v1/test_packs.py
index 9406a50af0..07cacd0be8 100644
--- a/st2api/tests/unit/controllers/v1/test_packs.py
+++ b/st2api/tests/unit/controllers/v1/test_packs.py
@@ -33,9 +33,7 @@
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'PacksControllerTestCase'
-]
+__all__ = ["PacksControllerTestCase"]
PACK_INDEX = {
"test": {
@@ -45,7 +43,7 @@
"author": "st2-dev",
"keywords": ["some", "search", "another", "terms"],
"email": "info@stackstorm.com",
- "description": "st2 pack to test package management pipeline"
+ "description": "st2 pack to test package management pipeline",
},
"test2": {
"version": "0.5.0",
@@ -54,13 +52,13 @@
"author": "stanley",
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
- "description": "another st2 pack to test package management pipeline"
- }
+ "description": "another st2 pack to test package management pipeline",
+ },
}
PACK_INDEXES = {
- 'http://main.example.com': PACK_INDEX,
- 'http://fallback.example.com': {
+ "http://main.example.com": PACK_INDEX,
+ "http://fallback.example.com": {
"test": {
"version": "0.1.0",
"name": "test",
@@ -68,10 +66,10 @@
"author": "st2-dev",
"keywords": ["some", "search", "another", "terms"],
"email": "info@stackstorm.com",
- "description": "st2 pack to test package management pipeline"
+ "description": "st2 pack to test package management pipeline",
}
},
- 'http://override.example.com': {
+ "http://override.example.com": {
"test2": {
"version": "1.0.0",
"name": "test2",
@@ -79,10 +77,12 @@
"author": "stanley",
"keywords": ["some", "special", "terms"],
"email": "info@stackstorm.com",
- "description": "another st2 pack to test package management pipeline"
+ "description": "another st2 pack to test package management pipeline",
}
},
- 'http://broken.example.com': requests.exceptions.RequestException('index is broken')
+ "http://broken.example.com": requests.exceptions.RequestException(
+ "index is broken"
+ ),
}
@@ -93,10 +93,7 @@ def mock_index_get(url, *args, **kwargs):
raise index
status = 200
- content = {
- 'metadata': {},
- 'packs': index
- }
+ content = {"metadata": {}, "packs": index}
# Return mock response object
@@ -104,311 +101,371 @@ def mock_index_get(url, *args, **kwargs):
mock_resp.raise_for_status = mock.Mock()
mock_resp.status_code = status
mock_resp.content = content
- mock_resp.json = mock.Mock(
- return_value=content
- )
+ mock_resp.json = mock.Mock(return_value=content)
return mock_resp
-class PacksControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/packs'
+class PacksControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/packs"
controller_cls = PacksController
- include_attribute_field_name = 'version'
- exclude_attribute_field_name = 'author'
+ include_attribute_field_name = "version"
+ exclude_attribute_field_name = "author"
@classmethod
def setUpClass(cls):
super(PacksControllerTestCase, cls).setUpClass()
- cls.pack_db_1 = PackDB(name='pack1', description='foo', version='0.1.0', author='foo',
- email='test@example.com', ref='pack1')
- cls.pack_db_2 = PackDB(name='pack2', description='foo', version='0.1.0', author='foo',
- email='test@example.com', ref='pack2')
- cls.pack_db_3 = PackDB(name='pack3-name', ref='pack3-ref', description='foo',
- version='0.1.0', author='foo',
- email='test@example.com')
+ cls.pack_db_1 = PackDB(
+ name="pack1",
+ description="foo",
+ version="0.1.0",
+ author="foo",
+ email="test@example.com",
+ ref="pack1",
+ )
+ cls.pack_db_2 = PackDB(
+ name="pack2",
+ description="foo",
+ version="0.1.0",
+ author="foo",
+ email="test@example.com",
+ ref="pack2",
+ )
+ cls.pack_db_3 = PackDB(
+ name="pack3-name",
+ ref="pack3-ref",
+ description="foo",
+ version="0.1.0",
+ author="foo",
+ email="test@example.com",
+ )
Pack.add_or_update(cls.pack_db_1)
Pack.add_or_update(cls.pack_db_2)
Pack.add_or_update(cls.pack_db_3)
def test_get_all(self):
- resp = self.app.get('/v1/packs')
+ resp = self.app.get("/v1/packs")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 3, '/v1/actionalias did not return all packs.')
+ self.assertEqual(len(resp.json), 3, "/v1/actionalias did not return all packs.")
def test_get_one(self):
# Get by id
- resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.id))
+ resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.id))
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['name'], self.pack_db_1.name)
+ self.assertEqual(resp.json["name"], self.pack_db_1.name)
# Get by name
- resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.ref))
+ resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.ref))
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['ref'], self.pack_db_1.ref)
- self.assertEqual(resp.json['name'], self.pack_db_1.name)
+ self.assertEqual(resp.json["ref"], self.pack_db_1.ref)
+ self.assertEqual(resp.json["name"], self.pack_db_1.name)
# Get by ref (ref != name)
- resp = self.app.get('/v1/packs/%s' % (self.pack_db_3.ref))
+ resp = self.app.get("/v1/packs/%s" % (self.pack_db_3.ref))
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['ref'], self.pack_db_3.ref)
+ self.assertEqual(resp.json["ref"], self.pack_db_3.ref)
def test_get_one_doesnt_exist(self):
- resp = self.app.get('/v1/packs/doesntexistfoo', expect_errors=True)
+ resp = self.app.get("/v1/packs/doesntexistfoo", expect_errors=True)
self.assertEqual(resp.status_int, 404)
- @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution')
+ @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution")
def test_install(self, _handle_schedule_execution):
- _handle_schedule_execution.return_value = Response(json={'id': '123'})
- payload = {'packs': ['some']}
+ _handle_schedule_execution.return_value = Response(json={"id": "123"})
+ payload = {"packs": ["some"]}
- resp = self.app.post_json('/v1/packs/install', payload)
+ resp = self.app.post_json("/v1/packs/install", payload)
self.assertEqual(resp.status_int, 202)
- self.assertEqual(resp.json, {'execution_id': '123'})
+ self.assertEqual(resp.json, {"execution_id": "123"})
- @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution')
+ @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution")
def test_install_with_force_parameter(self, _handle_schedule_execution):
- _handle_schedule_execution.return_value = Response(json={'id': '123'})
- payload = {'packs': ['some'], 'force': True}
+ _handle_schedule_execution.return_value = Response(json={"id": "123"})
+ payload = {"packs": ["some"], "force": True}
- resp = self.app.post_json('/v1/packs/install', payload)
+ resp = self.app.post_json("/v1/packs/install", payload)
self.assertEqual(resp.status_int, 202)
- self.assertEqual(resp.json, {'execution_id': '123'})
+ self.assertEqual(resp.json, {"execution_id": "123"})
- @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution')
+ @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution")
def test_install_with_skip_dependencies_parameter(self, _handle_schedule_execution):
- _handle_schedule_execution.return_value = Response(json={'id': '123'})
- payload = {'packs': ['some'], 'skip_dependencies': True}
+ _handle_schedule_execution.return_value = Response(json={"id": "123"})
+ payload = {"packs": ["some"], "skip_dependencies": True}
- resp = self.app.post_json('/v1/packs/install', payload)
+ resp = self.app.post_json("/v1/packs/install", payload)
self.assertEqual(resp.status_int, 202)
- self.assertEqual(resp.json, {'execution_id': '123'})
+ self.assertEqual(resp.json, {"execution_id": "123"})
- @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution')
+ @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution")
def test_uninstall(self, _handle_schedule_execution):
- _handle_schedule_execution.return_value = Response(json={'id': '123'})
- payload = {'packs': ['some']}
+ _handle_schedule_execution.return_value = Response(json={"id": "123"})
+ payload = {"packs": ["some"]}
- resp = self.app.post_json('/v1/packs/uninstall', payload)
+ resp = self.app.post_json("/v1/packs/uninstall", payload)
self.assertEqual(resp.status_int, 202)
- self.assertEqual(resp.json, {'execution_id': '123'})
+ self.assertEqual(resp.json, {"execution_id": "123"})
- @mock.patch.object(pack_service, 'fetch_pack_index',
- mock.MagicMock(return_value=(PACK_INDEX, {})))
+ @mock.patch.object(
+ pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {}))
+ )
def test_search_with_query(self):
test_scenarios = [
{
- 'input': {'query': 'test'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']]
+ "input": {"query": "test"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'stanley'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test2']]
+ "input": {"query": "stanley"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'special'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test2']]
+ "input": {"query": "special"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'TEST'}, # Search should be case insensitive by default
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']]
+ "input": {
+ "query": "TEST"
+ }, # Search should be case insensitive by default
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'SPECIAL'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test2']]
+ "input": {"query": "SPECIAL"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'sPeCiAL'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test2']]
+ "input": {"query": "sPeCiAL"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test2"]],
},
{
- 'input': {'query': 'st2-dev'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test']]
+ "input": {"query": "st2-dev"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test"]],
},
{
- 'input': {'query': 'ST2-dev'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test']]
+ "input": {"query": "ST2-dev"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test"]],
},
{
- 'input': {'query': '-dev'},
- 'expected_code': 200,
- 'expected_result': [PACK_INDEX['test']]
+ "input": {"query": "-dev"},
+ "expected_code": 200,
+ "expected_result": [PACK_INDEX["test"]],
},
- {
- 'input': {'query': 'core'},
- 'expected_code': 200,
- 'expected_result': []
- }
+ {"input": {"query": "core"}, "expected_code": 200, "expected_result": []},
]
for scenario in test_scenarios:
- resp = self.app.post_json('/v1/packs/index/search', scenario['input'])
- self.assertEqual(resp.status_int, scenario['expected_code'])
- self.assertEqual(resp.json, scenario['expected_result'])
-
- @mock.patch.object(pack_service, 'get_pack_from_index',
- mock.MagicMock(return_value=PACK_INDEX['test']))
+ resp = self.app.post_json("/v1/packs/index/search", scenario["input"])
+ self.assertEqual(resp.status_int, scenario["expected_code"])
+ self.assertEqual(resp.json, scenario["expected_result"])
+
+ @mock.patch.object(
+ pack_service,
+ "get_pack_from_index",
+ mock.MagicMock(return_value=PACK_INDEX["test"]),
+ )
def test_search_with_pack_has_result(self):
- resp = self.app.post_json('/v1/packs/index/search', {'pack': 'st2-dev'})
+ resp = self.app.post_json("/v1/packs/index/search", {"pack": "st2-dev"})
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, PACK_INDEX['test'])
+ self.assertEqual(resp.json, PACK_INDEX["test"])
- @mock.patch.object(pack_service, 'get_pack_from_index',
- mock.MagicMock(return_value=None))
+ @mock.patch.object(
+ pack_service, "get_pack_from_index", mock.MagicMock(return_value=None)
+ )
def test_search_with_pack_no_result(self):
- resp = self.app.post_json('/v1/packs/index/search', {'pack': 'not-found'})
+ resp = self.app.post_json("/v1/packs/index/search", {"pack": "not-found"})
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.json, [])
- @mock.patch.object(pack_service, 'fetch_pack_index',
- mock.MagicMock(return_value=(PACK_INDEX, {})))
+ @mock.patch.object(
+ pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {}))
+ )
def test_show(self):
- resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test'})
+ resp = self.app.post_json("/v1/packs/index/search", {"pack": "test"})
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, PACK_INDEX['test'])
+ self.assertEqual(resp.json, PACK_INDEX["test"])
- resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test2'})
+ resp = self.app.post_json("/v1/packs/index/search", {"pack": "test2"})
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, PACK_INDEX['test2'])
+ self.assertEqual(resp.json, PACK_INDEX["test2"])
- @mock.patch.object(pack_service, '_build_index_list',
- mock.MagicMock(return_value=['http://main.example.com']))
- @mock.patch.object(requests, 'get', mock_index_get)
+ @mock.patch.object(
+ pack_service,
+ "_build_index_list",
+ mock.MagicMock(return_value=["http://main.example.com"]),
+ )
+ @mock.patch.object(requests, "get", mock_index_get)
def test_index_health(self):
- resp = self.app.get('/v1/packs/index/health')
+ resp = self.app.get("/v1/packs/index/health")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'packs': {
- 'count': 2
+ self.assertEqual(
+ resp.json,
+ {
+ "packs": {"count": 2},
+ "indexes": {
+ "count": 1,
+ "status": [
+ {
+ "url": "http://main.example.com",
+ "message": "Success.",
+ "packs": 2,
+ "error": None,
+ }
+ ],
+ "valid": 1,
+ "errors": {},
+ "invalid": 0,
+ },
},
- 'indexes': {
- 'count': 1,
- 'status': [{
- 'url': 'http://main.example.com',
- 'message': 'Success.',
- 'packs': 2,
- 'error': None
- }],
- 'valid': 1,
- 'errors': {},
- 'invalid': 0
- }
- })
-
- @mock.patch.object(pack_service, '_build_index_list',
- mock.MagicMock(return_value=['http://main.example.com',
- 'http://broken.example.com']))
- @mock.patch.object(requests, 'get', mock_index_get)
+ )
+
+ @mock.patch.object(
+ pack_service,
+ "_build_index_list",
+ mock.MagicMock(
+ return_value=["http://main.example.com", "http://broken.example.com"]
+ ),
+ )
+ @mock.patch.object(requests, "get", mock_index_get)
def test_index_health_broken(self):
- resp = self.app.get('/v1/packs/index/health')
+ resp = self.app.get("/v1/packs/index/health")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'packs': {
- 'count': 2
- },
- 'indexes': {
- 'count': 2,
- 'status': [{
- 'url': 'http://main.example.com',
- 'message': 'Success.',
- 'packs': 2,
- 'error': None
- }, {
- 'url': 'http://broken.example.com',
- 'message': "RequestException('index is broken',)",
- 'packs': 0,
- 'error': 'unresponsive'
- }],
- 'valid': 1,
- 'errors': {
- 'unresponsive': 1
+ self.assertEqual(
+ resp.json,
+ {
+ "packs": {"count": 2},
+ "indexes": {
+ "count": 2,
+ "status": [
+ {
+ "url": "http://main.example.com",
+ "message": "Success.",
+ "packs": 2,
+ "error": None,
+ },
+ {
+ "url": "http://broken.example.com",
+ "message": "RequestException('index is broken',)",
+ "packs": 0,
+ "error": "unresponsive",
+ },
+ ],
+ "valid": 1,
+ "errors": {"unresponsive": 1},
+ "invalid": 1,
},
- 'invalid': 1
- }
- })
+ },
+ )
- @mock.patch.object(pack_service, '_build_index_list',
- mock.MagicMock(return_value=['http://main.example.com']))
- @mock.patch.object(requests, 'get', mock_index_get)
+ @mock.patch.object(
+ pack_service,
+ "_build_index_list",
+ mock.MagicMock(return_value=["http://main.example.com"]),
+ )
+ @mock.patch.object(requests, "get", mock_index_get)
def test_index(self):
- resp = self.app.get('/v1/packs/index')
+ resp = self.app.get("/v1/packs/index")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'status': [{
- 'url': 'http://main.example.com',
- 'message': 'Success.',
- 'packs': 2,
- 'error': None
- }],
- 'index': PACK_INDEX
- })
-
- @mock.patch.object(pack_service, '_build_index_list',
- mock.MagicMock(return_value=['http://fallback.example.com',
- 'http://main.example.com']))
- @mock.patch.object(requests, 'get', mock_index_get)
+ self.assertEqual(
+ resp.json,
+ {
+ "status": [
+ {
+ "url": "http://main.example.com",
+ "message": "Success.",
+ "packs": 2,
+ "error": None,
+ }
+ ],
+ "index": PACK_INDEX,
+ },
+ )
+
+ @mock.patch.object(
+ pack_service,
+ "_build_index_list",
+ mock.MagicMock(
+ return_value=["http://fallback.example.com", "http://main.example.com"]
+ ),
+ )
+ @mock.patch.object(requests, "get", mock_index_get)
def test_index_fallback(self):
- resp = self.app.get('/v1/packs/index')
+ resp = self.app.get("/v1/packs/index")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'status': [{
- 'url': 'http://fallback.example.com',
- 'message': 'Success.',
- 'packs': 1,
- 'error': None
- }, {
- 'url': 'http://main.example.com',
- 'message': 'Success.',
- 'packs': 2,
- 'error': None
- }],
- 'index': PACK_INDEX
- })
-
- @mock.patch.object(pack_service, '_build_index_list',
- mock.MagicMock(return_value=['http://main.example.com',
- 'http://override.example.com']))
- @mock.patch.object(requests, 'get', mock_index_get)
+ self.assertEqual(
+ resp.json,
+ {
+ "status": [
+ {
+ "url": "http://fallback.example.com",
+ "message": "Success.",
+ "packs": 1,
+ "error": None,
+ },
+ {
+ "url": "http://main.example.com",
+ "message": "Success.",
+ "packs": 2,
+ "error": None,
+ },
+ ],
+ "index": PACK_INDEX,
+ },
+ )
+
+ @mock.patch.object(
+ pack_service,
+ "_build_index_list",
+ mock.MagicMock(
+ return_value=["http://main.example.com", "http://override.example.com"]
+ ),
+ )
+ @mock.patch.object(requests, "get", mock_index_get)
def test_index_override(self):
- resp = self.app.get('/v1/packs/index')
+ resp = self.app.get("/v1/packs/index")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'status': [{
- 'url': 'http://main.example.com',
- 'message': 'Success.',
- 'packs': 2,
- 'error': None
- }, {
- 'url': 'http://override.example.com',
- 'message': 'Success.',
- 'packs': 1,
- 'error': None
- }],
- 'index': {
- 'test': PACK_INDEX['test'],
- 'test2': PACK_INDEXES['http://override.example.com']['test2']
- }
- })
+ self.assertEqual(
+ resp.json,
+ {
+ "status": [
+ {
+ "url": "http://main.example.com",
+ "message": "Success.",
+ "packs": 2,
+ "error": None,
+ },
+ {
+ "url": "http://override.example.com",
+ "message": "Success.",
+ "packs": 1,
+ "error": None,
+ },
+ ],
+ "index": {
+ "test": PACK_INDEX["test"],
+ "test2": PACK_INDEXES["http://override.example.com"]["test2"],
+ },
+ },
+ )
def test_packs_register_endpoint_resource_register_order(self):
# Verify that resources are registered in the same order as they are inside
@@ -416,17 +473,17 @@ def test_packs_register_endpoint_resource_register_order(self):
# Note: Sadly there is no easier / better way to test this
resource_types = list(ENTITIES.keys())
expected_order = [
- 'trigger',
- 'sensor',
- 'action',
- 'rule',
- 'alias',
- 'policy',
- 'config'
+ "trigger",
+ "sensor",
+ "action",
+ "rule",
+ "alias",
+ "policy",
+ "config",
]
self.assertEqual(resource_types, expected_order)
- @mock.patch.object(ContentPackLoader, 'get_packs')
+ @mock.patch.object(ContentPackLoader, "get_packs")
def test_packs_register_endpoint(self, mock_get_packs):
# Register resources from all packs - make sure the count values are correctly added
# together
@@ -434,12 +491,12 @@ def test_packs_register_endpoint(self, mock_get_packs):
# Note: We only register a couple of packs and not all on disk to speed
# things up. Registering all the packs takes a long time.
fixtures_base_path = get_fixtures_base_path()
- packs_base_path = os.path.join(fixtures_base_path, 'packs')
+ packs_base_path = os.path.join(fixtures_base_path, "packs")
pack_names = [
- 'dummy_pack_1',
- 'dummy_pack_2',
- 'dummy_pack_3',
- 'dummy_pack_10',
+ "dummy_pack_1",
+ "dummy_pack_2",
+ "dummy_pack_3",
+ "dummy_pack_10",
]
mock_return_value = {}
for pack_name in pack_names:
@@ -447,160 +504,180 @@ def test_packs_register_endpoint(self, mock_get_packs):
mock_get_packs.return_value = mock_return_value
- resp = self.app.post_json('/v1/packs/register', {'fail_on_failure': False})
+ resp = self.app.post_json("/v1/packs/register", {"fail_on_failure": False})
self.assertEqual(resp.status_int, 200)
- self.assertIn('runners', resp.json)
- self.assertIn('actions', resp.json)
- self.assertIn('triggers', resp.json)
- self.assertIn('sensors', resp.json)
- self.assertIn('rules', resp.json)
- self.assertIn('rule_types', resp.json)
- self.assertIn('aliases', resp.json)
- self.assertIn('policy_types', resp.json)
- self.assertIn('policies', resp.json)
- self.assertIn('configs', resp.json)
-
- self.assertTrue(resp.json['actions'] >= 3)
- self.assertTrue(resp.json['configs'] >= 1)
+ self.assertIn("runners", resp.json)
+ self.assertIn("actions", resp.json)
+ self.assertIn("triggers", resp.json)
+ self.assertIn("sensors", resp.json)
+ self.assertIn("rules", resp.json)
+ self.assertIn("rule_types", resp.json)
+ self.assertIn("aliases", resp.json)
+ self.assertIn("policy_types", resp.json)
+ self.assertIn("policies", resp.json)
+ self.assertIn("configs", resp.json)
+
+ self.assertTrue(resp.json["actions"] >= 3)
+ self.assertTrue(resp.json["configs"] >= 1)
# Register resources from a specific pack
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'],
- 'fail_on_failure': False})
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["dummy_pack_1"], "fail_on_failure": False}
+ )
self.assertEqual(resp.status_int, 200)
- self.assertTrue(resp.json['actions'] >= 1)
- self.assertTrue(resp.json['sensors'] >= 1)
- self.assertTrue(resp.json['configs'] >= 1)
+ self.assertTrue(resp.json["actions"] >= 1)
+ self.assertTrue(resp.json["sensors"] >= 1)
+ self.assertTrue(resp.json["configs"] >= 1)
# Verify metadata_file attribute is set
- action_dbs = Action.query(pack='dummy_pack_1')
- self.assertEqual(action_dbs[0].metadata_file, 'actions/my_action.yaml')
+ action_dbs = Action.query(pack="dummy_pack_1")
+ self.assertEqual(action_dbs[0].metadata_file, "actions/my_action.yaml")
# Register 'all' resource types should try include any possible content for the pack
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'],
- 'fail_on_failure': False,
- 'types': ['all']})
+ resp = self.app.post_json(
+ "/v1/packs/register",
+ {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["all"]},
+ )
self.assertEqual(resp.status_int, 200)
- self.assertIn('runners', resp.json)
- self.assertIn('actions', resp.json)
- self.assertIn('triggers', resp.json)
- self.assertIn('sensors', resp.json)
- self.assertIn('rules', resp.json)
- self.assertIn('rule_types', resp.json)
- self.assertIn('aliases', resp.json)
- self.assertIn('policy_types', resp.json)
- self.assertIn('policies', resp.json)
- self.assertIn('configs', resp.json)
+ self.assertIn("runners", resp.json)
+ self.assertIn("actions", resp.json)
+ self.assertIn("triggers", resp.json)
+ self.assertIn("sensors", resp.json)
+ self.assertIn("rules", resp.json)
+ self.assertIn("rule_types", resp.json)
+ self.assertIn("aliases", resp.json)
+ self.assertIn("policy_types", resp.json)
+ self.assertIn("policies", resp.json)
+ self.assertIn("configs", resp.json)
# Registering single resource type should also cause dependent resources
# to be registered
# * actions -> runners
# * rules -> rule types
# * policies -> policy types
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'],
- 'fail_on_failure': False,
- 'types': ['actions']})
+ resp = self.app.post_json(
+ "/v1/packs/register",
+ {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["actions"]},
+ )
self.assertEqual(resp.status_int, 200)
- self.assertTrue(resp.json['runners'] >= 1)
- self.assertTrue(resp.json['actions'] >= 1)
+ self.assertTrue(resp.json["runners"] >= 1)
+ self.assertTrue(resp.json["actions"] >= 1)
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'],
- 'fail_on_failure': False,
- 'types': ['rules']})
+ resp = self.app.post_json(
+ "/v1/packs/register",
+ {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["rules"]},
+ )
self.assertEqual(resp.status_int, 200)
- self.assertTrue(resp.json['rule_types'] >= 1)
- self.assertTrue(resp.json['rules'] >= 1)
+ self.assertTrue(resp.json["rule_types"] >= 1)
+ self.assertTrue(resp.json["rules"] >= 1)
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_2'],
- 'fail_on_failure': False,
- 'types': ['policies']})
+ resp = self.app.post_json(
+ "/v1/packs/register",
+ {
+ "packs": ["dummy_pack_2"],
+ "fail_on_failure": False,
+ "types": ["policies"],
+ },
+ )
self.assertEqual(resp.status_int, 200)
- self.assertTrue(resp.json['policy_types'] >= 1)
- self.assertTrue(resp.json['policies'] >= 0)
+ self.assertTrue(resp.json["policy_types"] >= 1)
+ self.assertTrue(resp.json["policies"] >= 0)
# Register specific type for all packs
- resp = self.app.post_json('/v1/packs/register', {'types': ['sensor'],
- 'fail_on_failure': False})
+ resp = self.app.post_json(
+ "/v1/packs/register", {"types": ["sensor"], "fail_on_failure": False}
+ )
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {'sensors': 3})
+ self.assertEqual(resp.json, {"sensors": 3})
# Verify that plural name form also works
- resp = self.app.post_json('/v1/packs/register', {'types': ['sensors'],
- 'fail_on_failure': False})
+ resp = self.app.post_json(
+ "/v1/packs/register", {"types": ["sensors"], "fail_on_failure": False}
+ )
self.assertEqual(resp.status_int, 200)
# Register specific type for a single packs
- resp = self.app.post_json('/v1/packs/register',
- {'packs': ['dummy_pack_1'], 'types': ['action']})
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["action"]}
+ )
self.assertEqual(resp.status_int, 200)
# 13 real plus 1 mock runner
- self.assertEqual(resp.json, {'actions': 1, 'runners': 14})
+ self.assertEqual(resp.json, {"actions": 1, "runners": 14})
# Verify that plural name form also works
- resp = self.app.post_json('/v1/packs/register',
- {'packs': ['dummy_pack_1'], 'types': ['actions']})
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["actions"]}
+ )
self.assertEqual(resp.status_int, 200)
# 13 real plus 1 mock runner
- self.assertEqual(resp.json, {'actions': 1, 'runners': 14})
+ self.assertEqual(resp.json, {"actions": 1, "runners": 14})
# Register single resource from a single pack specified multiple times - verify that
# resources from the same pack are only registered once
- resp = self.app.post_json('/v1/packs/register',
- {'packs': ['dummy_pack_1', 'dummy_pack_1', 'dummy_pack_1'],
- 'types': ['actions'],
- 'fail_on_failure': False})
+ resp = self.app.post_json(
+ "/v1/packs/register",
+ {
+ "packs": ["dummy_pack_1", "dummy_pack_1", "dummy_pack_1"],
+ "types": ["actions"],
+ "fail_on_failure": False,
+ },
+ )
self.assertEqual(resp.status_int, 200)
# 13 real plus 1 mock runner
- self.assertEqual(resp.json, {'actions': 1, 'runners': 14})
+ self.assertEqual(resp.json, {"actions": 1, "runners": 14})
# Register resources from a single (non-existent pack)
- resp = self.app.post_json('/v1/packs/register', {'packs': ['doesntexist']},
- expect_errors=True)
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["doesntexist"]}, expect_errors=True
+ )
self.assertEqual(resp.status_int, 400)
- self.assertIn('Pack "doesntexist" not found on disk:', resp.json['faultstring'])
+ self.assertIn('Pack "doesntexist" not found on disk:', resp.json["faultstring"])
# Fail on failure is enabled by default
- resp = self.app.post_json('/v1/packs/register', expect_errors=True)
+ resp = self.app.post_json("/v1/packs/register", expect_errors=True)
expected_msg = 'Failed to register pack "dummy_pack_10":'
self.assertEqual(resp.status_int, 400)
- self.assertIn(expected_msg, resp.json['faultstring'])
+ self.assertIn(expected_msg, resp.json["faultstring"])
# Fail on failure (broken pack metadata)
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1']},
- expect_errors=True)
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["dummy_pack_1"]}, expect_errors=True
+ )
expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist'
self.assertEqual(resp.status_int, 400)
- self.assertIn(expected_msg, resp.json['faultstring'])
+ self.assertIn(expected_msg, resp.json["faultstring"])
# Fail on failure (broken action metadata)
- resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_15']},
- expect_errors=True)
+ resp = self.app.post_json(
+ "/v1/packs/register", {"packs": ["dummy_pack_15"]}, expect_errors=True
+ )
- expected_msg = 'Failed to register action'
+ expected_msg = "Failed to register action"
self.assertEqual(resp.status_int, 400)
- self.assertIn(expected_msg, resp.json['faultstring'])
+ self.assertIn(expected_msg, resp.json["faultstring"])
- expected_msg = '\'stringa\' is not valid under any of the given schemas'
+ expected_msg = "'stringa' is not valid under any of the given schemas"
self.assertEqual(resp.status_int, 400)
- self.assertIn(expected_msg, resp.json['faultstring'])
+ self.assertIn(expected_msg, resp.json["faultstring"])
def test_get_all_invalid_exclude_and_include_parameter(self):
pass
def _insert_mock_models(self):
- return [self.pack_db_1['id'], self.pack_db_2['id'], self.pack_db_3['id']]
+ return [self.pack_db_1["id"], self.pack_db_2["id"], self.pack_db_3["id"]]
def _do_delete(self, object_ids):
pass
diff --git a/st2api/tests/unit/controllers/v1/test_packs_views.py b/st2api/tests/unit/controllers/v1/test_packs_views.py
index 5535a6e22b..a1b96a4aea 100644
--- a/st2api/tests/unit/controllers/v1/test_packs_views.py
+++ b/st2api/tests/unit/controllers/v1/test_packs_views.py
@@ -21,7 +21,7 @@
from st2tests.api import FunctionalTest
-@mock.patch('st2common.bootstrap.base.REGISTERED_PACKS_CACHE', {})
+@mock.patch("st2common.bootstrap.base.REGISTERED_PACKS_CACHE", {})
class PacksViewsControllerTestCase(FunctionalTest):
@classmethod
def setUpClass(cls):
@@ -31,32 +31,34 @@ def setUpClass(cls):
actions_registrar.register_actions(use_pack_cache=False)
def test_get_pack_files_success(self):
- resp = self.app.get('/v1/packs/views/files/dummy_pack_1')
+ resp = self.app.get("/v1/packs/views/files/dummy_pack_1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertTrue(len(resp.json) > 1)
- item = [_item for _item in resp.json if _item['file_path'] == 'pack.yaml'][0]
- self.assertEqual(item['file_path'], 'pack.yaml')
- item = [_item for _item in resp.json if _item['file_path'] == 'actions/my_action.py'][0]
- self.assertEqual(item['file_path'], 'actions/my_action.py')
+ item = [_item for _item in resp.json if _item["file_path"] == "pack.yaml"][0]
+ self.assertEqual(item["file_path"], "pack.yaml")
+ item = [
+ _item for _item in resp.json if _item["file_path"] == "actions/my_action.py"
+ ][0]
+ self.assertEqual(item["file_path"], "actions/my_action.py")
def test_get_pack_files_pack_doesnt_exist(self):
- resp = self.app.get('/v1/packs/views/files/doesntexist', expect_errors=True)
+ resp = self.app.get("/v1/packs/views/files/doesntexist", expect_errors=True)
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_get_pack_files_binary_files_are_excluded(self):
binary_files = [
- 'icon.png',
- 'etc/permissions.png',
- 'etc/travisci.png',
- 'etc/generate_new_token.png'
+ "icon.png",
+ "etc/permissions.png",
+ "etc/travisci.png",
+ "etc/generate_new_token.png",
]
- pack_db = Pack.get_by_ref('dummy_pack_1')
+ pack_db = Pack.get_by_ref("dummy_pack_1")
all_files_count = len(pack_db.files)
non_binary_files_count = all_files_count - len(binary_files)
- resp = self.app.get('/v1/packs/views/files/dummy_pack_1')
+ resp = self.app.get("/v1/packs/views/files/dummy_pack_1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), non_binary_files_count)
@@ -65,63 +67,75 @@ def test_get_pack_files_binary_files_are_excluded(self):
# But not in files controller response
for file_path in binary_files:
- item = [item for item in resp.json if item['file_path'] == file_path]
+ item = [item for item in resp.json if item["file_path"] == file_path]
self.assertFalse(item)
def test_get_pack_file_success(self):
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml')
+ resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertIn(b'name : dummy_pack_1', resp.body)
+ self.assertIn(b"name : dummy_pack_1", resp.body)
def test_get_pack_file_pack_doesnt_exist(self):
- resp = self.app.get('/v1/packs/views/files/doesntexist/pack.yaml', expect_errors=True)
+ resp = self.app.get(
+ "/v1/packs/views/files/doesntexist/pack.yaml", expect_errors=True
+ )
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
- @mock.patch('st2api.controllers.v1.pack_views.MAX_FILE_SIZE', 1)
+ @mock.patch("st2api.controllers.v1.pack_views.MAX_FILE_SIZE", 1)
def test_pack_file_file_larger_then_maximum_size(self):
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', expect_errors=True)
+ resp = self.app.get(
+ "/v1/packs/views/file/dummy_pack_1/pack.yaml", expect_errors=True
+ )
self.assertEqual(resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('File pack.yaml exceeds maximum allowed file size', resp)
+ self.assertIn("File pack.yaml exceeds maximum allowed file size", resp)
def test_headers_get_pack_file(self):
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml')
+ resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertIn(b'name : dummy_pack_1', resp.body)
- self.assertIsNotNone(resp.headers['ETag'])
- self.assertIsNotNone(resp.headers['Last-Modified'])
+ self.assertIn(b"name : dummy_pack_1", resp.body)
+ self.assertIsNotNone(resp.headers["ETag"])
+ self.assertIsNotNone(resp.headers["Last-Modified"])
def test_no_change_get_pack_file(self):
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml')
+ resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertIn(b'name : dummy_pack_1', resp.body)
+ self.assertIn(b"name : dummy_pack_1", resp.body)
# Confirm NOT_MODIFIED
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml',
- headers={'If-None-Match': resp.headers['ETag']})
+ resp = self.app.get(
+ "/v1/packs/views/file/dummy_pack_1/pack.yaml",
+ headers={"If-None-Match": resp.headers["ETag"]},
+ )
self.assertEqual(resp.status_code, http_client.NOT_MODIFIED)
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml',
- headers={'If-Modified-Since': resp.headers['Last-Modified']})
+ resp = self.app.get(
+ "/v1/packs/views/file/dummy_pack_1/pack.yaml",
+ headers={"If-Modified-Since": resp.headers["Last-Modified"]},
+ )
self.assertEqual(resp.status_code, http_client.NOT_MODIFIED)
# Confirm value is returned if header do not match
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml',
- headers={'If-None-Match': 'ETAG'})
+ resp = self.app.get(
+ "/v1/packs/views/file/dummy_pack_1/pack.yaml",
+ headers={"If-None-Match": "ETAG"},
+ )
self.assertEqual(resp.status_code, http_client.OK)
- self.assertIn(b'name : dummy_pack_1', resp.body)
+ self.assertIn(b"name : dummy_pack_1", resp.body)
- resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml',
- headers={'If-Modified-Since': 'Last-Modified'})
+ resp = self.app.get(
+ "/v1/packs/views/file/dummy_pack_1/pack.yaml",
+ headers={"If-Modified-Since": "Last-Modified"},
+ )
self.assertEqual(resp.status_code, http_client.OK)
- self.assertIn(b'name : dummy_pack_1', resp.body)
+ self.assertIn(b"name : dummy_pack_1", resp.body)
def test_get_pack_files_and_pack_file_ref_doesnt_equal_pack_name(self):
# Ref is not equal to the name, controller should still work
- resp = self.app.get('/v1/packs/views/files/dummy_pack_16')
+ resp = self.app.get("/v1/packs/views/files/dummy_pack_16")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['file_path'], 'pack.yaml')
+ self.assertEqual(resp.json[0]["file_path"], "pack.yaml")
- resp = self.app.get('/v1/packs/views/file/dummy_pack_16/pack.yaml')
+ resp = self.app.get("/v1/packs/views/file/dummy_pack_16/pack.yaml")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertIn(b'ref: dummy_pack_16', resp.body)
+ self.assertIn(b"ref: dummy_pack_16", resp.body)
diff --git a/st2api/tests/unit/controllers/v1/test_policies.py b/st2api/tests/unit/controllers/v1/test_policies.py
index a26c3dea24..3127b3aeb7 100644
--- a/st2api/tests/unit/controllers/v1/test_policies.py
+++ b/st2api/tests/unit/controllers/v1/test_policies.py
@@ -27,36 +27,28 @@
from st2tests.api import FunctionalTest
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
-__all__ = [
- 'PolicyTypeControllerTestCase',
- 'PolicyControllerTestCase'
-]
+__all__ = ["PolicyTypeControllerTestCase", "PolicyControllerTestCase"]
TEST_FIXTURES = {
- 'policytypes': [
- 'fake_policy_type_1.yaml',
- 'fake_policy_type_2.yaml'
- ],
- 'policies': [
- 'policy_1.yaml',
- 'policy_2.yaml'
- ]
+ "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"],
+ "policies": ["policy_1.yaml", "policy_2.yaml"],
}
-PACK = 'generic'
+PACK = "generic"
LOADER = FixturesLoader()
FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
-class PolicyTypeControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/policytypes'
+class PolicyTypeControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/policytypes"
controller_cls = PolicyTypeController
- include_attribute_field_name = 'module'
- exclude_attribute_field_name = 'parameters'
+ include_attribute_field_name = "module"
+ exclude_attribute_field_name = "parameters"
- base_url = '/v1/policytypes'
+ base_url = "/v1/policytypes"
@classmethod
def setUpClass(cls):
@@ -64,7 +56,7 @@ def setUpClass(cls):
cls.policy_type_dbs = []
- for _, fixture in six.iteritems(FIXTURES['policytypes']):
+ for _, fixture in six.iteritems(FIXTURES["policytypes"]):
instance = PolicyTypeAPI(**fixture)
policy_type_db = PolicyType.add_or_update(PolicyTypeAPI.to_model(instance))
cls.policy_type_dbs.append(policy_type_db)
@@ -80,23 +72,25 @@ def test_policy_type_filter(self):
self.assertGreater(len(resp.json), 0)
selected = resp.json[0]
- resp = self.__do_get_all(filter='resource_type=%s&name=%s' %
- (selected['resource_type'], selected['name']))
+ resp = self.__do_get_all(
+ filter="resource_type=%s&name=%s"
+ % (selected["resource_type"], selected["name"])
+ )
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"])
- resp = self.__do_get_all(filter='name=%s' % selected['name'])
+ resp = self.__do_get_all(filter="name=%s" % selected["name"])
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"])
- resp = self.__do_get_all(filter='resource_type=%s' % selected['resource_type'])
+ resp = self.__do_get_all(filter="resource_type=%s" % selected["resource_type"])
self.assertEqual(resp.status_int, 200)
self.assertGreater(len(resp.json), 1)
def test_policy_type_filter_empty(self):
- resp = self.__do_get_all(filter='resource_type=yo&name=whatever')
+ resp = self.__do_get_all(filter="resource_type=yo&name=whatever")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 0)
@@ -106,16 +100,16 @@ def test_policy_type_get_one(self):
self.assertGreater(len(resp.json), 0)
selected = resp.json[0]
- resp = self.__do_get_one(selected['id'])
+ resp = self.__do_get_one(selected["id"])
self.assertEqual(resp.status_int, 200)
- self.assertEqual(self.__get_obj_id(resp), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp), selected["id"])
- resp = self.__do_get_one(selected['ref'])
+ resp = self.__do_get_one(selected["ref"])
self.assertEqual(resp.status_int, 200)
- self.assertEqual(self.__get_obj_id(resp), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp), selected["id"])
def test_policy_type_get_one_fail(self):
- resp = self.__do_get_one('1')
+ resp = self.__do_get_one("1")
self.assertEqual(resp.status_int, 404)
def _insert_mock_models(self):
@@ -130,36 +124,37 @@ def _delete_mock_models(self, object_ids):
@staticmethod
def __get_obj_id(resp, idx=-1):
- return resp.json['id'] if idx < 0 else resp.json[idx]['id']
+ return resp.json["id"] if idx < 0 else resp.json[idx]["id"]
def __do_get_all(self, filter=None):
- url = '%s?%s' % (self.base_url, filter) if filter else self.base_url
+ url = "%s?%s" % (self.base_url, filter) if filter else self.base_url
return self.app.get(url, expect_errors=True)
def __do_get_one(self, id):
- return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True)
+ return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True)
-class PolicyControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/policies'
+class PolicyControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/policies"
controller_cls = PolicyController
- include_attribute_field_name = 'policy_type'
- exclude_attribute_field_name = 'parameters'
+ include_attribute_field_name = "policy_type"
+ exclude_attribute_field_name = "parameters"
- base_url = '/v1/policies'
+ base_url = "/v1/policies"
@classmethod
def setUpClass(cls):
super(PolicyControllerTestCase, cls).setUpClass()
- for _, fixture in six.iteritems(FIXTURES['policytypes']):
+ for _, fixture in six.iteritems(FIXTURES["policytypes"]):
instance = PolicyTypeAPI(**fixture)
PolicyType.add_or_update(PolicyTypeAPI.to_model(instance))
cls.policy_dbs = []
- for _, fixture in six.iteritems(FIXTURES['policies']):
+ for _, fixture in six.iteritems(FIXTURES["policies"]):
instance = PolicyAPI(**fixture)
policy_db = Policy.add_or_update(PolicyAPI.to_model(instance))
cls.policy_dbs.append(policy_db)
@@ -175,22 +170,24 @@ def test_filter(self):
self.assertGreater(len(resp.json), 0)
selected = resp.json[0]
- resp = self.__do_get_all(filter='pack=%s&name=%s' % (selected['pack'], selected['name']))
+ resp = self.__do_get_all(
+ filter="pack=%s&name=%s" % (selected["pack"], selected["name"])
+ )
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"])
- resp = self.__do_get_all(filter='name=%s' % selected['name'])
+ resp = self.__do_get_all(filter="name=%s" % selected["name"])
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"])
- resp = self.__do_get_all(filter='pack=%s' % selected['pack'])
+ resp = self.__do_get_all(filter="pack=%s" % selected["pack"])
self.assertEqual(resp.status_int, 200)
self.assertGreater(len(resp.json), 1)
def test_filter_empty(self):
- resp = self.__do_get_all(filter='pack=yo&name=whatever')
+ resp = self.__do_get_all(filter="pack=yo&name=whatever")
self.assertEqual(resp.status_int, 200)
self.assertEqual(len(resp.json), 0)
@@ -200,16 +197,16 @@ def test_get_one(self):
self.assertGreater(len(resp.json), 0)
selected = resp.json[0]
- resp = self.__do_get_one(selected['id'])
+ resp = self.__do_get_one(selected["id"])
self.assertEqual(resp.status_int, 200)
- self.assertEqual(self.__get_obj_id(resp), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp), selected["id"])
- resp = self.__do_get_one(selected['ref'])
+ resp = self.__do_get_one(selected["ref"])
self.assertEqual(resp.status_int, 200)
- self.assertEqual(self.__get_obj_id(resp), selected['id'])
+ self.assertEqual(self.__get_obj_id(resp), selected["id"])
def test_get_one_fail(self):
- resp = self.__do_get_one('1')
+ resp = self.__do_get_one("1")
self.assertEqual(resp.status_int, 404)
def test_crud(self):
@@ -221,10 +218,10 @@ def test_crud(self):
self.assertEqual(get_resp.status_int, http_client.OK)
updated_input = get_resp.json
- updated_input['enabled'] = not updated_input['enabled']
+ updated_input["enabled"] = not updated_input["enabled"]
put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input)
self.assertEqual(put_resp.status_int, http_client.OK)
- self.assertEqual(put_resp.json['enabled'], updated_input['enabled'])
+ self.assertEqual(put_resp.json["enabled"], updated_input["enabled"])
del_resp = self.__do_delete(self.__get_obj_id(post_resp))
self.assertEqual(del_resp.status_int, http_client.NO_CONTENT)
@@ -243,41 +240,45 @@ def test_post_duplicate(self):
def test_put_not_found(self):
updated_input = self.__create_instance()
- put_resp = self.__do_put('12345', updated_input)
+ put_resp = self.__do_put("12345", updated_input)
self.assertEqual(put_resp.status_int, http_client.NOT_FOUND)
def test_put_sys_pack(self):
instance = self.__create_instance()
- instance['pack'] = 'core'
+ instance["pack"] = "core"
post_resp = self.__do_post(instance)
self.assertEqual(post_resp.status_int, http_client.CREATED)
updated_input = post_resp.json
- updated_input['enabled'] = not updated_input['enabled']
+ updated_input["enabled"] = not updated_input["enabled"]
put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input)
self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST)
- self.assertEqual(put_resp.json['faultstring'],
- "Resources belonging to system level packs can't be manipulated")
+ self.assertEqual(
+ put_resp.json["faultstring"],
+ "Resources belonging to system level packs can't be manipulated",
+ )
# Clean up manually since API won't delete object in sys pack.
Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp)))
def test_delete_not_found(self):
- del_resp = self.__do_delete('12345')
+ del_resp = self.__do_delete("12345")
self.assertEqual(del_resp.status_int, http_client.NOT_FOUND)
def test_delete_sys_pack(self):
instance = self.__create_instance()
- instance['pack'] = 'core'
+ instance["pack"] = "core"
post_resp = self.__do_post(instance)
self.assertEqual(post_resp.status_int, http_client.CREATED)
del_resp = self.__do_delete(self.__get_obj_id(post_resp))
self.assertEqual(del_resp.status_int, http_client.BAD_REQUEST)
- self.assertEqual(del_resp.json['faultstring'],
- "Resources belonging to system level packs can't be manipulated")
+ self.assertEqual(
+ del_resp.json["faultstring"],
+ "Resources belonging to system level packs can't be manipulated",
+ )
# Clean up manually since API won't delete object in sys pack.
Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp)))
@@ -295,34 +296,34 @@ def _delete_mock_models(self, object_ids):
@staticmethod
def __create_instance():
return {
- 'name': 'myaction.mypolicy',
- 'pack': 'mypack',
- 'resource_ref': 'mypack.myaction',
- 'policy_type': 'action.mock_policy_error',
- 'parameters': {
- 'k1': 'v1'
- }
+ "name": "myaction.mypolicy",
+ "pack": "mypack",
+ "resource_ref": "mypack.myaction",
+ "policy_type": "action.mock_policy_error",
+ "parameters": {"k1": "v1"},
}
@staticmethod
def __get_obj_id(resp, idx=-1):
- return resp.json['id'] if idx < 0 else resp.json[idx]['id']
+ return resp.json["id"] if idx < 0 else resp.json[idx]["id"]
def __do_get_all(self, filter=None):
- url = '%s?%s' % (self.base_url, filter) if filter else self.base_url
+ url = "%s?%s" % (self.base_url, filter) if filter else self.base_url
return self.app.get(url, expect_errors=True)
def __do_get_one(self, id):
- return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True)
+ return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True)
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_post(self, instance):
return self.app.post_json(self.base_url, instance, expect_errors=True)
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_put(self, id, instance):
- return self.app.put_json('%s/%s' % (self.base_url, id), instance, expect_errors=True)
+ return self.app.put_json(
+ "%s/%s" % (self.base_url, id), instance, expect_errors=True
+ )
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_delete(self, id):
- return self.app.delete('%s/%s' % (self.base_url, id), expect_errors=True)
+ return self.app.delete("%s/%s" % (self.base_url, id), expect_errors=True)
diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py
index 84c2a66b4a..0a7a104d35 100644
--- a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py
+++ b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py
@@ -21,87 +21,109 @@
from st2tests.api import FunctionalTest
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
-__all__ = [
- 'RuleEnforcementViewsControllerTestCase'
-]
+__all__ = ["RuleEnforcementViewsControllerTestCase"]
http_client = six.moves.http_client
TEST_FIXTURES = {
- 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml'],
- 'executions': ['execution1.yaml'],
- 'triggerinstances': ['trigger_instance_1.yaml']
+ "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"],
+ "executions": ["execution1.yaml"],
+ "triggerinstances": ["trigger_instance_1.yaml"],
}
-FIXTURES_PACK = 'rule_enforcements'
+FIXTURES_PACK = "rule_enforcements"
-class RuleEnforcementViewsControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/ruleenforcements/views'
+class RuleEnforcementViewsControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/ruleenforcements/views"
controller_cls = RuleEnforcementViewController
- include_attribute_field_name = 'enforced_at'
- exclude_attribute_field_name = 'status'
+ include_attribute_field_name = "enforced_at"
+ exclude_attribute_field_name = "status"
fixtures_loader = FixturesLoader()
@classmethod
def setUpClass(cls):
super(RuleEnforcementViewsControllerTestCase, cls).setUpClass()
- cls.models = RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES,
- use_object_ids=True)
- cls.ENFORCEMENT_1 = cls.models['enforcements']['enforcement1.yaml']
+ cls.models = (
+ RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK,
+ fixtures_dict=TEST_FIXTURES,
+ use_object_ids=True,
+ )
+ )
+ cls.ENFORCEMENT_1 = cls.models["enforcements"]["enforcement1.yaml"]
def test_get_all(self):
- resp = self.app.get('/v1/ruleenforcements/views')
+ resp = self.app.get("/v1/ruleenforcements/views")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
# Verify it includes corresponding execution and trigger instance object
- self.assertEqual(resp.json[0]['trigger_instance']['id'], '565e15ce32ed350857dfa623')
- self.assertEqual(resp.json[0]['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'})
-
- self.assertEqual(resp.json[0]['execution']['action']['ref'], 'core.local')
- self.assertEqual(resp.json[0]['execution']['action']['parameters'],
- {'sudo': {'immutable': True}})
- self.assertEqual(resp.json[0]['execution']['runner']['name'], 'action-chain')
- self.assertEqual(resp.json[0]['execution']['runner']['runner_parameters'],
- {'foo': {'type': 'string'}})
- self.assertEqual(resp.json[0]['execution']['parameters'], {'cmd': 'echo bar'})
- self.assertEqual(resp.json[0]['execution']['status'], 'scheduled')
-
- self.assertEqual(resp.json[1]['trigger_instance'], {})
- self.assertEqual(resp.json[1]['execution'], {})
-
- self.assertEqual(resp.json[2]['trigger_instance'], {})
- self.assertEqual(resp.json[2]['execution'], {})
+ self.assertEqual(
+ resp.json[0]["trigger_instance"]["id"], "565e15ce32ed350857dfa623"
+ )
+ self.assertEqual(
+ resp.json[0]["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"}
+ )
+
+ self.assertEqual(resp.json[0]["execution"]["action"]["ref"], "core.local")
+ self.assertEqual(
+ resp.json[0]["execution"]["action"]["parameters"],
+ {"sudo": {"immutable": True}},
+ )
+ self.assertEqual(resp.json[0]["execution"]["runner"]["name"], "action-chain")
+ self.assertEqual(
+ resp.json[0]["execution"]["runner"]["runner_parameters"],
+ {"foo": {"type": "string"}},
+ )
+ self.assertEqual(resp.json[0]["execution"]["parameters"], {"cmd": "echo bar"})
+ self.assertEqual(resp.json[0]["execution"]["status"], "scheduled")
+
+ self.assertEqual(resp.json[1]["trigger_instance"], {})
+ self.assertEqual(resp.json[1]["execution"], {})
+
+ self.assertEqual(resp.json[2]["trigger_instance"], {})
+ self.assertEqual(resp.json[2]["execution"], {})
def test_filter_by_rule_ref(self):
- resp = self.app.get('/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule')
+ resp = self.app.get("/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['rule']['ref'], 'wolfpack.golden_rule')
+ self.assertEqual(resp.json[0]["rule"]["ref"], "wolfpack.golden_rule")
def test_get_one_success(self):
- resp = self.app.get('/v1/ruleenforcements/views/%s' % (str(self.ENFORCEMENT_1.id)))
- self.assertEqual(resp.json['id'], str(self.ENFORCEMENT_1.id))
-
- self.assertEqual(resp.json['trigger_instance']['id'], '565e15ce32ed350857dfa623')
- self.assertEqual(resp.json['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'})
-
- self.assertEqual(resp.json['execution']['action']['ref'], 'core.local')
- self.assertEqual(resp.json['execution']['action']['parameters'],
- {'sudo': {'immutable': True}})
- self.assertEqual(resp.json['execution']['runner']['name'], 'action-chain')
- self.assertEqual(resp.json['execution']['runner']['runner_parameters'],
- {'foo': {'type': 'string'}})
- self.assertEqual(resp.json['execution']['parameters'], {'cmd': 'echo bar'})
- self.assertEqual(resp.json['execution']['status'], 'scheduled')
+ resp = self.app.get(
+ "/v1/ruleenforcements/views/%s" % (str(self.ENFORCEMENT_1.id))
+ )
+ self.assertEqual(resp.json["id"], str(self.ENFORCEMENT_1.id))
+
+ self.assertEqual(
+ resp.json["trigger_instance"]["id"], "565e15ce32ed350857dfa623"
+ )
+ self.assertEqual(
+ resp.json["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"}
+ )
+
+ self.assertEqual(resp.json["execution"]["action"]["ref"], "core.local")
+ self.assertEqual(
+ resp.json["execution"]["action"]["parameters"],
+ {"sudo": {"immutable": True}},
+ )
+ self.assertEqual(resp.json["execution"]["runner"]["name"], "action-chain")
+ self.assertEqual(
+ resp.json["execution"]["runner"]["runner_parameters"],
+ {"foo": {"type": "string"}},
+ )
+ self.assertEqual(resp.json["execution"]["parameters"], {"cmd": "echo bar"})
+ self.assertEqual(resp.json["execution"]["status"], "scheduled")
def _insert_mock_models(self):
- enfrocement_ids = [enforcement['id'] for enforcement in
- self.models['enforcements'].values()]
+ enfrocement_ids = [
+ enforcement["id"] for enforcement in self.models["enforcements"].values()
+ ]
return enfrocement_ids
def _delete_mock_models(self, object_ids):
diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py
index 172b186098..f2de1e2b2a 100644
--- a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py
+++ b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py
@@ -24,92 +24,106 @@
http_client = six.moves.http_client
TEST_FIXTURES = {
- 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml']
+ "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"]
}
-FIXTURES_PACK = 'rule_enforcements'
+FIXTURES_PACK = "rule_enforcements"
-class RuleEnforcementControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/ruleenforcements'
+class RuleEnforcementControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/ruleenforcements"
controller_cls = RuleEnforcementController
- include_attribute_field_name = 'enforced_at'
- exclude_attribute_field_name = 'status'
+ include_attribute_field_name = "enforced_at"
+ exclude_attribute_field_name = "status"
fixtures_loader = FixturesLoader()
@classmethod
def setUpClass(cls):
super(RuleEnforcementControllerTestCase, cls).setUpClass()
- cls.models = RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
- RuleEnforcementControllerTestCase.ENFORCEMENT_1 = \
- cls.models['enforcements']['enforcement1.yaml']
+ cls.models = (
+ RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ )
+ RuleEnforcementControllerTestCase.ENFORCEMENT_1 = cls.models["enforcements"][
+ "enforcement1.yaml"
+ ]
def test_get_all(self):
- resp = self.app.get('/v1/ruleenforcements')
+ resp = self.app.get("/v1/ruleenforcements")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
def test_get_all_minus_one(self):
- resp = self.app.get('/v1/ruleenforcements/?limit=-1')
+ resp = self.app.get("/v1/ruleenforcements/?limit=-1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
def test_get_all_limit(self):
- resp = self.app.get('/v1/ruleenforcements/?limit=1')
+ resp = self.app.get("/v1/ruleenforcements/?limit=1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
def test_get_all_limit_negative_number(self):
- resp = self.app.get('/v1/ruleenforcements?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/ruleenforcements?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_get_one_by_id(self):
e_id = str(RuleEnforcementControllerTestCase.ENFORCEMENT_1.id)
- resp = self.app.get('/v1/ruleenforcements/%s' % e_id)
+ resp = self.app.get("/v1/ruleenforcements/%s" % e_id)
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(resp.json['id'], e_id)
+ self.assertEqual(resp.json["id"], e_id)
def test_get_one_fail(self):
- resp = self.app.get('/v1/ruleenforcements/1', expect_errors=True)
+ resp = self.app.get("/v1/ruleenforcements/1", expect_errors=True)
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_filter_by_rule_ref(self):
- resp = self.app.get('/v1/ruleenforcements?rule_ref=wolfpack.golden_rule')
+ resp = self.app.get("/v1/ruleenforcements?rule_ref=wolfpack.golden_rule")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
def test_filter_by_rule_id(self):
- resp = self.app.get('/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331')
+ resp = self.app.get("/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 2)
def test_filter_by_execution_id(self):
- resp = self.app.get('/v1/ruleenforcements?execution=565e15cd32ed350857dfa620')
+ resp = self.app.get("/v1/ruleenforcements?execution=565e15cd32ed350857dfa620")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
def test_filter_by_trigger_instance_id(self):
- resp = self.app.get('/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623')
+ resp = self.app.get(
+ "/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623"
+ )
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
def test_filter_by_enforced_at(self):
- resp = self.app.get('/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z')
+ resp = self.app.get(
+ "/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z"
+ )
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 2)
- resp = self.app.get('/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z')
+ resp = self.app.get(
+ "/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z"
+ )
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
def _insert_mock_models(self):
- enfrocement_ids = [enforcement['id'] for enforcement in
- self.models['enforcements'].values()]
+ enfrocement_ids = [
+ enforcement["id"] for enforcement in self.models["enforcements"].values()
+ ]
return enfrocement_ids
def _delete_mock_models(self, object_ids):
diff --git a/st2api/tests/unit/controllers/v1/test_rule_views.py b/st2api/tests/unit/controllers/v1/test_rule_views.py
index f8a25e5d3d..95839c3110 100644
--- a/st2api/tests/unit/controllers/v1/test_rule_views.py
+++ b/st2api/tests/unit/controllers/v1/test_rule_views.py
@@ -25,25 +25,24 @@
http_client = six.moves.http_client
TEST_FIXTURES = {
- 'runners': ['testrunner1.yaml'],
- 'actions': ['action1.yaml', 'action2.yaml'],
- 'triggers': ['trigger1.yaml'],
- 'triggertypes': ['triggertype1.yaml']
+ "runners": ["testrunner1.yaml"],
+ "actions": ["action1.yaml", "action2.yaml"],
+ "triggers": ["trigger1.yaml"],
+ "triggertypes": ["triggertype1.yaml"],
}
-TEST_FIXTURES_RULES = {
- 'rules': ['rule1.yaml', 'rule4.yaml', 'rule5.yaml']
-}
+TEST_FIXTURES_RULES = {"rules": ["rule1.yaml", "rule4.yaml", "rule5.yaml"]}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
-class RuleViewControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/rules/views'
+class RuleViewControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/rules/views"
controller_cls = RuleViewController
- include_attribute_field_name = 'criteria'
- exclude_attribute_field_name = 'enabled'
+ include_attribute_field_name = "criteria"
+ exclude_attribute_field_name = "enabled"
fixtures_loader = FixturesLoader()
@@ -51,17 +50,21 @@ class RuleViewControllerTestCase(FunctionalTest,
def setUpClass(cls):
super(RuleViewControllerTestCase, cls).setUpClass()
models = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
- RuleViewControllerTestCase.ACTION_1 = models['actions']['action1.yaml']
- RuleViewControllerTestCase.TRIGGER_TYPE_1 = models['triggertypes']['triggertype1.yaml']
-
- file_name = 'rule1.yaml'
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ RuleViewControllerTestCase.ACTION_1 = models["actions"]["action1.yaml"]
+ RuleViewControllerTestCase.TRIGGER_TYPE_1 = models["triggertypes"][
+ "triggertype1.yaml"
+ ]
+
+ file_name = "rule1.yaml"
cls.rules = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES)['rules']
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES
+ )["rules"]
RuleViewControllerTestCase.RULE_1 = cls.rules[file_name]
def test_get_all(self):
- resp = self.app.get('/v1/rules/views')
+ resp = self.app.get("/v1/rules/views")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
@@ -70,25 +73,29 @@ def test_get_one_by_id(self):
get_resp = self.__do_get_one(rule_id)
self.assertEqual(get_resp.status_int, http_client.OK)
self.assertEqual(self.__get_rule_id(get_resp), rule_id)
- self.assertEqual(get_resp.json['action']['description'],
- RuleViewControllerTestCase.ACTION_1.description)
- self.assertEqual(get_resp.json['trigger']['description'],
- RuleViewControllerTestCase.TRIGGER_TYPE_1.description)
+ self.assertEqual(
+ get_resp.json["action"]["description"],
+ RuleViewControllerTestCase.ACTION_1.description,
+ )
+ self.assertEqual(
+ get_resp.json["trigger"]["description"],
+ RuleViewControllerTestCase.TRIGGER_TYPE_1.description,
+ )
def test_get_one_by_ref(self):
rule_name = RuleViewControllerTestCase.RULE_1.name
rule_pack = RuleViewControllerTestCase.RULE_1.pack
ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack)
get_resp = self.__do_get_one(ref)
- self.assertEqual(get_resp.json['name'], rule_name)
+ self.assertEqual(get_resp.json["name"], rule_name)
self.assertEqual(get_resp.status_int, http_client.OK)
def test_get_one_fail(self):
- resp = self.app.get('/v1/rules/1', expect_errors=True)
+ resp = self.app.get("/v1/rules/1", expect_errors=True)
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def _insert_mock_models(self):
- rule_ids = [rule['id'] for rule in self.rules.values()]
+ rule_ids = [rule["id"] for rule in self.rules.values()]
return rule_ids
def _delete_mock_models(self, object_ids):
@@ -96,7 +103,7 @@ def _delete_mock_models(self, object_ids):
@staticmethod
def __get_rule_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def __do_get_one(self, rule_id):
- return self.app.get('/v1/rules/views/%s' % rule_id, expect_errors=True)
+ return self.app.get("/v1/rules/views/%s" % rule_id, expect_errors=True)
diff --git a/st2api/tests/unit/controllers/v1/test_rules.py b/st2api/tests/unit/controllers/v1/test_rules.py
index f52b4294ca..daf6845bcb 100644
--- a/st2api/tests/unit/controllers/v1/test_rules.py
+++ b/st2api/tests/unit/controllers/v1/test_rules.py
@@ -34,21 +34,23 @@
http_client = six.moves.http_client
TEST_FIXTURES = {
- 'runners': ['testrunner1.yaml'],
- 'actions': ['action1.yaml'],
- 'triggers': ['trigger1.yaml'],
- 'triggertypes': ['triggertype1.yaml', 'triggertype_with_parameters_2.yaml']
+ "runners": ["testrunner1.yaml"],
+ "actions": ["action1.yaml"],
+ "triggers": ["trigger1.yaml"],
+ "triggertypes": ["triggertype1.yaml", "triggertype_with_parameters_2.yaml"],
}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
-class RulesControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/rules'
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+class RulesControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/rules"
controller_cls = RuleController
- include_attribute_field_name = 'criteria'
- exclude_attribute_field_name = 'enabled'
+ include_attribute_field_name = "criteria"
+ exclude_attribute_field_name = "enabled"
VALIDATE_TRIGGER_PAYLOAD = None
@@ -64,71 +66,96 @@ def setUpClass(cls):
cls.VALIDATE_TRIGGER_PAYLOAD = cfg.CONF.system.validate_trigger_parameters
models = RulesControllerTestCase.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
- RulesControllerTestCase.RUNNER_TYPE = models['runners']['testrunner1.yaml']
- RulesControllerTestCase.ACTION = models['actions']['action1.yaml']
- RulesControllerTestCase.TRIGGER = models['triggers']['trigger1.yaml']
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ RulesControllerTestCase.RUNNER_TYPE = models["runners"]["testrunner1.yaml"]
+ RulesControllerTestCase.ACTION = models["actions"]["action1.yaml"]
+ RulesControllerTestCase.TRIGGER = models["triggers"]["trigger1.yaml"]
# Don't load rule into DB as that is what is being tested.
- file_name = 'rule1.yaml'
- RulesControllerTestCase.RULE_1 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'cron_timer_rule_invalid_parameters.yaml'
- RulesControllerTestCase.RULE_2 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'rule_no_enabled_attribute.yaml'
- RulesControllerTestCase.RULE_3 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'backstop_rule.yaml'
- RulesControllerTestCase.RULE_4 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'date_timer_rule_invalid_parameters.yaml'
- RulesControllerTestCase.RULE_5 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'cron_timer_rule_invalid_parameters_1.yaml'
- RulesControllerTestCase.RULE_6 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'cron_timer_rule_invalid_parameters_2.yaml'
- RulesControllerTestCase.RULE_7 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'cron_timer_rule_invalid_parameters_3.yaml'
- RulesControllerTestCase.RULE_8 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'rule_invalid_trigger_parameter_type.yaml'
- RulesControllerTestCase.RULE_9 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'rule_trigger_with_no_parameters.yaml'
- RulesControllerTestCase.RULE_10 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'rule_invalid_trigger_parameter_type_default_cfg.yaml'
- RulesControllerTestCase.RULE_11 = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
-
- file_name = 'rule space.yaml'
- RulesControllerTestCase.RULE_SPACE = RulesControllerTestCase.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
+ file_name = "rule1.yaml"
+ RulesControllerTestCase.RULE_1 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "cron_timer_rule_invalid_parameters.yaml"
+ RulesControllerTestCase.RULE_2 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "rule_no_enabled_attribute.yaml"
+ RulesControllerTestCase.RULE_3 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "backstop_rule.yaml"
+ RulesControllerTestCase.RULE_4 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "date_timer_rule_invalid_parameters.yaml"
+ RulesControllerTestCase.RULE_5 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "cron_timer_rule_invalid_parameters_1.yaml"
+ RulesControllerTestCase.RULE_6 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "cron_timer_rule_invalid_parameters_2.yaml"
+ RulesControllerTestCase.RULE_7 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "cron_timer_rule_invalid_parameters_3.yaml"
+ RulesControllerTestCase.RULE_8 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "rule_invalid_trigger_parameter_type.yaml"
+ RulesControllerTestCase.RULE_9 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "rule_trigger_with_no_parameters.yaml"
+ RulesControllerTestCase.RULE_10 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "rule_invalid_trigger_parameter_type_default_cfg.yaml"
+ RulesControllerTestCase.RULE_11 = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
+
+ file_name = "rule space.yaml"
+ RulesControllerTestCase.RULE_SPACE = (
+ RulesControllerTestCase.fixtures_loader.load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
+ )
@classmethod
def tearDownClass(cls):
@@ -136,18 +163,19 @@ def tearDownClass(cls):
cfg.CONF.system.validate_trigger_payload = cls.VALIDATE_TRIGGER_PAYLOAD
RulesControllerTestCase.fixtures_loader.delete_fixtures_from_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
super(RulesControllerTestCase, cls).setUpClass()
def test_get_all_and_minus_one(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3)
- resp = self.app.get('/v1/rules')
+ resp = self.app.get("/v1/rules")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 2)
- resp = self.app.get('/v1/rules/?limit=-1')
+ resp = self.app.get("/v1/rules/?limit=-1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 2)
@@ -158,10 +186,12 @@ def test_get_all_limit_negative_number(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3)
- resp = self.app.get('/v1/rules?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/rules?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
self.__do_delete(self.__get_rule_id(post_resp_rule_3))
@@ -171,18 +201,18 @@ def test_get_all_enabled(self):
post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3)
# enabled=True
- resp = self.app.get('/v1/rules?enabled=True')
+ resp = self.app.get("/v1/rules?enabled=True")
self.assertEqual(resp.status_int, http_client.OK)
rule = resp.json[0]
- self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule['id'])
- self.assertEqual(rule['enabled'], True)
+ self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule["id"])
+ self.assertEqual(rule["enabled"], True)
# enabled=False
- resp = self.app.get('/v1/rules?enabled=False')
+ resp = self.app.get("/v1/rules?enabled=False")
self.assertEqual(resp.status_int, http_client.OK)
rule = resp.json[0]
- self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule['id'])
- self.assertEqual(rule['enabled'], False)
+ self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule["id"])
+ self.assertEqual(rule["enabled"], False)
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
self.__do_delete(self.__get_rule_id(post_resp_rule_3))
@@ -191,37 +221,45 @@ def test_get_all_action_parameters_secrets_masking(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
# Verify parameter is masked by default
- resp = self.app.get('/v1/rules')
- self.assertEqual('action' in resp.json[0], True)
- self.assertEqual(resp.json[0]['action']['parameters']['action_secret'],
- MASKED_ATTRIBUTE_VALUE)
+ resp = self.app.get("/v1/rules")
+ self.assertEqual("action" in resp.json[0], True)
+ self.assertEqual(
+ resp.json[0]["action"]["parameters"]["action_secret"],
+ MASKED_ATTRIBUTE_VALUE,
+ )
# Verify ?show_secrets=true works
- resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true')
- self.assertEqual('action' in resp.json[0], True)
- self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret')
+ resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true")
+ self.assertEqual("action" in resp.json[0], True)
+ self.assertEqual(
+ resp.json[0]["action"]["parameters"]["action_secret"], "secret"
+ )
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
def test_get_all_parameters_mask_with_exclude_parameters(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
- resp = self.app.get('/v1/rules?exclude_attributes=action')
- self.assertEqual('action' in resp.json[0], False)
+ resp = self.app.get("/v1/rules?exclude_attributes=action")
+ self.assertEqual("action" in resp.json[0], False)
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
def test_get_all_parameters_mask_with_include_parameters(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
# Verify parameter is masked by default
- resp = self.app.get('/v1/rules?include_attributes=action')
- self.assertEqual('action' in resp.json[0], True)
- self.assertEqual(resp.json[0]['action']['parameters']['action_secret'],
- MASKED_ATTRIBUTE_VALUE)
+ resp = self.app.get("/v1/rules?include_attributes=action")
+ self.assertEqual("action" in resp.json[0], True)
+ self.assertEqual(
+ resp.json[0]["action"]["parameters"]["action_secret"],
+ MASKED_ATTRIBUTE_VALUE,
+ )
# Verify ?show_secrets=true works
- resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true')
- self.assertEqual('action' in resp.json[0], True)
- self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret')
+ resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true")
+ self.assertEqual("action" in resp.json[0], True)
+ self.assertEqual(
+ resp.json[0]["action"]["parameters"]["action_secret"], "secret"
+ )
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
@@ -229,13 +267,16 @@ def test_get_one_action_parameters_secrets_masking(self):
post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1)
# Verify parameter is masked by default
- resp = self.app.get('/v1/rules/%s' % (post_resp_rule_1.json['id']))
- self.assertEqual(resp.json['action']['parameters']['action_secret'],
- MASKED_ATTRIBUTE_VALUE)
+ resp = self.app.get("/v1/rules/%s" % (post_resp_rule_1.json["id"]))
+ self.assertEqual(
+ resp.json["action"]["parameters"]["action_secret"], MASKED_ATTRIBUTE_VALUE
+ )
# Verify ?show_secrets=true works
- resp = self.app.get('/v1/rules/%s?show_secrets=true' % (post_resp_rule_1.json['id']))
- self.assertEqual(resp.json['action']['parameters']['action_secret'], 'secret')
+ resp = self.app.get(
+ "/v1/rules/%s?show_secrets=true" % (post_resp_rule_1.json["id"])
+ )
+ self.assertEqual(resp.json["action"]["parameters"]["action_secret"], "secret")
self.__do_delete(self.__get_rule_id(post_resp_rule_1))
@@ -249,27 +290,27 @@ def test_get_one_by_id(self):
def test_get_one_by_ref(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_1)
- rule_name = post_resp.json['name']
- rule_pack = post_resp.json['pack']
+ rule_name = post_resp.json["name"]
+ rule_pack = post_resp.json["pack"]
ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack)
- rule_id = post_resp.json['id']
+ rule_id = post_resp.json["id"]
get_resp = self.__do_get_one(ref)
- self.assertEqual(get_resp.json['name'], rule_name)
+ self.assertEqual(get_resp.json["name"], rule_name)
self.assertEqual(get_resp.status_int, http_client.OK)
self.__do_delete(rule_id)
post_resp = self.__do_post(RulesControllerTestCase.RULE_SPACE)
- rule_name = post_resp.json['name']
- rule_pack = post_resp.json['pack']
+ rule_name = post_resp.json["name"]
+ rule_pack = post_resp.json["pack"]
ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack)
- rule_id = post_resp.json['id']
+ rule_id = post_resp.json["id"]
get_resp = self.__do_get_one(ref)
- self.assertEqual(get_resp.json['name'], rule_name)
+ self.assertEqual(get_resp.json["name"], rule_name)
self.assertEqual(get_resp.status_int, http_client.OK)
self.__do_delete(rule_id)
def test_get_one_fail(self):
- resp = self.app.get('/v1/rules/1', expect_errors=True)
+ resp = self.app.get("/v1/rules/1", expect_errors=True)
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_post(self):
@@ -283,38 +324,44 @@ def test_post_duplicate(self):
self.assertEqual(post_resp.status_int, http_client.CREATED)
post_resp_2 = self.__do_post(RulesControllerTestCase.RULE_1)
self.assertEqual(post_resp_2.status_int, http_client.CONFLICT)
- self.assertEqual(post_resp_2.json['conflict-id'], org_id)
+ self.assertEqual(post_resp_2.json["conflict-id"], org_id)
self.__do_delete(org_id)
def test_post_invalid_rule_data(self):
- post_resp = self.__do_post({'name': 'rule'})
+ post_resp = self.__do_post({"name": "rule"})
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
expected_msg = "'trigger' is a required property"
- self.assertEqual(post_resp.json['faultstring'], expected_msg)
+ self.assertEqual(post_resp.json["faultstring"], expected_msg)
def test_post_trigger_parameter_schema_validation_fails(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_2)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
if six.PY3:
- expected_msg = b'Additional properties are not allowed (\'minutex\' was unexpected)'
+ expected_msg = (
+ b"Additional properties are not allowed ('minutex' was unexpected)"
+ )
else:
- expected_msg = b'Additional properties are not allowed (u\'minutex\' was unexpected)'
+ expected_msg = (
+ b"Additional properties are not allowed (u'minutex' was unexpected)"
+ )
self.assertIn(expected_msg, post_resp.body)
- def test_post_trigger_parameter_schema_validation_fails_missing_required_param(self):
+ def test_post_trigger_parameter_schema_validation_fails_missing_required_param(
+ self,
+ ):
post_resp = self.__do_post(RulesControllerTestCase.RULE_5)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- expected_msg = b'\'date\' is a required property'
+ expected_msg = b"'date' is a required property"
self.assertIn(expected_msg, post_resp.body)
def test_post_invalid_crontimer_trigger_parameters(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_6)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- expected_msg = b'1000 is greater than the maximum of 6'
+ expected_msg = b"1000 is greater than the maximum of 6"
self.assertIn(expected_msg, post_resp.body)
post_resp = self.__do_post(RulesControllerTestCase.RULE_7)
@@ -329,7 +376,9 @@ def test_post_invalid_crontimer_trigger_parameters(self):
expected_msg = b'Invalid weekday name \\"a\\"'
self.assertIn(expected_msg, post_resp.body)
- def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled(self):
+ def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled(
+ self,
+ ):
# Invalid custom trigger parameter (invalid type) and non-system trigger parameter
# validation is enabled - trigger creation should fail
cfg.CONF.system.validate_trigger_parameters = True
@@ -338,16 +387,22 @@ def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled(
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
if six.PY3:
- expected_msg_1 = "Failed validating 'type' in schema['properties']['param1']:"
- expected_msg_2 = '12345 is not of type \'string\''
+ expected_msg_1 = (
+ "Failed validating 'type' in schema['properties']['param1']:"
+ )
+ expected_msg_2 = "12345 is not of type 'string'"
else:
- expected_msg_1 = "Failed validating u'type' in schema[u'properties'][u'param1']:"
- expected_msg_2 = '12345 is not of type u\'string\''
+ expected_msg_1 = (
+ "Failed validating u'type' in schema[u'properties'][u'param1']:"
+ )
+ expected_msg_2 = "12345 is not of type u'string'"
- self.assertIn(expected_msg_1, post_resp.json['faultstring'])
- self.assertIn(expected_msg_2, post_resp.json['faultstring'])
+ self.assertIn(expected_msg_1, post_resp.json["faultstring"])
+ self.assertIn(expected_msg_2, post_resp.json["faultstring"])
- def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled(self):
+ def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled(
+ self,
+ ):
# Invalid custom trigger parameter (invalid type) and non-system trigger parameter
# validation is disabled - trigger creation should succeed
cfg.CONF.system.validate_trigger_parameters = False
@@ -368,33 +423,33 @@ def test_post_invalid_custom_trigger_parameter_trigger_no_parameters_schema(self
def test_post_no_enabled_attribute_disabled_by_default(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_3)
self.assertEqual(post_resp.status_int, http_client.CREATED)
- self.assertFalse(post_resp.json['enabled'])
+ self.assertFalse(post_resp.json["enabled"])
self.__do_delete(self.__get_rule_id(post_resp))
def test_put(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_1)
update_input = post_resp.json
- update_input['enabled'] = not update_input['enabled']
+ update_input["enabled"] = not update_input["enabled"]
put_resp = self.__do_put(self.__get_rule_id(post_resp), update_input)
self.assertEqual(put_resp.status_int, http_client.OK)
self.__do_delete(self.__get_rule_id(put_resp))
def test_post_no_pack_info(self):
rule = copy.deepcopy(RulesControllerTestCase.RULE_1)
- del rule['pack']
+ del rule["pack"]
post_resp = self.__do_post(rule)
- self.assertEqual(post_resp.json['pack'], DEFAULT_PACK_NAME)
+ self.assertEqual(post_resp.json["pack"], DEFAULT_PACK_NAME)
self.assertEqual(post_resp.status_int, http_client.CREATED)
self.__do_delete(self.__get_rule_id(post_resp))
def test_put_no_pack_info(self):
post_resp = self.__do_post(RulesControllerTestCase.RULE_1)
test_rule = post_resp.json
- if 'pack' in test_rule:
- del test_rule['pack']
- self.assertNotIn('pack', test_rule)
+ if "pack" in test_rule:
+ del test_rule["pack"]
+ self.assertNotIn("pack", test_rule)
put_resp = self.__do_put(self.__get_rule_id(post_resp), test_rule)
- self.assertEqual(put_resp.json['pack'], DEFAULT_PACK_NAME)
+ self.assertEqual(put_resp.json["pack"], DEFAULT_PACK_NAME)
self.assertEqual(put_resp.status_int, http_client.OK)
self.__do_delete(self.__get_rule_id(put_resp))
@@ -417,7 +472,7 @@ def test_rule_with_tags(self):
get_resp = self.__do_get_one(rule_id)
self.assertEqual(get_resp.status_int, http_client.OK)
self.assertEqual(self.__get_rule_id(get_resp), rule_id)
- self.assertEqual(get_resp.json['tags'], RulesControllerTestCase.RULE_1['tags'])
+ self.assertEqual(get_resp.json["tags"], RulesControllerTestCase.RULE_1["tags"])
self.__do_delete(rule_id)
def test_rule_without_type(self):
@@ -426,10 +481,13 @@ def test_rule_without_type(self):
get_resp = self.__do_get_one(rule_id)
self.assertEqual(get_resp.status_int, http_client.OK)
self.assertEqual(self.__get_rule_id(get_resp), rule_id)
- assigned_rule_type = get_resp.json['type']
- self.assertTrue(assigned_rule_type, 'rule_type should be assigned')
- self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_STANDARD,
- 'rule_type should be standard')
+ assigned_rule_type = get_resp.json["type"]
+ self.assertTrue(assigned_rule_type, "rule_type should be assigned")
+ self.assertEqual(
+ assigned_rule_type["ref"],
+ RULE_TYPE_STANDARD,
+ "rule_type should be standard",
+ )
self.__do_delete(rule_id)
def test_rule_with_type(self):
@@ -438,10 +496,13 @@ def test_rule_with_type(self):
get_resp = self.__do_get_one(rule_id)
self.assertEqual(get_resp.status_int, http_client.OK)
self.assertEqual(self.__get_rule_id(get_resp), rule_id)
- assigned_rule_type = get_resp.json['type']
- self.assertTrue(assigned_rule_type, 'rule_type should be assigned')
- self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_BACKSTOP,
- 'rule_type should be backstop')
+ assigned_rule_type = get_resp.json["type"]
+ self.assertTrue(assigned_rule_type, "rule_type should be assigned")
+ self.assertEqual(
+ assigned_rule_type["ref"],
+ RULE_TYPE_BACKSTOP,
+ "rule_type should be backstop",
+ )
self.__do_delete(rule_id)
def test_update_rule_no_data(self):
@@ -451,7 +512,7 @@ def test_update_rule_no_data(self):
put_resp = self.__do_put(rule_1_id, {})
expected_msg = "'name' is a required property"
self.assertEqual(put_resp.status_code, http_client.BAD_REQUEST)
- self.assertEqual(put_resp.json['faultstring'], expected_msg)
+ self.assertEqual(put_resp.json["faultstring"], expected_msg)
self.__do_delete(rule_1_id)
@@ -460,16 +521,16 @@ def test_update_rule_missing_id_in_body(self):
rule_1_id = self.__get_rule_id(post_resp)
rule_without_id = copy.deepcopy(self.RULE_1)
- rule_without_id.pop('id', None)
+ rule_without_id.pop("id", None)
put_resp = self.__do_put(rule_1_id, rule_without_id)
self.assertEqual(put_resp.status_int, http_client.OK)
- self.assertEqual(put_resp.json['id'], rule_1_id)
+ self.assertEqual(put_resp.json["id"], rule_1_id)
self.__do_delete(rule_1_id)
def _insert_mock_models(self):
rule = copy.deepcopy(RulesControllerTestCase.RULE_1)
- rule['name'] += '-253'
+ rule["name"] += "-253"
post_resp = self.__do_post(rule)
rule_1_id = self.__get_rule_id(post_resp)
return [rule_1_id]
@@ -479,32 +540,32 @@ def _do_delete(self, rule_id):
@staticmethod
def __get_rule_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def __do_get_one(self, rule_id):
- return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True)
+ return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True)
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_post(self, rule):
- return self.app.post_json('/v1/rules', rule, expect_errors=True)
+ return self.app.post_json("/v1/rules", rule, expect_errors=True)
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_put(self, rule_id, rule):
- return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True)
+ return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True)
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def __do_delete(self, rule_id):
- return self.app.delete('/v1/rules/%s' % rule_id)
+ return self.app.delete("/v1/rules/%s" % rule_id)
TEST_FIXTURES_2 = {
- 'runners': ['testrunner1.yaml'],
- 'actions': ['action1.yaml'],
- 'triggertypes': ['triggertype_with_parameter.yaml']
+ "runners": ["testrunner1.yaml"],
+ "actions": ["action1.yaml"],
+ "triggertypes": ["triggertype_with_parameter.yaml"],
}
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class RulesControllerTestCaseTriggerCreator(FunctionalTest):
fixtures_loader = FixturesLoader()
@@ -513,32 +574,33 @@ class RulesControllerTestCaseTriggerCreator(FunctionalTest):
def setUpClass(cls):
super(RulesControllerTestCaseTriggerCreator, cls).setUpClass()
cls.models = cls.fixtures_loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2)
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2
+ )
# Don't load rule into DB as that is what is being tested.
- file_name = 'rule_trigger_params.yaml'
+ file_name = "rule_trigger_params.yaml"
cls.RULE_1 = cls.fixtures_loader.load_fixtures(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict={'rules': [file_name]})['rules'][file_name]
+ fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]}
+ )["rules"][file_name]
def test_ref_count_trigger_increment(self):
post_resp = self.__do_post(self.RULE_1)
rule_1_id = self.__get_rule_id(post_resp)
self.assertEqual(post_resp.status_int, http_client.CREATED)
# ref_count is not served over API. Likely a choice that will prove unwise.
- triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']})
- self.assertEqual(len(triggers), 1, 'Exactly 1 should exist')
- self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1')
+ triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]})
+ self.assertEqual(len(triggers), 1, "Exactly 1 should exist")
+ self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1")
# different rule same params
rule_2 = copy.copy(self.RULE_1)
- rule_2['name'] = rule_2['name'] + '-2'
+ rule_2["name"] = rule_2["name"] + "-2"
post_resp = self.__do_post(rule_2)
rule_2_id = self.__get_rule_id(post_resp)
self.assertEqual(post_resp.status_int, http_client.CREATED)
- triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']})
- self.assertEqual(len(triggers), 1, 'Exactly 1 should exist')
- self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1')
+ triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]})
+ self.assertEqual(len(triggers), 1, "Exactly 1 should exist")
+ self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1")
self.__do_delete(rule_1_id)
self.__do_delete(rule_2_id)
@@ -549,16 +611,16 @@ def test_ref_count_trigger_decrement(self):
self.assertEqual(post_resp.status_int, http_client.CREATED)
rule_2 = copy.copy(self.RULE_1)
- rule_2['name'] = rule_2['name'] + '-2'
+ rule_2["name"] = rule_2["name"] + "-2"
post_resp = self.__do_post(rule_2)
rule_2_id = self.__get_rule_id(post_resp)
self.assertEqual(post_resp.status_int, http_client.CREATED)
# validate decrement
self.__do_delete(rule_1_id)
- triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']})
- self.assertEqual(len(triggers), 1, 'Exactly 1 should exist')
- self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1')
+ triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]})
+ self.assertEqual(len(triggers), 1, "Exactly 1 should exist")
+ self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1")
self.__do_delete(rule_2_id)
def test_trigger_cleanup(self):
@@ -567,34 +629,34 @@ def test_trigger_cleanup(self):
self.assertEqual(post_resp.status_int, http_client.CREATED)
rule_2 = copy.copy(self.RULE_1)
- rule_2['name'] = rule_2['name'] + '-2'
+ rule_2["name"] = rule_2["name"] + "-2"
post_resp = self.__do_post(rule_2)
rule_2_id = self.__get_rule_id(post_resp)
self.assertEqual(post_resp.status_int, http_client.CREATED)
- triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']})
- self.assertEqual(len(triggers), 1, 'Exactly 1 should exist')
- self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1')
+ triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]})
+ self.assertEqual(len(triggers), 1, "Exactly 1 should exist")
+ self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1")
self.__do_delete(rule_1_id)
self.__do_delete(rule_2_id)
# validate cleanup
- triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']})
- self.assertEqual(len(triggers), 0, 'Exactly 1 should exist')
+ triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]})
+ self.assertEqual(len(triggers), 0, "Exactly 1 should exist")
@staticmethod
def __get_rule_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def __do_get_one(self, rule_id):
- return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True)
+ return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True)
def __do_post(self, rule):
- return self.app.post_json('/v1/rules', rule, expect_errors=True)
+ return self.app.post_json("/v1/rules", rule, expect_errors=True)
def __do_put(self, rule_id, rule):
- return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True)
+ return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True)
def __do_delete(self, rule_id):
- return self.app.delete('/v1/rules/%s' % rule_id)
+ return self.app.delete("/v1/rules/%s" % rule_id)
diff --git a/st2api/tests/unit/controllers/v1/test_ruletypes.py b/st2api/tests/unit/controllers/v1/test_ruletypes.py
index 5cba961409..87b1c4c584 100644
--- a/st2api/tests/unit/controllers/v1/test_ruletypes.py
+++ b/st2api/tests/unit/controllers/v1/test_ruletypes.py
@@ -26,20 +26,26 @@ def setUpClass(cls):
ruletypes_registrar.register_rule_types()
def test_get_one(self):
- list_resp = self.app.get('/v1/ruletypes')
+ list_resp = self.app.get("/v1/ruletypes")
self.assertEqual(list_resp.status_int, 200)
- self.assertTrue(len(list_resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.')
- ruletype_id = list_resp.json[0]['id']
- get_resp = self.app.get('/v1/ruletypes/%s' % ruletype_id)
- retrieved_id = get_resp.json['id']
+ self.assertTrue(
+ len(list_resp.json) > 0, "/v1/ruletypes did not return correct ruletypes."
+ )
+ ruletype_id = list_resp.json[0]["id"]
+ get_resp = self.app.get("/v1/ruletypes/%s" % ruletype_id)
+ retrieved_id = get_resp.json["id"]
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(retrieved_id, ruletype_id, '/v1/ruletypes returned incorrect ruletype.')
+ self.assertEqual(
+ retrieved_id, ruletype_id, "/v1/ruletypes returned incorrect ruletype."
+ )
def test_get_all(self):
- resp = self.app.get('/v1/ruletypes')
+ resp = self.app.get("/v1/ruletypes")
self.assertEqual(resp.status_int, 200)
- self.assertTrue(len(resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.')
+ self.assertTrue(
+ len(resp.json) > 0, "/v1/ruletypes did not return correct ruletypes."
+ )
def test_get_one_fail_doesnt_exist(self):
- resp = self.app.get('/v1/ruletypes/1', expect_errors=True)
+ resp = self.app.get("/v1/ruletypes/1", expect_errors=True)
self.assertEqual(resp.status_int, 404)
diff --git a/st2api/tests/unit/controllers/v1/test_runnertypes.py b/st2api/tests/unit/controllers/v1/test_runnertypes.py
index edaacdf6dd..34c243c545 100644
--- a/st2api/tests/unit/controllers/v1/test_runnertypes.py
+++ b/st2api/tests/unit/controllers/v1/test_runnertypes.py
@@ -18,67 +18,76 @@
from st2tests.api import FunctionalTest
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
-__all__ = [
- 'RunnerTypesControllerTestCase'
-]
+__all__ = ["RunnerTypesControllerTestCase"]
-class RunnerTypesControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/runnertypes'
+class RunnerTypesControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/runnertypes"
controller_cls = RunnerTypesController
- include_attribute_field_name = 'runner_package'
- exclude_attribute_field_name = 'runner_module'
- test_exact_object_count = False # runners are registered dynamically in base test class
+ include_attribute_field_name = "runner_package"
+ exclude_attribute_field_name = "runner_module"
+ test_exact_object_count = (
+ False # runners are registered dynamically in base test class
+ )
def test_get_one(self):
- resp = self.app.get('/v1/runnertypes')
+ resp = self.app.get("/v1/runnertypes")
self.assertEqual(resp.status_int, 200)
- self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.')
+ self.assertTrue(
+ len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes."
+ )
runnertype_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json[0])
- resp = self.app.get('/v1/runnertypes/%s' % runnertype_id)
+ resp = self.app.get("/v1/runnertypes/%s" % runnertype_id)
retrieved_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(retrieved_id, runnertype_id,
- '/v1/runnertypes returned incorrect runnertype.')
+ self.assertEqual(
+ retrieved_id,
+ runnertype_id,
+ "/v1/runnertypes returned incorrect runnertype.",
+ )
def test_get_all(self):
- resp = self.app.get('/v1/runnertypes')
+ resp = self.app.get("/v1/runnertypes")
self.assertEqual(resp.status_int, 200)
- self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.')
+ self.assertTrue(
+ len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes."
+ )
def test_get_one_fail_doesnt_exist(self):
- resp = self.app.get('/v1/runnertypes/1', expect_errors=True)
+ resp = self.app.get("/v1/runnertypes/1", expect_errors=True)
self.assertEqual(resp.status_int, 404)
def test_put_disable_runner(self):
- runnertype_id = 'action-chain'
- resp = self.app.get('/v1/runnertypes/%s' % runnertype_id)
- self.assertTrue(resp.json['enabled'])
+ runnertype_id = "action-chain"
+ resp = self.app.get("/v1/runnertypes/%s" % runnertype_id)
+ self.assertTrue(resp.json["enabled"])
# Disable the runner
update_input = resp.json
- update_input['enabled'] = False
- update_input['name'] = 'foobar'
+ update_input["enabled"] = False
+ update_input["name"] = "foobar"
put_resp = self.__do_put(runnertype_id, update_input)
- self.assertFalse(put_resp.json['enabled'])
+ self.assertFalse(put_resp.json["enabled"])
# Verify that the name hasn't been updated - we only allow updating
# enabled attribute on the runner
- self.assertEqual(put_resp.json['name'], 'action-chain')
+ self.assertEqual(put_resp.json["name"], "action-chain")
# Enable the runner
update_input = resp.json
- update_input['enabled'] = True
+ update_input["enabled"] = True
put_resp = self.__do_put(runnertype_id, update_input)
- self.assertTrue(put_resp.json['enabled'])
+ self.assertTrue(put_resp.json["enabled"])
def __do_put(self, runner_type_id, runner_type):
- return self.app.put_json('/v1/runnertypes/%s' % runner_type_id, runner_type,
- expect_errors=True)
+ return self.app.put_json(
+ "/v1/runnertypes/%s" % runner_type_id, runner_type, expect_errors=True
+ )
@staticmethod
def __get_runnertype_id(resp_json):
- return resp_json['id']
+ return resp_json["id"]
diff --git a/st2api/tests/unit/controllers/v1/test_sensortypes.py b/st2api/tests/unit/controllers/v1/test_sensortypes.py
index 8e66cdfb40..c59a1c28e2 100644
--- a/st2api/tests/unit/controllers/v1/test_sensortypes.py
+++ b/st2api/tests/unit/controllers/v1/test_sensortypes.py
@@ -25,17 +25,16 @@
http_client = six.moves.http_client
-__all__ = [
- 'SensorTypeControllerTestCase'
-]
+__all__ = ["SensorTypeControllerTestCase"]
-class SensorTypeControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/sensortypes'
+class SensorTypeControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/sensortypes"
controller_cls = SensorTypeController
- include_attribute_field_name = 'entry_point'
- exclude_attribute_field_name = 'artifact_uri'
+ include_attribute_field_name = "entry_point"
+ exclude_attribute_field_name = "artifact_uri"
test_exact_object_count = False
@classmethod
@@ -46,106 +45,108 @@ def setUpClass(cls):
sensors_registrar.register_sensors(use_pack_cache=False)
def test_get_all_and_minus_one(self):
- resp = self.app.get('/v1/sensortypes')
+ resp = self.app.get("/v1/sensortypes")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
- self.assertEqual(resp.json[0]['name'], 'SampleSensor')
+ self.assertEqual(resp.json[0]["name"], "SampleSensor")
- resp = self.app.get('/v1/sensortypes/?limit=-1')
+ resp = self.app.get("/v1/sensortypes/?limit=-1")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
- self.assertEqual(resp.json[0]['name'], 'SampleSensor')
+ self.assertEqual(resp.json[0]["name"], "SampleSensor")
def test_get_all_negative_limit(self):
- resp = self.app.get('/v1/sensortypes/?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/sensortypes/?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_get_all_filters(self):
- resp = self.app.get('/v1/sensortypes')
+ resp = self.app.get("/v1/sensortypes")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 3)
# ?name filter
- resp = self.app.get('/v1/sensortypes?name=foobar')
+ resp = self.app.get("/v1/sensortypes?name=foobar")
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/sensortypes?name=SampleSensor2')
+ resp = self.app.get("/v1/sensortypes?name=SampleSensor2")
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['name'], 'SampleSensor2')
- self.assertEqual(resp.json[0]['ref'], 'dummy_pack_1.SampleSensor2')
+ self.assertEqual(resp.json[0]["name"], "SampleSensor2")
+ self.assertEqual(resp.json[0]["ref"], "dummy_pack_1.SampleSensor2")
- resp = self.app.get('/v1/sensortypes?name=SampleSensor3')
+ resp = self.app.get("/v1/sensortypes?name=SampleSensor3")
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['name'], 'SampleSensor3')
+ self.assertEqual(resp.json[0]["name"], "SampleSensor3")
# ?pack filter
- resp = self.app.get('/v1/sensortypes?pack=foobar')
+ resp = self.app.get("/v1/sensortypes?pack=foobar")
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/sensortypes?pack=dummy_pack_1')
+ resp = self.app.get("/v1/sensortypes?pack=dummy_pack_1")
self.assertEqual(len(resp.json), 3)
# ?enabled filter
- resp = self.app.get('/v1/sensortypes?enabled=False')
+ resp = self.app.get("/v1/sensortypes?enabled=False")
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['enabled'], False)
+ self.assertEqual(resp.json[0]["enabled"], False)
- resp = self.app.get('/v1/sensortypes?enabled=True')
+ resp = self.app.get("/v1/sensortypes?enabled=True")
self.assertEqual(len(resp.json), 2)
- self.assertEqual(resp.json[0]['enabled'], True)
- self.assertEqual(resp.json[1]['enabled'], True)
+ self.assertEqual(resp.json[0]["enabled"], True)
+ self.assertEqual(resp.json[1]["enabled"], True)
# ?trigger filter
- resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event3')
+ resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event3")
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event3'])
+ self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event3"])
- resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event')
+ resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event")
self.assertEqual(len(resp.json), 2)
- self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event'])
- self.assertEqual(resp.json[1]['trigger_types'], ['dummy_pack_1.event'])
+ self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event"])
+ self.assertEqual(resp.json[1]["trigger_types"], ["dummy_pack_1.event"])
def test_get_one_success(self):
- resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor')
+ resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(resp.json['name'], 'SampleSensor')
- self.assertEqual(resp.json['ref'], 'dummy_pack_1.SampleSensor')
+ self.assertEqual(resp.json["name"], "SampleSensor")
+ self.assertEqual(resp.json["ref"], "dummy_pack_1.SampleSensor")
def test_get_one_doesnt_exist(self):
- resp = self.app.get('/v1/sensortypes/1', expect_errors=True)
+ resp = self.app.get("/v1/sensortypes/1", expect_errors=True)
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_disable_and_enable_sensor(self):
# Verify initial state
- resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor')
+ resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertTrue(resp.json['enabled'])
+ self.assertTrue(resp.json["enabled"])
sensor_data = resp.json
# Disable sensor
data = copy.deepcopy(sensor_data)
- data['enabled'] = False
- put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data)
+ data["enabled"] = False
+ put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data)
self.assertEqual(put_resp.status_int, http_client.OK)
- self.assertEqual(put_resp.json['ref'], 'dummy_pack_1.SampleSensor')
- self.assertFalse(put_resp.json['enabled'])
+ self.assertEqual(put_resp.json["ref"], "dummy_pack_1.SampleSensor")
+ self.assertFalse(put_resp.json["enabled"])
# Verify sensor has been disabled
- resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor')
+ resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertFalse(resp.json['enabled'])
+ self.assertFalse(resp.json["enabled"])
# Enable sensor
data = copy.deepcopy(sensor_data)
- data['enabled'] = True
- put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data)
+ data["enabled"] = True
+ put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data)
self.assertEqual(put_resp.status_int, http_client.OK)
- self.assertTrue(put_resp.json['enabled'])
+ self.assertTrue(put_resp.json["enabled"])
# Verify sensor has been enabled
- resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor')
+ resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertTrue(resp.json['enabled'])
+ self.assertTrue(resp.json["enabled"])
diff --git a/st2api/tests/unit/controllers/v1/test_service_registry.py b/st2api/tests/unit/controllers/v1/test_service_registry.py
index efeb7d432a..d195c2361e 100644
--- a/st2api/tests/unit/controllers/v1/test_service_registry.py
+++ b/st2api/tests/unit/controllers/v1/test_service_registry.py
@@ -22,9 +22,7 @@
from st2tests.api import FunctionalTest
-__all__ = [
- 'ServiceyRegistryControllerTestCase'
-]
+__all__ = ["ServiceyRegistryControllerTestCase"]
class ServiceyRegistryControllerTestCase(FunctionalTest):
@@ -41,10 +39,11 @@ def setUpClass(cls):
# NOTE: We mock call common_setup to emulate service being registered in the service
# registry during bootstrap phase
- register_service_in_service_registry(service='mock_service',
- capabilities={'key1': 'value1',
- 'name': 'mock_service'},
- start_heart=True)
+ register_service_in_service_registry(
+ service="mock_service",
+ capabilities={"key1": "value1", "name": "mock_service"},
+ start_heart=True,
+ )
@classmethod
def tearDownClass(cls):
@@ -53,33 +52,40 @@ def tearDownClass(cls):
coordination.coordinator_teardown(cls.coordinator)
def test_get_groups(self):
- list_resp = self.app.get('/v1/service_registry/groups')
+ list_resp = self.app.get("/v1/service_registry/groups")
self.assertEqual(list_resp.status_int, 200)
- self.assertEqual(list_resp.json, {'groups': ['mock_service']})
+ self.assertEqual(list_resp.json, {"groups": ["mock_service"]})
def test_get_group_members(self):
proc_info = system_info.get_process_info()
member_id = get_member_id()
# 1. Group doesn't exist
- resp = self.app.get('/v1/service_registry/groups/doesnt-exist/members', expect_errors=True)
+ resp = self.app.get(
+ "/v1/service_registry/groups/doesnt-exist/members", expect_errors=True
+ )
self.assertEqual(resp.status_int, 404)
- self.assertEqual(resp.json['faultstring'], 'Group with ID "doesnt-exist" not found.')
+ self.assertEqual(
+ resp.json["faultstring"], 'Group with ID "doesnt-exist" not found.'
+ )
# 2. Group exists and has a single member
- resp = self.app.get('/v1/service_registry/groups/mock_service/members')
+ resp = self.app.get("/v1/service_registry/groups/mock_service/members")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json, {
- 'members': [
- {
- 'group_id': 'mock_service',
- 'member_id': member_id.decode('utf-8'),
- 'capabilities': {
- 'key1': 'value1',
- 'name': 'mock_service',
- 'hostname': proc_info['hostname'],
- 'pid': proc_info['pid']
+ self.assertEqual(
+ resp.json,
+ {
+ "members": [
+ {
+ "group_id": "mock_service",
+ "member_id": member_id.decode("utf-8"),
+ "capabilities": {
+ "key1": "value1",
+ "name": "mock_service",
+ "hostname": proc_info["hostname"],
+ "pid": proc_info["pid"],
+ },
}
- }
- ]
- })
+ ]
+ },
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_timers.py b/st2api/tests/unit/controllers/v1/test_timers.py
index 492c231b10..cb57844539 100644
--- a/st2api/tests/unit/controllers/v1/test_timers.py
+++ b/st2api/tests/unit/controllers/v1/test_timers.py
@@ -17,20 +17,29 @@
import st2common.services.triggers as trigger_service
-with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()):
+with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()):
from st2api.controllers.v1.timers import TimersHolder
from st2common.models.system.common import ResourceReference
from st2tests.base import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-from st2common.constants.triggers import INTERVAL_TIMER_TRIGGER_REF, DATE_TIMER_TRIGGER_REF
+from st2common.constants.triggers import (
+ INTERVAL_TIMER_TRIGGER_REF,
+ DATE_TIMER_TRIGGER_REF,
+)
from st2common.constants.triggers import CRON_TIMER_TRIGGER_REF
from st2tests.api import FunctionalTest
-PACK = 'timers'
+PACK = "timers"
FIXTURES = {
- 'triggers': ['cron1.yaml', 'date1.yaml', 'interval1.yaml', 'interval2.yaml', 'interval3.yaml']
+ "triggers": [
+ "cron1.yaml",
+ "date1.yaml",
+ "interval1.yaml",
+ "interval2.yaml",
+ "interval3.yaml",
+ ]
}
@@ -43,23 +52,28 @@ def setUpClass(cls):
loader = FixturesLoader()
TestTimersHolder.MODELS = loader.load_fixtures(
- fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers']
+ fixtures_pack=PACK, fixtures_dict=FIXTURES
+ )["triggers"]
loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=FIXTURES)
def test_add_trigger(self):
holder = TimersHolder()
for _, model in TestTimersHolder.MODELS.items():
holder.add_trigger(
- ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']),
- trigger=model
+ ref=ResourceReference.to_string_reference(
+ pack=model["pack"], name=model["name"]
+ ),
+ trigger=model,
)
self.assertEqual(len(holder._timers), 5)
def test_remove_trigger(self):
holder = TimersHolder()
- model = TestTimersHolder.MODELS.get('cron1.yaml', None)
+ model = TestTimersHolder.MODELS.get("cron1.yaml", None)
self.assertIsNotNone(model)
- ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name'])
+ ref = ResourceReference.to_string_reference(
+ pack=model["pack"], name=model["name"]
+ )
holder.add_trigger(ref, model)
self.assertEqual(len(holder._timers), 1)
holder.remove_trigger(ref, model)
@@ -69,8 +83,10 @@ def test_get_all(self):
holder = TimersHolder()
for _, model in TestTimersHolder.MODELS.items():
holder.add_trigger(
- ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']),
- trigger=model
+ ref=ResourceReference.to_string_reference(
+ pack=model["pack"], name=model["name"]
+ ),
+ trigger=model,
)
self.assertEqual(len(holder.get_all()), 5)
@@ -78,8 +94,10 @@ def test_get_all_filters_filter_by_type(self):
holder = TimersHolder()
for _, model in TestTimersHolder.MODELS.items():
holder.add_trigger(
- ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']),
- trigger=model
+ ref=ResourceReference.to_string_reference(
+ pack=model["pack"], name=model["name"]
+ ),
+ trigger=model,
)
self.assertEqual(len(holder.get_all(timer_type=INTERVAL_TIMER_TRIGGER_REF)), 3)
self.assertEqual(len(holder.get_all(timer_type=DATE_TIMER_TRIGGER_REF)), 1)
@@ -95,20 +113,23 @@ def setUpClass(cls):
loader = FixturesLoader()
TestTimersController.MODELS = loader.save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers']
+ fixtures_pack=PACK, fixtures_dict=FIXTURES
+ )["triggers"]
def test_timerscontroller_get_one_with_id(self):
- model = TestTimersController.MODELS['interval1.yaml']
+ model = TestTimersController.MODELS["interval1.yaml"]
get_resp = self._do_get_one(model.id)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['parameters'], model['parameters'])
+ self.assertEqual(get_resp.json["parameters"], model["parameters"])
def test_timerscontroller_get_one_with_ref(self):
- model = TestTimersController.MODELS['interval1.yaml']
- ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name'])
+ model = TestTimersController.MODELS["interval1.yaml"]
+ ref = ResourceReference.to_string_reference(
+ pack=model["pack"], name=model["name"]
+ )
get_resp = self._do_get_one(ref)
self.assertEqual(get_resp.status_int, 200)
- self.assertEqual(get_resp.json['parameters'], model['parameters'])
+ self.assertEqual(get_resp.json["parameters"], model["parameters"])
def _do_get_one(self, timer_id, expect_errors=False):
- return self.app.get('/v1/timers/%s' % timer_id, expect_errors=expect_errors)
+ return self.app.get("/v1/timers/%s" % timer_id, expect_errors=expect_errors)
diff --git a/st2api/tests/unit/controllers/v1/test_traces.py b/st2api/tests/unit/controllers/v1/test_traces.py
index 0ce16a2a29..79bbdad6ae 100644
--- a/st2api/tests/unit/controllers/v1/test_traces.py
+++ b/st2api/tests/unit/controllers/v1/test_traces.py
@@ -19,23 +19,24 @@
from st2tests.api import FunctionalTest
from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase
-FIXTURES_PACK = 'traces'
+FIXTURES_PACK = "traces"
TEST_MODELS = {
- 'traces': [
- 'trace_empty.yaml',
- 'trace_one_each.yaml',
- 'trace_multiple_components.yaml'
+ "traces": [
+ "trace_empty.yaml",
+ "trace_one_each.yaml",
+ "trace_multiple_components.yaml",
]
}
-class TracesControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/traces'
+class TracesControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/traces"
controller_cls = TracesController
- include_attribute_field_name = 'trace_tag'
- exclude_attribute_field_name = 'start_timestamp'
+ include_attribute_field_name = "trace_tag"
+ exclude_attribute_field_name = "start_timestamp"
models = None
trace1 = None
@@ -45,112 +46,145 @@ class TracesControllerTestCase(FunctionalTest,
@classmethod
def setUpClass(cls):
super(TracesControllerTestCase, cls).setUpClass()
- cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
- cls.trace1 = cls.models['traces']['trace_empty.yaml']
- cls.trace2 = cls.models['traces']['trace_one_each.yaml']
- cls.trace3 = cls.models['traces']['trace_multiple_components.yaml']
+ cls.models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+ cls.trace1 = cls.models["traces"]["trace_empty.yaml"]
+ cls.trace2 = cls.models["traces"]["trace_one_each.yaml"]
+ cls.trace3 = cls.models["traces"]["trace_multiple_components.yaml"]
def test_get_all_and_minus_one(self):
- resp = self.app.get('/v1/traces')
+ resp = self.app.get("/v1/traces")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.')
+ self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.")
# Note: traces are returned sorted by start_timestamp in descending order by default
- retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json]
- self.assertEqual(retrieved_trace_tags,
- [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
- 'Incorrect traces retrieved.')
-
- resp = self.app.get('/v1/traces/?limit=-1')
+ retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json]
+ self.assertEqual(
+ retrieved_trace_tags,
+ [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
+ "Incorrect traces retrieved.",
+ )
+
+ resp = self.app.get("/v1/traces/?limit=-1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.')
+ self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.")
# Note: traces are returned sorted by start_timestamp in descending order by default
- retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json]
- self.assertEqual(retrieved_trace_tags,
- [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
- 'Incorrect traces retrieved.')
+ retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json]
+ self.assertEqual(
+ retrieved_trace_tags,
+ [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
+ "Incorrect traces retrieved.",
+ )
def test_get_all_ascending_and_descending(self):
- resp = self.app.get('/v1/traces?sort_asc=True')
+ resp = self.app.get("/v1/traces?sort_asc=True")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.')
+ self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.")
- retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json]
- self.assertEqual(retrieved_trace_tags,
- [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag],
- 'Incorrect traces retrieved.')
+ retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json]
+ self.assertEqual(
+ retrieved_trace_tags,
+ [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag],
+ "Incorrect traces retrieved.",
+ )
- resp = self.app.get('/v1/traces?sort_desc=True')
+ resp = self.app.get("/v1/traces?sort_desc=True")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.')
+ self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.")
- retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json]
- self.assertEqual(retrieved_trace_tags,
- [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
- 'Incorrect traces retrieved.')
+ retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json]
+ self.assertEqual(
+ retrieved_trace_tags,
+ [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag],
+ "Incorrect traces retrieved.",
+ )
def test_get_all_limit(self):
- resp = self.app.get('/v1/traces?limit=1')
+ resp = self.app.get("/v1/traces?limit=1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1, '/v1/traces did not return all traces.')
+ self.assertEqual(len(resp.json), 1, "/v1/traces did not return all traces.")
- retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json]
- self.assertEqual(retrieved_trace_tags,
- [self.trace3.trace_tag], 'Incorrect traces retrieved.')
+ retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json]
+ self.assertEqual(
+ retrieved_trace_tags, [self.trace3.trace_tag], "Incorrect traces retrieved."
+ )
def test_get_all_limit_negative_number(self):
- resp = self.app.get('/v1/traces?limit=-22', expect_errors=True)
+ resp = self.app.get("/v1/traces?limit=-22", expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_get_by_id(self):
- resp = self.app.get('/v1/traces/%s' % self.trace1.id)
+ resp = self.app.get("/v1/traces/%s" % self.trace1.id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(resp.json['id'], str(self.trace1.id),
- 'Incorrect trace retrieved.')
+ self.assertEqual(
+ resp.json["id"], str(self.trace1.id), "Incorrect trace retrieved."
+ )
def test_query_by_trace_tag(self):
- resp = self.app.get('/v1/traces?trace_tag=test-trace-1')
+ resp = self.app.get("/v1/traces?trace_tag=test-trace-1")
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1, '/v1/traces?trace_tag=x did not return correct trace.')
+ self.assertEqual(
+ len(resp.json), 1, "/v1/traces?trace_tag=x did not return correct trace."
+ )
- self.assertEqual(resp.json[0]['trace_tag'], self.trace1['trace_tag'],
- 'Correct trace not returned.')
+ self.assertEqual(
+ resp.json[0]["trace_tag"],
+ self.trace1["trace_tag"],
+ "Correct trace not returned.",
+ )
def test_query_by_action_execution(self):
- execution_id = self.trace3['action_executions'][0].object_id
- resp = self.app.get('/v1/traces?execution=%s' % execution_id)
+ execution_id = self.trace3["action_executions"][0].object_id
+ resp = self.app.get("/v1/traces?execution=%s" % execution_id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1,
- '/v1/traces?execution=x did not return correct trace.')
- self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'],
- 'Correct trace not returned.')
+ self.assertEqual(
+ len(resp.json), 1, "/v1/traces?execution=x did not return correct trace."
+ )
+ self.assertEqual(
+ resp.json[0]["trace_tag"],
+ self.trace3["trace_tag"],
+ "Correct trace not returned.",
+ )
def test_query_by_rule(self):
- rule_id = self.trace3['rules'][0].object_id
- resp = self.app.get('/v1/traces?rule=%s' % rule_id)
+ rule_id = self.trace3["rules"][0].object_id
+ resp = self.app.get("/v1/traces?rule=%s" % rule_id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1, '/v1/traces?rule=x did not return correct trace.')
- self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'],
- 'Correct trace not returned.')
+ self.assertEqual(
+ len(resp.json), 1, "/v1/traces?rule=x did not return correct trace."
+ )
+ self.assertEqual(
+ resp.json[0]["trace_tag"],
+ self.trace3["trace_tag"],
+ "Correct trace not returned.",
+ )
def test_query_by_trigger_instance(self):
- trigger_instance_id = self.trace3['trigger_instances'][0].object_id
- resp = self.app.get('/v1/traces?trigger_instance=%s' % trigger_instance_id)
+ trigger_instance_id = self.trace3["trigger_instances"][0].object_id
+ resp = self.app.get("/v1/traces?trigger_instance=%s" % trigger_instance_id)
self.assertEqual(resp.status_int, 200)
- self.assertEqual(len(resp.json), 1,
- '/v1/traces?trigger_instance=x did not return correct trace.')
- self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'],
- 'Correct trace not returned.')
+ self.assertEqual(
+ len(resp.json),
+ 1,
+ "/v1/traces?trigger_instance=x did not return correct trace.",
+ )
+ self.assertEqual(
+ resp.json[0]["trace_tag"],
+ self.trace3["trace_tag"],
+ "Correct trace not returned.",
+ )
def _insert_mock_models(self):
- trace_ids = [trace['id'] for trace in self.models['traces'].values()]
+ trace_ids = [trace["id"] for trace in self.models["traces"].values()]
return trace_ids
def _delete_mock_models(self, object_ids):
diff --git a/st2api/tests/unit/controllers/v1/test_triggerinstances.py b/st2api/tests/unit/controllers/v1/test_triggerinstances.py
index 0d81de723d..2a4149707c 100644
--- a/st2api/tests/unit/controllers/v1/test_triggerinstances.py
+++ b/st2api/tests/unit/controllers/v1/test_triggerinstances.py
@@ -31,13 +31,14 @@
http_client = six.moves.http_client
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
-class TriggerInstanceTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/triggerinstances'
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+class TriggerInstanceTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/triggerinstances"
controller_cls = TriggerInstanceController
- include_attribute_field_name = 'trigger'
- exclude_attribute_field_name = 'payload'
+ include_attribute_field_name = "trigger"
+ exclude_attribute_field_name = "payload"
@classmethod
def setUpClass(cls):
@@ -47,74 +48,84 @@ def setUpClass(cls):
cls._setupTriggerInstance()
def test_get_all(self):
- resp = self.app.get('/v1/triggerinstances')
+ resp = self.app.get("/v1/triggerinstances")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(len(resp.json), self.triggerinstance_count, 'Get all failure.')
+ self.assertEqual(len(resp.json), self.triggerinstance_count, "Get all failure.")
def test_get_all_limit(self):
limit = 1
- resp = self.app.get('/v1/triggerinstances?limit=%d' % limit)
+ resp = self.app.get("/v1/triggerinstances?limit=%d" % limit)
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(len(resp.json), limit, 'Get all failure. Length doesn\'t match limit.')
+ self.assertEqual(
+ len(resp.json), limit, "Get all failure. Length doesn't match limit."
+ )
def test_get_all_limit_negative_number(self):
limit = -22
- resp = self.app.get('/v1/triggerinstances?limit=%d' % limit, expect_errors=True)
+ resp = self.app.get("/v1/triggerinstances?limit=%d" % limit, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- self.assertEqual(resp.json['faultstring'],
- u'Limit, "-22" specified, must be a positive number.')
+ self.assertEqual(
+ resp.json["faultstring"],
+ 'Limit, "-22" specified, must be a positive number.',
+ )
def test_get_all_filter_by_trigger(self):
- trigger = 'dummy_pack_1.st2.test.trigger0'
- resp = self.app.get('/v1/triggerinstances?trigger=%s' % trigger)
+ trigger = "dummy_pack_1.st2.test.trigger0"
+ resp = self.app.get("/v1/triggerinstances?trigger=%s" % trigger)
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(len(resp.json), 1, 'Get all failure. Must get only one such instance.')
+ self.assertEqual(
+ len(resp.json), 1, "Get all failure. Must get only one such instance."
+ )
def test_get_all_filter_by_timestamp(self):
- resp = self.app.get('/v1/triggerinstances')
+ resp = self.app.get("/v1/triggerinstances")
self.assertEqual(resp.status_int, http_client.OK)
- timestamp_largest = resp.json[0]['occurrence_time']
- timestamp_middle = resp.json[1]['occurrence_time']
+ timestamp_largest = resp.json[0]["occurrence_time"]
+ timestamp_middle = resp.json[1]["occurrence_time"]
dt = isotime.parse(timestamp_largest)
dt = dt + datetime.timedelta(seconds=1)
timestamp_largest = isotime.format(dt, offset=False)
- resp = self.app.get('/v1/triggerinstances?timestamp_gt=%s' % timestamp_largest)
+ resp = self.app.get("/v1/triggerinstances?timestamp_gt=%s" % timestamp_largest)
# Since we sort trigger instances by time (latest first), the previous
# get should return no trigger instances.
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/triggerinstances?timestamp_lt=%s' % (timestamp_middle))
+ resp = self.app.get("/v1/triggerinstances?timestamp_lt=%s" % (timestamp_middle))
self.assertEqual(len(resp.json), 1)
def test_get_all_trigger_type_ref_filtering(self):
# 1. Invalid / inexistent trigger type ref
- resp = self.app.get('/v1/triggerinstances?trigger_type=foo.bar.invalid')
+ resp = self.app.get("/v1/triggerinstances?trigger_type=foo.bar.invalid")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 0)
# 2. Valid trigger type ref with corresponding trigger instances
- resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0')
+ resp = self.app.get(
+ "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0"
+ )
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
# 3. Valid trigger type ref with no corresponding trigger instances
- resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3')
+ resp = self.app.get(
+ "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3"
+ )
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 0)
def test_reemit_trigger_instance(self):
- resp = self.app.get('/v1/triggerinstances')
+ resp = self.app.get("/v1/triggerinstances")
self.assertEqual(resp.status_int, http_client.OK)
- instance_id = resp.json[0]['id']
- resp = self.app.post('/v1/triggerinstances/%s/re_emit' % instance_id)
+ instance_id = resp.json[0]["id"]
+ resp = self.app.post("/v1/triggerinstances/%s/re_emit" % instance_id)
self.assertEqual(resp.status_int, http_client.OK)
- resent_message = resp.json['message']
- resent_payload = resp.json['payload']
+ resent_message = resp.json["message"]
+ resent_payload = resp.json["payload"]
self.assertIn(instance_id, resent_message)
- self.assertIn('__context', resent_payload)
- self.assertEqual(resent_payload['__context']['original_id'], instance_id)
+ self.assertIn("__context", resent_payload)
+ self.assertEqual(resent_payload["__context"]["original_id"], instance_id)
def test_get_one(self):
triggerinstance_id = str(self.triggerinstance_1.id)
@@ -133,79 +144,78 @@ def test_get_one(self):
self.assertEqual(self._get_id(resp), triggerinstance_id)
def test_get_one_fail(self):
- resp = self._do_get_one('1')
+ resp = self._do_get_one("1")
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
@classmethod
def _setupTriggerTypes(cls):
TRIGGERTYPE_0 = {
- 'name': 'st2.test.triggertype0',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {}
+ "name": "st2.test.triggertype0",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {},
}
TRIGGERTYPE_1 = {
- 'name': 'st2.test.triggertype1',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
+ "name": "st2.test.triggertype1",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
}
TRIGGERTYPE_2 = {
- 'name': 'st2.test.triggertype2',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {'param1': {'type': 'object'}}
+ "name": "st2.test.triggertype2",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {"param1": {"type": "object"}},
}
TRIGGERTYPE_3 = {
- 'name': 'st2.test.triggertype3',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {'param1': {'type': 'object'}}
+ "name": "st2.test.triggertype3",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {"param1": {"type": "object"}},
}
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False)
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False)
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False)
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_3, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_3, expect_errors=False)
@classmethod
def _setupTriggers(cls):
TRIGGER_0 = {
- 'name': 'st2.test.trigger0',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype0',
- 'parameters': {}
+ "name": "st2.test.trigger0",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype0",
+ "parameters": {},
}
TRIGGER_1 = {
- 'name': 'st2.test.trigger1',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype1',
- 'parameters': {}
+ "name": "st2.test.trigger1",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype1",
+ "parameters": {},
}
TRIGGER_2 = {
- 'name': 'st2.test.trigger2',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype2',
- 'parameters': {
- 'param1': {
- 'foo': 'bar'
- }
- }
+ "name": "st2.test.trigger2",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype2",
+ "parameters": {"param1": {"foo": "bar"}},
}
- cls.app.post_json('/v1/triggers', TRIGGER_0, expect_errors=False)
- cls.app.post_json('/v1/triggers', TRIGGER_1, expect_errors=False)
- cls.app.post_json('/v1/triggers', TRIGGER_2, expect_errors=False)
+ cls.app.post_json("/v1/triggers", TRIGGER_0, expect_errors=False)
+ cls.app.post_json("/v1/triggers", TRIGGER_1, expect_errors=False)
+ cls.app.post_json("/v1/triggers", TRIGGER_2, expect_errors=False)
def _insert_mock_models(self):
- return [self.triggerinstance_1['id'], self.triggerinstance_2['id'],
- self.triggerinstance_3['id']]
+ return [
+ self.triggerinstance_1["id"],
+ self.triggerinstance_2["id"],
+ self.triggerinstance_3["id"],
+ ]
def _delete_mock_models(self, object_ids):
return None
@@ -214,17 +224,20 @@ def _delete_mock_models(self, object_ids):
def _setupTriggerInstance(cls):
cls.triggerinstance_count = 0
cls.triggerinstance_1 = cls._create_trigger_instance(
- trigger_ref='dummy_pack_1.st2.test.trigger0',
- payload={'tp1': 1, 'tp2': 2, 'tp3': 3},
- seconds=1)
+ trigger_ref="dummy_pack_1.st2.test.trigger0",
+ payload={"tp1": 1, "tp2": 2, "tp3": 3},
+ seconds=1,
+ )
cls.triggerinstance_2 = cls._create_trigger_instance(
- trigger_ref='dummy_pack_1.st2.test.trigger1',
- payload={'tp1': 'a', 'tp2': 'b', 'tp3': 'c'},
- seconds=2)
+ trigger_ref="dummy_pack_1.st2.test.trigger1",
+ payload={"tp1": "a", "tp2": "b", "tp3": "c"},
+ seconds=2,
+ )
cls.triggerinstance_3 = cls._create_trigger_instance(
- trigger_ref='dummy_pack_1.st2.test.trigger2',
- payload={'tp1': None, 'tp2': None, 'tp3': None},
- seconds=3)
+ trigger_ref="dummy_pack_1.st2.test.trigger2",
+ payload={"tp1": None, "tp2": None, "tp3": None},
+ seconds=3,
+ )
@classmethod
def _create_trigger_instance(cls, trigger_ref, payload, seconds):
@@ -244,7 +257,9 @@ def _create_trigger_instance(cls, trigger_ref, payload, seconds):
@staticmethod
def _get_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def _do_get_one(self, triggerinstance_id):
- return self.app.get('/v1/triggerinstances/%s' % triggerinstance_id, expect_errors=True)
+ return self.app.get(
+ "/v1/triggerinstances/%s" % triggerinstance_id, expect_errors=True
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_triggers.py b/st2api/tests/unit/controllers/v1/test_triggers.py
index d3526e624a..5067c7674f 100644
--- a/st2api/tests/unit/controllers/v1/test_triggers.py
+++ b/st2api/tests/unit/controllers/v1/test_triggers.py
@@ -22,57 +22,52 @@
http_client = six.moves.http_client
TRIGGER_0 = {
- 'name': 'st2.test.trigger0',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype0',
- 'parameters': {}
+ "name": "st2.test.trigger0",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype0",
+ "parameters": {},
}
TRIGGER_1 = {
- 'name': 'st2.test.trigger1',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype1',
- 'parameters': {}
+ "name": "st2.test.trigger1",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype1",
+ "parameters": {},
}
TRIGGER_2 = {
- 'name': 'st2.test.trigger2',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'type': 'dummy_pack_1.st2.test.triggertype2',
- 'parameters': {
- 'param1': {
- 'foo': 'bar'
- }
- }
+ "name": "st2.test.trigger2",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "type": "dummy_pack_1.st2.test.triggertype2",
+ "parameters": {"param1": {"foo": "bar"}},
}
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class TestTriggerController(FunctionalTest):
-
@classmethod
def setUpClass(cls):
super(TestTriggerController, cls).setUpClass()
cls._setupTriggerTypes()
def test_get_all(self):
- resp = self.app.get('/v1/triggers')
+ resp = self.app.get("/v1/triggers")
self.assertEqual(resp.status_int, http_client.OK)
# TriggerType without parameters will register a trigger
# with same name.
- self.assertEqual(len(resp.json), 2, 'Get all failure. %s' % resp.json)
+ self.assertEqual(len(resp.json), 2, "Get all failure. %s" % resp.json)
post_resp = self._do_post(TRIGGER_0)
trigger_id_0 = self._get_trigger_id(post_resp)
post_resp = self._do_post(TRIGGER_1)
trigger_id_1 = self._get_trigger_id(post_resp)
- resp = self.app.get('/v1/triggers')
+ resp = self.app.get("/v1/triggers")
self.assertEqual(resp.status_int, http_client.OK)
# TriggerType without parameters will register a trigger
# with same name. So here we see 4 instead of 2.
- self.assertEqual(len(resp.json), 4, 'Get all failure.')
+ self.assertEqual(len(resp.json), 4, "Get all failure.")
self._do_delete(trigger_id_0)
self._do_delete(trigger_id_1)
@@ -85,7 +80,7 @@ def test_get_one(self):
self._do_delete(trigger_id)
def test_get_one_fail(self):
- resp = self._do_get_one('1')
+ resp = self._do_get_one("1")
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_post(self):
@@ -106,13 +101,15 @@ def test_post_duplicate(self):
# id is same in both cases.
post_resp_2 = self._do_post(TRIGGER_1)
self.assertEqual(post_resp_2.status_int, http_client.CREATED)
- self.assertEqual(self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2))
+ self.assertEqual(
+ self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2)
+ )
self._do_delete(self._get_trigger_id(post_resp))
def test_put(self):
post_resp = self._do_post(TRIGGER_1)
update_input = post_resp.json
- update_input['description'] = 'updated description.'
+ update_input["description"] = "updated description."
put_resp = self._do_put(self._get_trigger_id(post_resp), update_input)
self.assertEqual(put_resp.status_int, http_client.OK)
self._do_delete(self._get_trigger_id(put_resp))
@@ -133,41 +130,43 @@ def test_delete(self):
@classmethod
def _setupTriggerTypes(cls):
TRIGGERTYPE_0 = {
- 'name': 'st2.test.triggertype0',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {}
+ "name": "st2.test.triggertype0",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {},
}
TRIGGERTYPE_1 = {
- 'name': 'st2.test.triggertype1',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
+ "name": "st2.test.triggertype1",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
}
TRIGGERTYPE_2 = {
- 'name': 'st2.test.triggertype2',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {'param1': {'type': 'object'}}
+ "name": "st2.test.triggertype2",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {"param1": {"type": "object"}},
}
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False)
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False)
- cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False)
+ cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False)
@staticmethod
def _get_trigger_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def _do_get_one(self, trigger_id):
- return self.app.get('/v1/triggers/%s' % trigger_id, expect_errors=True)
+ return self.app.get("/v1/triggers/%s" % trigger_id, expect_errors=True)
def _do_post(self, trigger):
- return self.app.post_json('/v1/triggers', trigger, expect_errors=True)
+ return self.app.post_json("/v1/triggers", trigger, expect_errors=True)
def _do_put(self, trigger_id, trigger):
- return self.app.put_json('/v1/triggers/%s' % trigger_id, trigger, expect_errors=True)
+ return self.app.put_json(
+ "/v1/triggers/%s" % trigger_id, trigger, expect_errors=True
+ )
def _do_delete(self, trigger_id):
- return self.app.delete('/v1/triggers/%s' % trigger_id)
+ return self.app.delete("/v1/triggers/%s" % trigger_id)
diff --git a/st2api/tests/unit/controllers/v1/test_triggertypes.py b/st2api/tests/unit/controllers/v1/test_triggertypes.py
index c7848f5c2d..414fc34360 100644
--- a/st2api/tests/unit/controllers/v1/test_triggertypes.py
+++ b/st2api/tests/unit/controllers/v1/test_triggertypes.py
@@ -23,33 +23,34 @@
http_client = six.moves.http_client
TRIGGER_0 = {
- 'name': 'st2.test.triggertype0',
- 'pack': 'dummy_pack_1',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {}
+ "name": "st2.test.triggertype0",
+ "pack": "dummy_pack_1",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {},
}
TRIGGER_1 = {
- 'name': 'st2.test.triggertype1',
- 'pack': 'dummy_pack_2',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
+ "name": "st2.test.triggertype1",
+ "pack": "dummy_pack_2",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
}
TRIGGER_2 = {
- 'name': 'st2.test.triggertype3',
- 'pack': 'dummy_pack_3',
- 'description': 'test trigger',
- 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None},
- 'parameters_schema': {'param1': {'type': 'object'}}
+ "name": "st2.test.triggertype3",
+ "pack": "dummy_pack_3",
+ "description": "test trigger",
+ "payload_schema": {"tp1": None, "tp2": None, "tp3": None},
+ "parameters_schema": {"param1": {"type": "object"}},
}
-class TriggerTypeControllerTestCase(FunctionalTest,
- APIControllerWithIncludeAndExcludeFilterTestCase):
- get_all_path = '/v1/triggertypes'
+class TriggerTypeControllerTestCase(
+ FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase
+):
+ get_all_path = "/v1/triggertypes"
controller_cls = TriggerTypeController
- include_attribute_field_name = 'payload_schema'
- exclude_attribute_field_name = 'parameters_schema'
+ include_attribute_field_name = "payload_schema"
+ exclude_attribute_field_name = "parameters_schema"
@classmethod
def setUpClass(cls):
@@ -71,19 +72,19 @@ def test_get_all(self):
trigger_id_0 = self.__get_trigger_id(post_resp)
post_resp = self.__do_post(TRIGGER_1)
trigger_id_1 = self.__get_trigger_id(post_resp)
- resp = self.app.get('/v1/triggertypes')
+ resp = self.app.get("/v1/triggertypes")
self.assertEqual(resp.status_int, http_client.OK)
- self.assertEqual(len(resp.json), 2, 'Get all failure.')
+ self.assertEqual(len(resp.json), 2, "Get all failure.")
# ?pack query filter
- resp = self.app.get('/v1/triggertypes?pack=doesnt-exist-invalid')
+ resp = self.app.get("/v1/triggertypes?pack=doesnt-exist-invalid")
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 0)
- resp = self.app.get('/v1/triggertypes?pack=%s' % (TRIGGER_0['pack']))
+ resp = self.app.get("/v1/triggertypes?pack=%s" % (TRIGGER_0["pack"]))
self.assertEqual(resp.status_int, http_client.OK)
self.assertEqual(len(resp.json), 1)
- self.assertEqual(resp.json[0]['pack'], TRIGGER_0['pack'])
+ self.assertEqual(resp.json[0]["pack"], TRIGGER_0["pack"])
self.__do_delete(trigger_id_0)
self.__do_delete(trigger_id_1)
@@ -97,7 +98,7 @@ def test_get_one(self):
self.__do_delete(trigger_id)
def test_get_one_fail(self):
- resp = self.__do_get_one('1')
+ resp = self.__do_get_one("1")
self.assertEqual(resp.status_int, http_client.NOT_FOUND)
def test_post(self):
@@ -116,13 +117,13 @@ def test_post_duplicate(self):
self.assertEqual(post_resp.status_int, http_client.CREATED)
post_resp_2 = self.__do_post(TRIGGER_1)
self.assertEqual(post_resp_2.status_int, http_client.CONFLICT)
- self.assertEqual(post_resp_2.json['conflict-id'], org_id)
+ self.assertEqual(post_resp_2.json["conflict-id"], org_id)
self.__do_delete(org_id)
def test_put(self):
post_resp = self.__do_post(TRIGGER_1)
update_input = post_resp.json
- update_input['description'] = 'updated description.'
+ update_input["description"] = "updated description."
put_resp = self.__do_put(self.__get_trigger_id(post_resp), update_input)
self.assertEqual(put_resp.status_int, http_client.OK)
self.__do_delete(self.__get_trigger_id(put_resp))
@@ -151,16 +152,18 @@ def _do_delete(self, trigger_id):
@staticmethod
def __get_trigger_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def __do_get_one(self, trigger_id):
- return self.app.get('/v1/triggertypes/%s' % trigger_id, expect_errors=True)
+ return self.app.get("/v1/triggertypes/%s" % trigger_id, expect_errors=True)
def __do_post(self, trigger):
- return self.app.post_json('/v1/triggertypes', trigger, expect_errors=True)
+ return self.app.post_json("/v1/triggertypes", trigger, expect_errors=True)
def __do_put(self, trigger_id, trigger):
- return self.app.put_json('/v1/triggertypes/%s' % trigger_id, trigger, expect_errors=True)
+ return self.app.put_json(
+ "/v1/triggertypes/%s" % trigger_id, trigger, expect_errors=True
+ )
def __do_delete(self, trigger_id):
- return self.app.delete('/v1/triggertypes/%s' % trigger_id)
+ return self.app.delete("/v1/triggertypes/%s" % trigger_id)
diff --git a/st2api/tests/unit/controllers/v1/test_webhooks.py b/st2api/tests/unit/controllers/v1/test_webhooks.py
index 487830a092..e8fedc673c 100644
--- a/st2api/tests/unit/controllers/v1/test_webhooks.py
+++ b/st2api/tests/unit/controllers/v1/test_webhooks.py
@@ -21,7 +21,7 @@
import st2common.services.triggers as trigger_service
-with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()):
+with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()):
from st2api.controllers.v1.webhooks import WebhooksController, HooksHolder
from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES
@@ -34,28 +34,20 @@
http_client = six.moves.http_client
-WEBHOOK_1 = {
- 'action': 'closed',
- 'pull_request': {
- 'merged': True
- }
-}
+WEBHOOK_1 = {"action": "closed", "pull_request": {"merged": True}}
ST2_WEBHOOK = {
- 'trigger': 'git.pr-merged',
- 'payload': {
- 'value_str': 'string!',
- 'value_int': 12345
- }
+ "trigger": "git.pr-merged",
+ "payload": {"value_str": "string!", "value_int": 12345},
}
WEBHOOK_DATA = {
- 'value_str': 'test string 1',
- 'value_int': 987654,
+ "value_str": "test string 1",
+ "value_int": 987654,
}
# 1. Trigger which references a system webhook trigger type
-DUMMY_TRIGGER_DB = TriggerDB(name='pr-merged', pack='git')
+DUMMY_TRIGGER_DB = TriggerDB(name="pr-merged", pack="git")
DUMMY_TRIGGER_DB.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0]
@@ -63,34 +55,24 @@
DUMMY_TRIGGER_DICT = vars(DUMMY_TRIGGER_API)
# 2. Custom TriggerType object
-DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name='pr-merged', pack='git')
+DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name="pr-merged", pack="git")
DUMMY_TRIGGER_TYPE_DB.payload_schema = {
- 'type': 'object',
- 'properties': {
- 'body': {
- 'properties': {
- 'value_str': {
- 'type': 'string',
- 'required': True
- },
- 'value_int': {
- 'type': 'integer',
- 'required': True
- }
+ "type": "object",
+ "properties": {
+ "body": {
+ "properties": {
+ "value_str": {"type": "string", "required": True},
+ "value_int": {"type": "integer", "required": True},
}
}
- }
+ },
}
# 2. Custom TriggerType object
-DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name='pr-merged', pack='git')
+DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name="pr-merged", pack="git")
DUMMY_TRIGGER_TYPE_DB_2.payload_schema = {
- 'type': 'object',
- 'properties': {
- 'body': {
- 'type': 'array'
- }
- }
+ "type": "object",
+ "properties": {"body": {"type": "array"}},
}
@@ -100,190 +82,244 @@ def setUp(self):
cfg.CONF.system.validate_trigger_payload = True
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_all', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder, "get_all", mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT])
+ )
def test_get_all(self):
- get_resp = self.app.get('/v1/webhooks', expect_errors=False)
+ get_resp = self.app.get("/v1/webhooks", expect_errors=False)
self.assertEqual(get_resp.status_int, http_client.OK)
self.assertEqual(get_resp.json, [DUMMY_TRIGGER_DICT])
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_post(self, dispatch_mock):
- post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False)
+ post_resp = self.__do_post("git", WEBHOOK_1, expect_errors=False)
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag)
-
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock(
- return_value=DUMMY_TRIGGER_TYPE_DB))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag)
+
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_post_with_trace(self, dispatch_mock):
- post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False,
- headers={'St2-Trace-Tag': 'tag1'})
+ post_resp = self.__do_post(
+ "git", WEBHOOK_1, expect_errors=False, headers={"St2-Trace-Tag": "tag1"}
+ )
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1')
+ self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1")
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
def test_post_hook_not_registered(self):
- post_resp = self.__do_post('foo', WEBHOOK_1, expect_errors=True)
+ post_resp = self.__do_post("foo", WEBHOOK_1, expect_errors=True)
self.assertEqual(post_resp.status_int, http_client.NOT_FOUND)
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock(
- return_value=DUMMY_TRIGGER_TYPE_DB))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_st2_webhook_success(self, dispatch_mock):
- post_resp = self.__do_post('st2', ST2_WEBHOOK)
+ post_resp = self.__do_post("st2", ST2_WEBHOOK)
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag)
+ self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag)
- post_resp = self.__do_post('st2/', ST2_WEBHOOK)
+ post_resp = self.__do_post("st2/", ST2_WEBHOOK)
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock(
- return_value=DUMMY_TRIGGER_TYPE_DB))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_st2_webhook_failure_payload_validation_failed(self, dispatch_mock):
- data = {
- 'trigger': 'git.pr-merged',
- 'payload': 'invalid'
- }
- post_resp = self.__do_post('st2', data, expect_errors=True)
+ data = {"trigger": "git.pr-merged", "payload": "invalid"}
+ post_resp = self.__do_post("st2", data, expect_errors=True)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- expected_msg = 'Trigger payload validation failed'
- self.assertIn(expected_msg, post_resp.json['faultstring'])
+ expected_msg = "Trigger payload validation failed"
+ self.assertIn(expected_msg, post_resp.json["faultstring"])
expected_msg = "'invalid' is not of type 'object'"
- self.assertIn(expected_msg, post_resp.json['faultstring'])
-
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock(
- return_value=DUMMY_TRIGGER_TYPE_DB))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ self.assertIn(expected_msg, post_resp.json["faultstring"])
+
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_st2_webhook_with_trace(self, dispatch_mock):
- post_resp = self.__do_post('st2', ST2_WEBHOOK, headers={'St2-Trace-Tag': 'tag1'})
+ post_resp = self.__do_post(
+ "st2", ST2_WEBHOOK, headers={"St2-Trace-Tag": "tag1"}
+ )
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1')
+ self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1")
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
def test_st2_webhook_body_missing_trigger(self):
- post_resp = self.__do_post('st2', {'payload': {}}, expect_errors=True)
- self.assertIn('Trigger not specified.', post_resp)
+ post_resp = self.__do_post("st2", {"payload": {}}, expect_errors=True)
+ self.assertIn("Trigger not specified.", post_resp)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_json_request_body(self, dispatch_mock):
# 1. Send JSON using application/json content type
data = WEBHOOK_1
- post_resp = self.__do_post('git', data,
- headers={'St2-Trace-Tag': 'tag1'})
+ post_resp = self.__do_post("git", data, headers={"St2-Trace-Tag": "tag1"})
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'],
- 'application/json')
- self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data)
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1')
+ self.assertEqual(
+ dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"],
+ "application/json",
+ )
+ self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data)
+ self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1")
# 2. Send JSON using application/json + charset content type
data = WEBHOOK_1
- headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json; charset=utf-8'}
- post_resp = self.__do_post('git', data,
- headers=headers)
+ headers = {
+ "St2-Trace-Tag": "tag1",
+ "Content-Type": "application/json; charset=utf-8",
+ }
+ post_resp = self.__do_post("git", data, headers=headers)
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'],
- 'application/json; charset=utf-8')
- self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data)
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1')
+ self.assertEqual(
+ dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"],
+ "application/json; charset=utf-8",
+ )
+ self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data)
+ self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1")
# 3. JSON content type, invalid JSON body
- data = 'invalid'
- headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json'}
- post_resp = self.app.post('/v1/webhooks/git', data, headers=headers,
- expect_errors=True)
+ data = "invalid"
+ headers = {"St2-Trace-Tag": "tag1", "Content-Type": "application/json"}
+ post_resp = self.app.post(
+ "/v1/webhooks/git", data, headers=headers, expect_errors=True
+ )
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('Failed to parse request body', post_resp)
-
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ self.assertIn("Failed to parse request body", post_resp)
+
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_form_encoded_request_body(self, dispatch_mock):
# Send request body as form urlencoded data
if six.PY3:
- data = {b'form': [b'test']}
+ data = {b"form": [b"test"]}
else:
- data = {'form': ['test']}
+ data = {"form": ["test"]}
headers = {
- 'Content-Type': 'application/x-www-form-urlencoded',
- 'St2-Trace-Tag': 'tag1'
+ "Content-Type": "application/x-www-form-urlencoded",
+ "St2-Trace-Tag": "tag1",
}
- self.app.post('/v1/webhooks/git', data, headers=headers)
- self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'],
- 'application/x-www-form-urlencoded')
- self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data)
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1')
+ self.app.post("/v1/webhooks/git", data, headers=headers)
+ self.assertEqual(
+ dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"],
+ "application/x-www-form-urlencoded",
+ )
+ self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data)
+ self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1")
def test_unsupported_content_type(self):
# Invalid / unsupported content type - should throw
data = WEBHOOK_1
- headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'foo/invalid'}
- post_resp = self.app.post('/v1/webhooks/git', json.dumps(data), headers=headers,
- expect_errors=True)
+ headers = {"St2-Trace-Tag": "tag1", "Content-Type": "foo/invalid"}
+ post_resp = self.app.post(
+ "/v1/webhooks/git", json.dumps(data), headers=headers, expect_errors=True
+ )
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- self.assertIn('Failed to parse request body', post_resp)
- self.assertIn('Unsupported Content-Type', post_resp)
-
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock(
- return_value=DUMMY_TRIGGER_TYPE_DB_2))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ self.assertIn("Failed to parse request body", post_resp)
+ self.assertIn("Unsupported Content-Type", post_resp)
+
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB_2),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_custom_webhook_array_input_type(self, _):
- post_resp = self.__do_post('sample', [{'foo': 'bar'}])
+ post_resp = self.__do_post("sample", [{"foo": "bar"}])
self.assertEqual(post_resp.status_int, http_client.ACCEPTED)
- self.assertEqual(post_resp.json, [{'foo': 'bar'}])
+ self.assertEqual(post_resp.json, [{"foo": "bar"}])
def test_st2_webhook_array_webhook_array_input_type_not_valid(self):
- post_resp = self.__do_post('st2', [{'foo': 'bar'}], expect_errors=True)
+ post_resp = self.__do_post("st2", [{"foo": "bar"}], expect_errors=True)
self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST)
- self.assertEqual(post_resp.json['faultstring'],
- 'Webhook body needs to be an object, got: array')
+ self.assertEqual(
+ post_resp.json["faultstring"],
+ "Webhook body needs to be an object, got: array",
+ )
def test_leading_trailing_slashes(self):
# Ideally the test should setup fixtures in DB. However, the triggerwatcher
@@ -296,52 +332,65 @@ def test_leading_trailing_slashes(self):
# require hacking into the test app and force dependency on pecan internals.
# TLDR; sorry for the ghetto test. Not sure how else to test this as a unit test.
def get_webhook_trigger(name, url):
- trigger = TriggerDB(name=name, pack='test')
+ trigger = TriggerDB(name=name, pack="test")
trigger.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0]
- trigger.parameters = {'url': url}
+ trigger.parameters = {"url": url}
return trigger
test_triggers = [
- get_webhook_trigger('no_slash', 'no_slash'),
- get_webhook_trigger('with_leading_slash', '/with_leading_slash'),
- get_webhook_trigger('with_trailing_slash', '/with_trailing_slash/'),
- get_webhook_trigger('with_leading_trailing_slash', '/with_leading_trailing_slash/'),
- get_webhook_trigger('with_mixed_slash', '/with/mixed/slash/')
+ get_webhook_trigger("no_slash", "no_slash"),
+ get_webhook_trigger("with_leading_slash", "/with_leading_slash"),
+ get_webhook_trigger("with_trailing_slash", "/with_trailing_slash/"),
+ get_webhook_trigger(
+ "with_leading_trailing_slash", "/with_leading_trailing_slash/"
+ ),
+ get_webhook_trigger("with_mixed_slash", "/with/mixed/slash/"),
]
controller = WebhooksController()
for trigger in test_triggers:
controller.add_trigger(trigger)
- self.assertTrue(controller._is_valid_hook('no_slash'))
- self.assertFalse(controller._is_valid_hook('/no_slash'))
- self.assertTrue(controller._is_valid_hook('with_leading_slash'))
- self.assertTrue(controller._is_valid_hook('with_trailing_slash'))
- self.assertTrue(controller._is_valid_hook('with_leading_trailing_slash'))
- self.assertTrue(controller._is_valid_hook('with/mixed/slash'))
-
- @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock(
- return_value=True))
- @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock(
- return_value=True))
- @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock(
- return_value=[DUMMY_TRIGGER_DICT]))
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ self.assertTrue(controller._is_valid_hook("no_slash"))
+ self.assertFalse(controller._is_valid_hook("/no_slash"))
+ self.assertTrue(controller._is_valid_hook("with_leading_slash"))
+ self.assertTrue(controller._is_valid_hook("with_trailing_slash"))
+ self.assertTrue(controller._is_valid_hook("with_leading_trailing_slash"))
+ self.assertTrue(controller._is_valid_hook("with/mixed/slash"))
+
+ @mock.patch.object(
+ TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True)
+ )
+ @mock.patch.object(
+ HooksHolder,
+ "get_triggers_for_hook",
+ mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]),
+ )
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_authentication_headers_should_be_removed(self, dispatch_mock):
headers = {
- 'Content-Type': 'application/x-www-form-urlencoded',
- 'St2-Api-Key': 'foobar',
- 'X-Auth-Token': 'deadbeaf',
- 'Cookie': 'foo=bar'
+ "Content-Type": "application/x-www-form-urlencoded",
+ "St2-Api-Key": "foobar",
+ "X-Auth-Token": "deadbeaf",
+ "Cookie": "foo=bar",
}
- self.app.post('/v1/webhooks/git', WEBHOOK_1, headers=headers)
- self.assertNotIn('St2-Api-Key', dispatch_mock.call_args[1]['payload']['headers'])
- self.assertNotIn('X-Auth-Token', dispatch_mock.call_args[1]['payload']['headers'])
- self.assertNotIn('Cookie', dispatch_mock.call_args[1]['payload']['headers'])
+ self.app.post("/v1/webhooks/git", WEBHOOK_1, headers=headers)
+ self.assertNotIn(
+ "St2-Api-Key", dispatch_mock.call_args[1]["payload"]["headers"]
+ )
+ self.assertNotIn(
+ "X-Auth-Token", dispatch_mock.call_args[1]["payload"]["headers"]
+ )
+ self.assertNotIn("Cookie", dispatch_mock.call_args[1]["payload"]["headers"])
def __do_post(self, hook, webhook, expect_errors=False, headers=None):
- return self.app.post_json('/v1/webhooks/' + hook,
- params=webhook,
- expect_errors=expect_errors,
- headers=headers)
+ return self.app.post_json(
+ "/v1/webhooks/" + hook,
+ params=webhook,
+ expect_errors=expect_errors,
+ headers=headers,
+ )
diff --git a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py
index 3b45421d79..91e251fe9d 100644
--- a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py
+++ b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py
@@ -22,13 +22,17 @@
from st2tests.api import FunctionalTest
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
-PACKS = [TEST_PACK_PATH, st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core']
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
+PACKS = [
+ TEST_PACK_PATH,
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
+]
class WorkflowInspectionControllerTest(FunctionalTest, st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(WorkflowInspectionControllerTest, cls).setUpClass()
@@ -39,8 +43,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -48,14 +51,14 @@ def setUpClass(cls):
def _do_post(self, wf_def, expect_errors=False):
return self.app.post(
- '/v1/workflows/inspect',
+ "/v1/workflows/inspect",
wf_def,
expect_errors=expect_errors,
- content_type='text/plain'
+ content_type="text/plain",
)
def test_inspection(self):
- wf_file = 'sequential.yaml'
+ wf_file = "sequential.yaml"
wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta)
@@ -65,48 +68,48 @@ def test_inspection(self):
self.assertListEqual(response.json, expected_errors)
def test_inspection_return_errors(self):
- wf_file = 'fail-inspection.yaml'
+ wf_file = "fail-inspection.yaml"
wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file)
wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta)
expected_errors = [
{
- 'type': 'content',
- 'message': 'The action "std.noop" is not registered in the database.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action',
- 'spec_path': 'tasks.task3.action'
+ "type": "content",
+ "message": 'The action "std.noop" is not registered in the database.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action",
+ "spec_path": "tasks.task3.action",
},
{
- 'type': 'context',
- 'language': 'yaql',
- 'expression': '<% ctx().foobar %>',
- 'message': 'Variable "foobar" is referenced before assignment.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input',
- 'spec_path': 'tasks.task1.input',
+ "type": "context",
+ "language": "yaql",
+ "expression": "<% ctx().foobar %>",
+ "message": 'Variable "foobar" is referenced before assignment.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input",
+ "spec_path": "tasks.task1.input",
},
{
- 'type': 'expression',
- 'language': 'yaql',
- 'expression': '<% <% succeeded() %>',
- 'message': (
- 'Parse error: unexpected \'<\' at '
- 'position 0 of expression \'<% succeeded()\''
+ "type": "expression",
+ "language": "yaql",
+ "expression": "<% <% succeeded() %>",
+ "message": (
+ "Parse error: unexpected '<' at "
+ "position 0 of expression '<% succeeded()'"
),
- 'schema_path': (
- r'properties.tasks.patternProperties.^\w+$.'
- 'properties.next.items.properties.when'
+ "schema_path": (
+ r"properties.tasks.patternProperties.^\w+$."
+ "properties.next.items.properties.when"
),
- 'spec_path': 'tasks.task2.next[0].when'
+ "spec_path": "tasks.task2.next[0].when",
},
{
- 'type': 'syntax',
- 'message': (
- '[{\'cmd\': \'echo <% ctx().macro %>\'}] is '
- 'not valid under any of the given schemas'
+ "type": "syntax",
+ "message": (
+ "[{'cmd': 'echo <% ctx().macro %>'}] is "
+ "not valid under any of the given schemas"
),
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf',
- 'spec_path': 'tasks.task2.input'
- }
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf",
+ "spec_path": "tasks.task2.input",
+ },
]
response = self._do_post(wf_def, expect_errors=False)
diff --git a/st2api/tests/unit/test_validation_utils.py b/st2api/tests/unit/test_validation_utils.py
index eaf1cd75a5..bad17b22a5 100644
--- a/st2api/tests/unit/test_validation_utils.py
+++ b/st2api/tests/unit/test_validation_utils.py
@@ -19,9 +19,7 @@
from st2api.validation import validate_rbac_is_correctly_configured
from st2tests import config as tests_config
-__all__ = [
- 'ValidationUtilsTestCase'
-]
+__all__ = ["ValidationUtilsTestCase"]
class ValidationUtilsTestCase(unittest2.TestCase):
@@ -34,26 +32,34 @@ def test_validate_rbac_is_correctly_configured_succcess(self):
self.assertTrue(result)
def test_validate_rbac_is_correctly_configured_auth_not_enabled(self):
- cfg.CONF.set_override(group='rbac', name='enable', override=True)
- cfg.CONF.set_override(group='auth', name='enable', override=False)
+ cfg.CONF.set_override(group="rbac", name="enable", override=True)
+ cfg.CONF.set_override(group="auth", name="enable", override=False)
- expected_msg = ('Authentication is not enabled. RBAC only works when authentication is '
- 'enabled. You can either enable authentication or disable RBAC.')
- self.assertRaisesRegexp(ValueError, expected_msg,
- validate_rbac_is_correctly_configured)
+ expected_msg = (
+ "Authentication is not enabled. RBAC only works when authentication is "
+ "enabled. You can either enable authentication or disable RBAC."
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, validate_rbac_is_correctly_configured
+ )
def test_validate_rbac_is_correctly_configured_non_default_backend_set(self):
- cfg.CONF.set_override(group='rbac', name='enable', override=True)
- cfg.CONF.set_override(group='rbac', name='backend', override='invalid')
- cfg.CONF.set_override(group='auth', name='enable', override=True)
-
- expected_msg = ('You have enabled RBAC, but RBAC backend is not set to "default".')
- self.assertRaisesRegexp(ValueError, expected_msg,
- validate_rbac_is_correctly_configured)
-
- def test_validate_rbac_is_correctly_configured_default_backend_available_success(self):
- cfg.CONF.set_override(group='rbac', name='enable', override=True)
- cfg.CONF.set_override(group='rbac', name='backend', override='default')
- cfg.CONF.set_override(group='auth', name='enable', override=True)
+ cfg.CONF.set_override(group="rbac", name="enable", override=True)
+ cfg.CONF.set_override(group="rbac", name="backend", override="invalid")
+ cfg.CONF.set_override(group="auth", name="enable", override=True)
+
+ expected_msg = (
+ 'You have enabled RBAC, but RBAC backend is not set to "default".'
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, validate_rbac_is_correctly_configured
+ )
+
+ def test_validate_rbac_is_correctly_configured_default_backend_available_success(
+ self,
+ ):
+ cfg.CONF.set_override(group="rbac", name="enable", override=True)
+ cfg.CONF.set_override(group="rbac", name="backend", override="default")
+ cfg.CONF.set_override(group="auth", name="enable", override=True)
result = validate_rbac_is_correctly_configured()
self.assertTrue(result)
diff --git a/st2auth/dist_utils.py b/st2auth/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2auth/dist_utils.py
+++ b/st2auth/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2auth/setup.py b/st2auth/setup.py
index f77ee72f03..c6e266472b 100644
--- a/st2auth/setup.py
+++ b/st2auth/setup.py
@@ -22,9 +22,9 @@
from dist_utils import apply_vagrant_workaround
from st2auth import __version__
-ST2_COMPONENT = 'st2auth'
+ST2_COMPONENT = "st2auth"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -33,23 +33,21 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- scripts=[
- 'bin/st2auth'
- ],
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ scripts=["bin/st2auth"],
entry_points={
- 'st2auth.sso.backends': [
- 'noop = st2auth.sso.noop:NoOpSingleSignOnBackend'
- ]
- }
+ "st2auth.sso.backends": ["noop = st2auth.sso.noop:NoOpSingleSignOnBackend"]
+ },
)
diff --git a/st2auth/st2auth/__init__.py b/st2auth/st2auth/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2auth/st2auth/__init__.py
+++ b/st2auth/st2auth/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2auth/st2auth/app.py b/st2auth/st2auth/app.py
index 3398104c6c..b9b8f7d595 100644
--- a/st2auth/st2auth/app.py
+++ b/st2auth/st2auth/app.py
@@ -36,34 +36,38 @@
def setup_app(config=None):
config = config or {}
- LOG.info('Creating st2auth: %s as OpenAPI app.', VERSION_STRING)
+ LOG.info("Creating st2auth: %s as OpenAPI app.", VERSION_STRING)
- is_gunicorn = config.get('is_gunicorn', False)
+ is_gunicorn = config.get("is_gunicorn", False)
if is_gunicorn:
# NOTE: We only want to perform this logic in the WSGI worker
st2auth_config.register_opts()
capabilities = {
- 'name': 'auth',
- 'listen_host': cfg.CONF.auth.host,
- 'listen_port': cfg.CONF.auth.port,
- 'listen_ssl': cfg.CONF.auth.use_ssl,
- 'type': 'active'
+ "name": "auth",
+ "listen_host": cfg.CONF.auth.host,
+ "listen_port": cfg.CONF.auth.port,
+ "listen_ssl": cfg.CONF.auth.use_ssl,
+ "type": "active",
}
# This should be called in gunicorn case because we only want
# workers to connect to db, rabbbitmq etc. In standalone HTTP
# server case, this setup would have already occurred.
- common_setup(service='auth', config=st2auth_config, setup_db=True,
- register_mq_exchanges=False,
- register_signal_handlers=True,
- register_internal_trigger_types=False,
- run_migrations=False,
- service_registry=True,
- capabilities=capabilities,
- config_args=config.get('config_args', None))
+ common_setup(
+ service="auth",
+ config=st2auth_config,
+ setup_db=True,
+ register_mq_exchanges=False,
+ register_signal_handlers=True,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ service_registry=True,
+ capabilities=capabilities,
+ config_args=config.get("config_args", None),
+ )
# pysaml2 uses subprocess communicate which calls communicate_with_poll
- if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == 'saml2':
+ if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == "saml2":
use_select_poll_workaround(nose_only=False)
# Additional pre-run time checks
@@ -71,10 +75,8 @@ def setup_app(config=None):
router = Router(debug=cfg.CONF.auth.debug, is_gunicorn=is_gunicorn)
- spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2')
- transforms = {
- '^/auth/v1/': ['/', '/v1/']
- }
+ spec = spec_loader.load_spec("st2common", "openapi.yaml.j2")
+ transforms = {"^/auth/v1/": ["/", "/v1/"]}
router.add_spec(spec, transforms=transforms)
app = router.as_wsgi
@@ -83,8 +85,8 @@ def setup_app(config=None):
app = ErrorHandlingMiddleware(app)
app = CorsMiddleware(app)
app = LoggingMiddleware(app, router)
- app = ResponseInstrumentationMiddleware(app, router, service_name='auth')
+ app = ResponseInstrumentationMiddleware(app, router, service_name="auth")
app = RequestIDMiddleware(app)
- app = RequestInstrumentationMiddleware(app, router, service_name='auth')
+ app = RequestInstrumentationMiddleware(app, router, service_name="auth")
return app
diff --git a/st2auth/st2auth/backends/__init__.py b/st2auth/st2auth/backends/__init__.py
index 64d3275af5..a626f0d082 100644
--- a/st2auth/st2auth/backends/__init__.py
+++ b/st2auth/st2auth/backends/__init__.py
@@ -22,14 +22,11 @@
from st2common import log as logging
from st2common.util import driver_loader
-__all__ = [
- 'get_available_backends',
- 'get_backend_instance'
-]
+__all__ = ["get_available_backends", "get_backend_instance"]
LOG = logging.getLogger(__name__)
-BACKENDS_NAMESPACE = 'st2auth.backends.backend'
+BACKENDS_NAMESPACE = "st2auth.backends.backend"
def get_available_backends():
@@ -43,8 +40,10 @@ def get_backend_instance(name):
try:
kwargs = json.loads(backend_kwargs)
except ValueError as e:
- raise ValueError('Failed to JSON parse backend settings for backend "%s": %s' %
- (name, six.text_type(e)))
+ raise ValueError(
+ 'Failed to JSON parse backend settings for backend "%s": %s'
+ % (name, six.text_type(e))
+ )
else:
kwargs = {}
@@ -55,9 +54,11 @@ def get_backend_instance(name):
except Exception as e:
tb_msg = traceback.format_exc()
class_name = cls.__name__
- msg = ('Failed to instantiate auth backend "%s" (class %s) with backend settings '
- '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)))
- msg += '\n\n' + tb_msg
+ msg = (
+ 'Failed to instantiate auth backend "%s" (class %s) with backend settings '
+ '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))
+ )
+ msg += "\n\n" + tb_msg
exc_cls = type(e)
raise exc_cls(msg)
diff --git a/st2auth/st2auth/backends/base.py b/st2auth/st2auth/backends/base.py
index 0246729c1a..4d32e51860 100644
--- a/st2auth/st2auth/backends/base.py
+++ b/st2auth/st2auth/backends/base.py
@@ -19,9 +19,7 @@
from st2auth.backends.constants import AuthBackendCapability
-__all__ = [
- 'BaseAuthenticationBackend'
-]
+__all__ = ["BaseAuthenticationBackend"]
@six.add_metaclass(abc.ABCMeta)
@@ -31,9 +29,7 @@ class BaseAuthenticationBackend(object):
"""
# Capabilities offered by the auth backend
- CAPABILITIES = (
- AuthBackendCapability.CAN_AUTHENTICATE_USER
- )
+ CAPABILITIES = AuthBackendCapability.CAN_AUTHENTICATE_USER
@abc.abstractmethod
def authenticate(self, username, password):
@@ -47,7 +43,7 @@ def get_user(self, username):
:rtype: ``dict``
"""
- raise NotImplementedError('get_user() not implemented for this backend')
+ raise NotImplementedError("get_user() not implemented for this backend")
def get_user_groups(self, username):
"""
@@ -57,4 +53,4 @@ def get_user_groups(self, username):
:rtype: ``list`` of ``str``
"""
- raise NotImplementedError('get_groups() not implemented for this backend')
+ raise NotImplementedError("get_groups() not implemented for this backend")
diff --git a/st2auth/st2auth/backends/constants.py b/st2auth/st2auth/backends/constants.py
index 6cb990c64d..b50625e745 100644
--- a/st2auth/st2auth/backends/constants.py
+++ b/st2auth/st2auth/backends/constants.py
@@ -19,17 +19,15 @@
from st2common.util.enum import Enum
-__all__ = [
- 'AuthBackendCapability'
-]
+__all__ = ["AuthBackendCapability"]
class AuthBackendCapability(Enum):
# This auth backend can authenticate a user.
- CAN_AUTHENTICATE_USER = 'can_authenticate_user'
+ CAN_AUTHENTICATE_USER = "can_authenticate_user"
# Auth backend can provide additional information about a particular user.
- HAS_USER_INFORMATION = 'has_user_info'
+ HAS_USER_INFORMATION = "has_user_info"
# Auth backend can provide a group membership information for a particular user.
- HAS_GROUP_INFORMATION = 'has_groups_info'
+ HAS_GROUP_INFORMATION = "has_groups_info"
diff --git a/st2auth/st2auth/cmd/api.py b/st2auth/st2auth/cmd/api.py
index d1fd7605bd..4c52f2649c 100644
--- a/st2auth/st2auth/cmd/api.py
+++ b/st2auth/st2auth/cmd/api.py
@@ -14,6 +14,7 @@
# limitations under the License.
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import eventlet
@@ -27,15 +28,14 @@
from st2common.service_setup import setup as common_setup
from st2common.service_setup import teardown as common_teardown
from st2auth import config
+
config.register_opts()
from st2auth import app
from st2auth.validation import validate_auth_backend_is_correctly_configured
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
@@ -43,15 +43,23 @@
def _setup():
capabilities = {
- 'name': 'auth',
- 'listen_host': cfg.CONF.auth.host,
- 'listen_port': cfg.CONF.auth.port,
- 'listen_ssl': cfg.CONF.auth.use_ssl,
- 'type': 'active'
+ "name": "auth",
+ "listen_host": cfg.CONF.auth.host,
+ "listen_port": cfg.CONF.auth.port,
+ "listen_ssl": cfg.CONF.auth.use_ssl,
+ "type": "active",
}
- common_setup(service='auth', config=config, setup_db=True, register_mq_exchanges=False,
- register_signal_handlers=True, register_internal_trigger_types=False,
- run_migrations=False, service_registry=True, capabilities=capabilities)
+ common_setup(
+ service="auth",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=False,
+ register_signal_handlers=True,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ service_registry=True,
+ capabilities=capabilities,
+ )
# Additional pre-run time checks
validate_auth_backend_is_correctly_configured()
@@ -74,14 +82,18 @@ def _run_server():
socket = eventlet.listen((host, port))
if use_ssl:
- socket = eventlet.wrap_ssl(socket,
- certfile=cert_file_path,
- keyfile=key_file_path,
- server_side=True)
+ socket = eventlet.wrap_ssl(
+ socket, certfile=cert_file_path, keyfile=key_file_path, server_side=True
+ )
LOG.info('ST2 Auth API running in "%s" auth mode', cfg.CONF.auth.mode)
- LOG.info('(PID=%s) ST2 Auth API is serving on %s://%s:%s.', os.getpid(),
- 'https' if use_ssl else 'http', host, port)
+ LOG.info(
+ "(PID=%s) ST2 Auth API is serving on %s://%s:%s.",
+ os.getpid(),
+ "https" if use_ssl else "http",
+ host,
+ port,
+ )
wsgi.server(socket, app.setup_app(), log=LOG, log_output=False)
return 0
@@ -98,7 +110,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except Exception:
- LOG.exception('(PID=%s) ST2 Auth API quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) ST2 Auth API quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2auth/st2auth/config.py b/st2auth/st2auth/config.py
index 00cfa2aca7..dee0d2d064 100644
--- a/st2auth/st2auth/config.py
+++ b/st2auth/st2auth/config.py
@@ -28,8 +28,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -50,47 +53,61 @@ def _register_app_opts():
auth_opts = [
cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='Host on which the service should listen on.'),
+ "host",
+ default="127.0.0.1",
+ help="Host on which the service should listen on.",
+ ),
cfg.IntOpt(
- 'port', default=9100,
- help='Port on which the service should listen on.'),
- cfg.BoolOpt(
- 'use_ssl', default=False,
- help='Specify to enable SSL / TLS mode'),
+ "port", default=9100, help="Port on which the service should listen on."
+ ),
+ cfg.BoolOpt("use_ssl", default=False, help="Specify to enable SSL / TLS mode"),
cfg.StrOpt(
- 'cert', default='/etc/apache2/ssl/mycert.crt',
- help='Path to the SSL certificate file. Only used when "use_ssl" is specified.'),
+ "cert",
+ default="/etc/apache2/ssl/mycert.crt",
+ help='Path to the SSL certificate file. Only used when "use_ssl" is specified.',
+ ),
cfg.StrOpt(
- 'key', default='/etc/apache2/ssl/mycert.key',
- help='Path to the SSL private key file. Only used when "use_ssl" is specified.'),
+ "key",
+ default="/etc/apache2/ssl/mycert.key",
+ help='Path to the SSL private key file. Only used when "use_ssl" is specified.',
+ ),
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.auth.conf',
- help='Path to the logging config.'),
- cfg.BoolOpt(
- 'debug', default=False,
- help='Specify to enable debug mode.'),
+ "logging",
+ default="/etc/st2/logging.auth.conf",
+ help="Path to the logging config.",
+ ),
+ cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."),
cfg.StrOpt(
- 'mode', default=DEFAULT_MODE,
- help='Authentication mode (%s)' % (','.join(VALID_MODES))),
+ "mode",
+ default=DEFAULT_MODE,
+ help="Authentication mode (%s)" % (",".join(VALID_MODES)),
+ ),
cfg.StrOpt(
- 'backend', default=DEFAULT_BACKEND,
- help='Authentication backend to use in a standalone mode. Available '
- 'backends: %s.' % (', '.join(available_backends))),
+ "backend",
+ default=DEFAULT_BACKEND,
+ help="Authentication backend to use in a standalone mode. Available "
+ "backends: %s." % (", ".join(available_backends)),
+ ),
cfg.StrOpt(
- 'backend_kwargs', default=None,
- help='JSON serialized arguments which are passed to the authentication '
- 'backend in a standalone mode.'),
+ "backend_kwargs",
+ default=None,
+ help="JSON serialized arguments which are passed to the authentication "
+ "backend in a standalone mode.",
+ ),
cfg.BoolOpt(
- 'sso', default=False,
- help='Enable Single Sign On for GUI if true.'),
+ "sso", default=False, help="Enable Single Sign On for GUI if true."
+ ),
cfg.StrOpt(
- 'sso_backend', default=DEFAULT_SSO_BACKEND,
- help='Single Sign On backend to use when SSO is enabled. Available '
- 'backends: noop, saml2.'),
+ "sso_backend",
+ default=DEFAULT_SSO_BACKEND,
+ help="Single Sign On backend to use when SSO is enabled. Available "
+ "backends: noop, saml2.",
+ ),
cfg.StrOpt(
- 'sso_backend_kwargs', default=None,
- help='JSON serialized arguments which are passed to the SSO backend.')
+ "sso_backend_kwargs",
+ default=None,
+ help="JSON serialized arguments which are passed to the SSO backend.",
+ ),
]
- cfg.CONF.register_cli_opts(auth_opts, group='auth')
+ cfg.CONF.register_cli_opts(auth_opts, group="auth")
diff --git a/st2auth/st2auth/controllers/v1/auth.py b/st2auth/st2auth/controllers/v1/auth.py
index f0042632e9..c77546141f 100644
--- a/st2auth/st2auth/controllers/v1/auth.py
+++ b/st2auth/st2auth/controllers/v1/auth.py
@@ -29,8 +29,8 @@
HANDLER_MAPPINGS = {
- 'proxy': handlers.ProxyAuthHandler,
- 'standalone': handlers.StandaloneAuthHandler
+ "proxy": handlers.ProxyAuthHandler,
+ "standalone": handlers.StandaloneAuthHandler,
}
LOG = logging.getLogger(__name__)
@@ -38,17 +38,17 @@
class TokenValidationController(object):
def post(self, request):
- token = getattr(request, 'token', None)
+ token = getattr(request, "token", None)
if not token:
- raise exc.HTTPBadRequest('Token is not provided.')
+ raise exc.HTTPBadRequest("Token is not provided.")
try:
- return {'valid': auth_utils.validate_token(token) is not None}
+ return {"valid": auth_utils.validate_token(token) is not None}
except (TokenNotFoundError, TokenExpiredError):
- return {'valid': False}
+ return {"valid": False}
except Exception:
- msg = 'Unexpected error occurred while verifying token.'
+ msg = "Unexpected error occurred while verifying token."
LOG.exception(msg)
raise exc.HTTPInternalServerError(msg)
@@ -60,30 +60,32 @@ def __init__(self):
try:
self.handler = HANDLER_MAPPINGS[cfg.CONF.auth.mode]()
except KeyError:
- raise ParamException("%s is not a valid auth mode" %
- cfg.CONF.auth.mode)
+ raise ParamException("%s is not a valid auth mode" % cfg.CONF.auth.mode)
def post(self, request, **kwargs):
headers = {}
- if 'x-forwarded-for' in kwargs:
- headers['x-forwarded-for'] = kwargs.pop('x-forwarded-for')
+ if "x-forwarded-for" in kwargs:
+ headers["x-forwarded-for"] = kwargs.pop("x-forwarded-for")
- authorization = kwargs.pop('authorization', None)
+ authorization = kwargs.pop("authorization", None)
if authorization:
- authorization = tuple(authorization.split(' '))
-
- token = self.handler.handle_auth(request=request, headers=headers,
- remote_addr=kwargs.pop('remote_addr', None),
- remote_user=kwargs.pop('remote_user', None),
- authorization=authorization,
- **kwargs)
+ authorization = tuple(authorization.split(" "))
+
+ token = self.handler.handle_auth(
+ request=request,
+ headers=headers,
+ remote_addr=kwargs.pop("remote_addr", None),
+ remote_user=kwargs.pop("remote_user", None),
+ authorization=authorization,
+ **kwargs,
+ )
return process_successful_response(token=token)
def process_successful_response(token):
resp = Response(json=token, status=http_client.CREATED)
# NOTE: gunicon fails and throws an error if header value is not a string (e.g. if it's None)
- resp.headers['X-API-URL'] = api_utils.get_base_public_api_url()
+ resp.headers["X-API-URL"] = api_utils.get_base_public_api_url()
return resp
diff --git a/st2auth/st2auth/controllers/v1/sso.py b/st2auth/st2auth/controllers/v1/sso.py
index f25effe681..ef1096462c 100644
--- a/st2auth/st2auth/controllers/v1/sso.py
+++ b/st2auth/st2auth/controllers/v1/sso.py
@@ -32,7 +32,6 @@
class IdentityProviderCallbackController(object):
-
def __init__(self):
self.st2_auth_handler = handlers.ProxyAuthHandler()
@@ -40,16 +39,21 @@ def post(self, response, **kwargs):
try:
verified_user = SSO_BACKEND.verify_response(response)
- st2_auth_token_create_request = {'user': verified_user['username'], 'ttl': None}
+ st2_auth_token_create_request = {
+ "user": verified_user["username"],
+ "ttl": None,
+ }
st2_auth_token = self.st2_auth_handler.handle_auth(
request=st2_auth_token_create_request,
- remote_addr=verified_user['referer'],
- remote_user=verified_user['username'],
- headers={}
+ remote_addr=verified_user["referer"],
+ remote_user=verified_user["username"],
+ headers={},
)
- return process_successful_authn_response(verified_user['referer'], st2_auth_token)
+ return process_successful_authn_response(
+ verified_user["referer"], st2_auth_token
+ )
except NotImplementedError as e:
return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e)
except auth_exc.SSOVerificationError as e:
@@ -59,7 +63,6 @@ def post(self, response, **kwargs):
class SingleSignOnRequestController(object):
-
def get(self, referer):
try:
response = router.Response(status=http_client.TEMPORARY_REDIRECT)
@@ -76,15 +79,15 @@ class SingleSignOnController(object):
callback = IdentityProviderCallbackController()
def _get_sso_enabled_config(self):
- return {'enabled': cfg.CONF.auth.sso}
+ return {"enabled": cfg.CONF.auth.sso}
def get(self):
try:
result = self._get_sso_enabled_config()
return process_successful_response(http_client.OK, result)
except Exception:
- LOG.exception('Error encountered while getting SSO configuration.')
- result = {'enabled': False}
+ LOG.exception("Error encountered while getting SSO configuration.")
+ result = {"enabled": False}
return process_successful_response(http_client.OK, result)
@@ -107,23 +110,23 @@ def get(self):
def process_successful_authn_response(referer, token):
token_json = {
- 'id': str(token.id),
- 'user': token.user,
- 'token': token.token,
- 'expiry': str(token.expiry),
- 'service': False,
- 'metadata': {}
+ "id": str(token.id),
+ "user": token.user,
+ "token": token.token,
+ "expiry": str(token.expiry),
+ "service": False,
+ "metadata": {},
}
body = CALLBACK_SUCCESS_RESPONSE_BODY % referer
resp = router.Response(body=body)
- resp.headers['Content-Type'] = 'text/html'
+ resp.headers["Content-Type"] = "text/html"
resp.set_cookie(
- 'st2-auth-token',
+ "st2-auth-token",
value=urllib.parse.quote(json.dumps(token_json)),
expires=datetime.timedelta(seconds=60),
- overwrite=True
+ overwrite=True,
)
return resp
@@ -135,7 +138,7 @@ def process_successful_response(status_code, json_body):
def process_failure_response(status_code, exception):
LOG.error(str(exception))
- json_body = {'faultstring': str(exception)}
+ json_body = {"faultstring": str(exception)}
return router.Response(status_code=status_code, json_body=json_body)
diff --git a/st2auth/st2auth/handlers.py b/st2auth/st2auth/handlers.py
index 59d74085cf..f6540bcda7 100644
--- a/st2auth/st2auth/handlers.py
+++ b/st2auth/st2auth/handlers.py
@@ -35,13 +35,22 @@
LOG = logging.getLogger(__name__)
-def abort_request(status_code=http_client.UNAUTHORIZED, message='Invalid or missing credentials'):
+def abort_request(
+ status_code=http_client.UNAUTHORIZED, message="Invalid or missing credentials"
+):
return abort(status_code, message)
class AuthHandlerBase(object):
- def handle_auth(self, request, headers=None, remote_addr=None,
- remote_user=None, authorization=None, **kwargs):
+ def handle_auth(
+ self,
+ request,
+ headers=None,
+ remote_addr=None,
+ remote_user=None,
+ authorization=None,
+ **kwargs,
+ ):
raise NotImplementedError()
def _create_token_for_user(self, username, ttl=None):
@@ -49,80 +58,90 @@ def _create_token_for_user(self, username, ttl=None):
return TokenAPI.from_model(tokendb)
def _get_username_for_request(self, username, request):
- impersonate_user = getattr(request, 'user', None)
+ impersonate_user = getattr(request, "user", None)
if impersonate_user is not None:
# check this is a service account
try:
if not User.get_by_name(username).is_service:
- message = "Current user is not a service and cannot " \
- "request impersonated tokens"
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = (
+ "Current user is not a service and cannot "
+ "request impersonated tokens"
+ )
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
username = impersonate_user
except (UserNotFoundError, StackStormDBObjectNotFoundError):
- message = "Could not locate user %s" % \
- (impersonate_user)
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = "Could not locate user %s" % (impersonate_user)
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
else:
- impersonate_user = getattr(request, 'impersonate_user', None)
- nickname_origin = getattr(request, 'nickname_origin', None)
+ impersonate_user = getattr(request, "impersonate_user", None)
+ nickname_origin = getattr(request, "nickname_origin", None)
if impersonate_user is not None:
try:
# check this is a service account
if not User.get_by_name(username).is_service:
raise NotServiceUserError()
- username = User.get_by_nickname(impersonate_user,
- nickname_origin).name
+ username = User.get_by_nickname(
+ impersonate_user, nickname_origin
+ ).name
except NotServiceUserError:
- message = "Current user is not a service and cannot " \
- "request impersonated tokens"
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = (
+ "Current user is not a service and cannot "
+ "request impersonated tokens"
+ )
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
except (UserNotFoundError, StackStormDBObjectNotFoundError):
- message = "Could not locate user %s@%s" % \
- (impersonate_user, nickname_origin)
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = "Could not locate user %s@%s" % (
+ impersonate_user,
+ nickname_origin,
+ )
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
except NoNicknameOriginProvidedError:
- message = "Nickname origin is not provided for nickname '%s'" % \
- impersonate_user
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = (
+ "Nickname origin is not provided for nickname '%s'"
+ % impersonate_user
+ )
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
except AmbiguousUserError:
- message = "%s@%s matched more than one username" % \
- (impersonate_user, nickname_origin)
- abort_request(status_code=http_client.BAD_REQUEST,
- message=message)
+ message = "%s@%s matched more than one username" % (
+ impersonate_user,
+ nickname_origin,
+ )
+ abort_request(status_code=http_client.BAD_REQUEST, message=message)
return
return username
class ProxyAuthHandler(AuthHandlerBase):
- def handle_auth(self, request, headers=None, remote_addr=None,
- remote_user=None, authorization=None, **kwargs):
- remote_addr = headers.get('x-forwarded-for',
- remote_addr)
- extra = {'remote_addr': remote_addr}
+ def handle_auth(
+ self,
+ request,
+ headers=None,
+ remote_addr=None,
+ remote_user=None,
+ authorization=None,
+ **kwargs,
+ ):
+ remote_addr = headers.get("x-forwarded-for", remote_addr)
+ extra = {"remote_addr": remote_addr}
if remote_user:
- ttl = getattr(request, 'ttl', None)
+ ttl = getattr(request, "ttl", None)
username = self._get_username_for_request(remote_user, request)
try:
- token = self._create_token_for_user(username=username,
- ttl=ttl)
+ token = self._create_token_for_user(username=username, ttl=ttl)
except TTLTooLargeException as e:
- abort_request(status_code=http_client.BAD_REQUEST,
- message=six.text_type(e))
+ abort_request(
+ status_code=http_client.BAD_REQUEST, message=six.text_type(e)
+ )
return token
- LOG.audit('Access denied to anonymous user.', extra=extra)
+ LOG.audit("Access denied to anonymous user.", extra=extra)
abort_request()
@@ -131,77 +150,91 @@ def __init__(self, *args, **kwargs):
self._auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend)
super(StandaloneAuthHandler, self).__init__(*args, **kwargs)
- def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None,
- authorization=None, **kwargs):
+ def handle_auth(
+ self,
+ request,
+ headers=None,
+ remote_addr=None,
+ remote_user=None,
+ authorization=None,
+ **kwargs,
+ ):
auth_backend = self._auth_backend.__class__.__name__
- extra = {'auth_backend': auth_backend, 'remote_addr': remote_addr}
+ extra = {"auth_backend": auth_backend, "remote_addr": remote_addr}
if not authorization:
- LOG.audit('Authorization header not provided', extra=extra)
+ LOG.audit("Authorization header not provided", extra=extra)
abort_request()
return
auth_type, auth_value = authorization
- if auth_type.lower() not in ['basic']:
- extra['auth_type'] = auth_type
- LOG.audit('Unsupported authorization type: %s' % (auth_type), extra=extra)
+ if auth_type.lower() not in ["basic"]:
+ extra["auth_type"] = auth_type
+ LOG.audit("Unsupported authorization type: %s" % (auth_type), extra=extra)
abort_request()
return
try:
auth_value = base64.b64decode(auth_value)
except Exception:
- LOG.audit('Invalid authorization header', extra=extra)
+ LOG.audit("Invalid authorization header", extra=extra)
abort_request()
return
- split = auth_value.split(b':', 1)
+ split = auth_value.split(b":", 1)
if len(split) != 2:
- LOG.audit('Invalid authorization header', extra=extra)
+ LOG.audit("Invalid authorization header", extra=extra)
abort_request()
return
username, password = split
if six.PY3 and isinstance(username, six.binary_type):
- username = username.decode('utf-8')
+ username = username.decode("utf-8")
if six.PY3 and isinstance(password, six.binary_type):
- password = password.decode('utf-8')
+ password = password.decode("utf-8")
result = self._auth_backend.authenticate(username=username, password=password)
if result is True:
- ttl = getattr(request, 'ttl', None)
+ ttl = getattr(request, "ttl", None)
username = self._get_username_for_request(username, request)
try:
token = self._create_token_for_user(username=username, ttl=ttl)
except TTLTooLargeException as e:
- abort_request(status_code=http_client.BAD_REQUEST,
- message=six.text_type(e))
+ abort_request(
+ status_code=http_client.BAD_REQUEST, message=six.text_type(e)
+ )
return
# If remote group sync is enabled, sync the remote groups with local StackStorm roles
- if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != 'noop':
- LOG.debug('Retrieving auth backend groups for user "%s"' % (username),
- extra=extra)
+ if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != "noop":
+ LOG.debug(
+ 'Retrieving auth backend groups for user "%s"' % (username),
+ extra=extra,
+ )
try:
user_groups = self._auth_backend.get_user_groups(username=username)
except (NotImplementedError, AttributeError):
- LOG.debug('Configured auth backend doesn\'t expose user group membership '
- 'information, skipping sync...')
+ LOG.debug(
+ "Configured auth backend doesn't expose user group membership "
+ "information, skipping sync..."
+ )
return token
if not user_groups:
# No groups, return early
return token
- extra['username'] = username
- extra['user_groups'] = user_groups
+ extra["username"] = username
+ extra["user_groups"] = user_groups
- LOG.debug('Found "%s" groups for user "%s"' % (len(user_groups), username),
- extra=extra)
+ LOG.debug(
+ 'Found "%s" groups for user "%s"' % (len(user_groups), username),
+ extra=extra,
+ )
user_db = UserDB(name=username)
@@ -212,14 +245,19 @@ def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None,
syncer.sync(user_db=user_db, groups=user_groups)
except Exception:
# Note: Failed sync is not fatal
- LOG.exception('Failed to synchronize remote groups for user "%s"' % (username),
- extra=extra)
+ LOG.exception(
+ 'Failed to synchronize remote groups for user "%s"'
+ % (username),
+ extra=extra,
+ )
else:
- LOG.debug('Successfully synchronized groups for user "%s"' % (username),
- extra=extra)
+ LOG.debug(
+ 'Successfully synchronized groups for user "%s"' % (username),
+ extra=extra,
+ )
return token
return token
- LOG.audit('Invalid credentials provided', extra=extra)
+ LOG.audit("Invalid credentials provided", extra=extra)
abort_request()
diff --git a/st2auth/st2auth/sso/__init__.py b/st2auth/st2auth/sso/__init__.py
index 5839059ed9..b6d0df930a 100644
--- a/st2auth/st2auth/sso/__init__.py
+++ b/st2auth/st2auth/sso/__init__.py
@@ -25,15 +25,11 @@
from st2common.util import driver_loader
-__all__ = [
- 'get_available_backends',
- 'get_backend_instance',
- 'get_sso_backend'
-]
+__all__ = ["get_available_backends", "get_backend_instance", "get_sso_backend"]
LOG = logging.getLogger(__name__)
-BACKENDS_NAMESPACE = 'st2auth.sso.backends'
+BACKENDS_NAMESPACE = "st2auth.sso.backends"
def get_available_backends():
@@ -41,7 +37,9 @@ def get_available_backends():
def get_backend_instance(name):
- sso_backend_cls = driver_loader.get_backend_driver(namespace=BACKENDS_NAMESPACE, name=name)
+ sso_backend_cls = driver_loader.get_backend_driver(
+ namespace=BACKENDS_NAMESPACE, name=name
+ )
kwargs = {}
sso_backend_kwargs = cfg.CONF.auth.sso_backend_kwargs
@@ -51,8 +49,8 @@ def get_backend_instance(name):
kwargs = json.loads(sso_backend_kwargs)
except ValueError as e:
raise ValueError(
- 'Failed to JSON parse backend settings for backend "%s": %s' %
- (name, six.text_type(e))
+ 'Failed to JSON parse backend settings for backend "%s": %s'
+ % (name, six.text_type(e))
)
try:
@@ -60,9 +58,11 @@ def get_backend_instance(name):
except Exception as e:
tb_msg = traceback.format_exc()
class_name = sso_backend_cls.__name__
- msg = ('Failed to instantiate SSO backend "%s" (class %s) with backend settings '
- '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)))
- msg += '\n\n' + tb_msg
+ msg = (
+ 'Failed to instantiate SSO backend "%s" (class %s) with backend settings '
+ '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))
+ )
+ msg += "\n\n" + tb_msg
exc_cls = type(e)
raise exc_cls(msg)
diff --git a/st2auth/st2auth/sso/base.py b/st2auth/st2auth/sso/base.py
index c96782aba2..5e11199818 100644
--- a/st2auth/st2auth/sso/base.py
+++ b/st2auth/st2auth/sso/base.py
@@ -16,9 +16,7 @@
import six
-__all__ = [
- 'BaseSingleSignOnBackend'
-]
+__all__ = ["BaseSingleSignOnBackend"]
@six.add_metaclass(abc.ABCMeta)
@@ -32,5 +30,7 @@ def get_request_redirect_url(self, referer):
raise NotImplementedError(msg)
def verify_response(self, response):
- msg = 'The function "verify_response" is not implemented in the base SSO backend.'
+ msg = (
+ 'The function "verify_response" is not implemented in the base SSO backend.'
+ )
raise NotImplementedError(msg)
diff --git a/st2auth/st2auth/sso/noop.py b/st2auth/st2auth/sso/noop.py
index 6cacb5e7e9..6699e084f3 100644
--- a/st2auth/st2auth/sso/noop.py
+++ b/st2auth/st2auth/sso/noop.py
@@ -17,13 +17,11 @@
from st2auth.sso.base import BaseSingleSignOnBackend
-__all__ = [
- 'NoOpSingleSignOnBackend'
-]
+__all__ = ["NoOpSingleSignOnBackend"]
NOT_IMPLEMENTED_MESSAGE = (
'The default "noop" SSO backend is not a proper implementation. '
- 'Please refer to the enterprise version for configuring SSO.'
+ "Please refer to the enterprise version for configuring SSO."
)
diff --git a/st2auth/st2auth/validation.py b/st2auth/st2auth/validation.py
index 924ad390f2..ccea906062 100644
--- a/st2auth/st2auth/validation.py
+++ b/st2auth/st2auth/validation.py
@@ -19,26 +19,28 @@
from st2auth.backends import get_backend_instance as get_auth_backend_instance
from st2auth.backends.constants import AuthBackendCapability
-__all__ = [
- 'validate_auth_backend_is_correctly_configured'
-]
+__all__ = ["validate_auth_backend_is_correctly_configured"]
def validate_auth_backend_is_correctly_configured():
# 1. Verify correct mode is specified
if cfg.CONF.auth.mode not in VALID_MODES:
- msg = ('Invalid auth mode "%s" specified in the config. Valid modes are: %s' %
- (cfg.CONF.auth.mode, ', '.join(VALID_MODES)))
+ msg = 'Invalid auth mode "%s" specified in the config. Valid modes are: %s' % (
+ cfg.CONF.auth.mode,
+ ", ".join(VALID_MODES),
+ )
raise ValueError(msg)
# 2. Verify that auth backend used by the user exposes group information
if cfg.CONF.rbac.enable and cfg.CONF.rbac.sync_remote_groups:
auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend)
- capabilies = getattr(auth_backend, 'CAPABILITIES', ())
+ capabilies = getattr(auth_backend, "CAPABILITIES", ())
if AuthBackendCapability.HAS_GROUP_INFORMATION not in capabilies:
- msg = ('Configured auth backend doesn\'t expose user group information. Disable '
- 'remote group synchronization or use a different backend which exposes '
- 'user group membership information.')
+ msg = (
+ "Configured auth backend doesn't expose user group information. Disable "
+ "remote group synchronization or use a different backend which exposes "
+ "user group membership information."
+ )
raise ValueError(msg)
return True
diff --git a/st2auth/st2auth/wsgi.py b/st2auth/st2auth/wsgi.py
index 2fb9bee07a..16a44e64f3 100644
--- a/st2auth/st2auth/wsgi.py
+++ b/st2auth/st2auth/wsgi.py
@@ -16,6 +16,7 @@
import os
from st2common.util.monkey_patch import monkey_patch
+
# Note: We need to perform monkey patching in the worker. If we do it in
# the master process (gunicorn_config.py), it breaks tons of things
# including shutdown
@@ -28,8 +29,11 @@
from st2auth import app
config = {
- 'is_gunicorn': True,
- 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')]
+ "is_gunicorn": True,
+ "config_args": [
+ "--config-file",
+ os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"),
+ ],
}
application = app.setup_app(config)
diff --git a/st2auth/tests/base.py b/st2auth/tests/base.py
index e3bc2e1a05..dc63c1094e 100644
--- a/st2auth/tests/base.py
+++ b/st2auth/tests/base.py
@@ -20,7 +20,6 @@
class FunctionalTest(DbTestCase):
-
@classmethod
def setUpClass(cls, **kwargs):
super(FunctionalTest, cls).setUpClass()
diff --git a/st2auth/tests/unit/controllers/v1/test_sso.py b/st2auth/tests/unit/controllers/v1/test_sso.py
index 81d9dcea1d..2b6edb1f83 100644
--- a/st2auth/tests/unit/controllers/v1/test_sso.py
+++ b/st2auth/tests/unit/controllers/v1/test_sso.py
@@ -13,6 +13,7 @@
# limitations under the License.
import st2tests.config as tests_config
+
tests_config.parse_args()
import json
@@ -28,110 +29,125 @@
from tests.base import FunctionalTest
-SSO_V1_PATH = '/v1/sso'
-SSO_REQUEST_V1_PATH = SSO_V1_PATH + '/request'
-SSO_CALLBACK_V1_PATH = SSO_V1_PATH + '/callback'
-MOCK_REFERER = 'https://127.0.0.1'
-MOCK_USER = 'stanley'
+SSO_V1_PATH = "/v1/sso"
+SSO_REQUEST_V1_PATH = SSO_V1_PATH + "/request"
+SSO_CALLBACK_V1_PATH = SSO_V1_PATH + "/callback"
+MOCK_REFERER = "https://127.0.0.1"
+MOCK_USER = "stanley"
class TestSingleSignOnController(FunctionalTest):
-
def test_sso_enabled(self):
- cfg.CONF.set_override(group='auth', name='sso', override=True)
+ cfg.CONF.set_override(group="auth", name="sso", override=True)
response = self.app.get(SSO_V1_PATH, expect_errors=False)
self.assertTrue(response.status_code, http_client.OK)
- self.assertDictEqual(response.json, {'enabled': True})
+ self.assertDictEqual(response.json, {"enabled": True})
def test_sso_disabled(self):
- cfg.CONF.set_override(group='auth', name='sso', override=False)
+ cfg.CONF.set_override(group="auth", name="sso", override=False)
response = self.app.get(SSO_V1_PATH, expect_errors=False)
self.assertTrue(response.status_code, http_client.OK)
- self.assertDictEqual(response.json, {'enabled': False})
+ self.assertDictEqual(response.json, {"enabled": False})
@mock.patch.object(
sso_api_controller.SingleSignOnController,
- '_get_sso_enabled_config',
- mock.MagicMock(side_effect=KeyError('foobar')))
+ "_get_sso_enabled_config",
+ mock.MagicMock(side_effect=KeyError("foobar")),
+ )
def test_unknown_exception(self):
- cfg.CONF.set_override(group='auth', name='sso', override=True)
+ cfg.CONF.set_override(group="auth", name="sso", override=True)
response = self.app.get(SSO_V1_PATH, expect_errors=False)
self.assertTrue(response.status_code, http_client.OK)
- self.assertDictEqual(response.json, {'enabled': False})
- self.assertTrue(sso_api_controller.SingleSignOnController._get_sso_enabled_config.called)
+ self.assertDictEqual(response.json, {"enabled": False})
+ self.assertTrue(
+ sso_api_controller.SingleSignOnController._get_sso_enabled_config.called
+ )
class TestSingleSignOnRequestController(FunctionalTest):
-
@mock.patch.object(
sso_api_controller.SSO_BACKEND,
- 'get_request_redirect_url',
- mock.MagicMock(side_effect=Exception('fooobar')))
+ "get_request_redirect_url",
+ mock.MagicMock(side_effect=Exception("fooobar")),
+ )
def test_default_backend_unknown_exception(self):
- expected_error = {'faultstring': 'Internal Server Error'}
+ expected_error = {"faultstring": "Internal Server Error"}
response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True)
self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR)
self.assertDictEqual(response.json, expected_error)
def test_default_backend_not_implemented(self):
- expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE}
+ expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}
response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True)
self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR)
self.assertDictEqual(response.json, expected_error)
@mock.patch.object(
sso_api_controller.SSO_BACKEND,
- 'get_request_redirect_url',
- mock.MagicMock(return_value='https://127.0.0.1'))
+ "get_request_redirect_url",
+ mock.MagicMock(return_value="https://127.0.0.1"),
+ )
def test_idp_redirect(self):
response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=False)
self.assertTrue(response.status_code, http_client.TEMPORARY_REDIRECT)
- self.assertEqual(response.location, 'https://127.0.0.1')
+ self.assertEqual(response.location, "https://127.0.0.1")
class TestIdentityProviderCallbackController(FunctionalTest):
-
@mock.patch.object(
sso_api_controller.SSO_BACKEND,
- 'verify_response',
- mock.MagicMock(side_effect=Exception('fooobar')))
+ "verify_response",
+ mock.MagicMock(side_effect=Exception("fooobar")),
+ )
def test_default_backend_unknown_exception(self):
- expected_error = {'faultstring': 'Internal Server Error'}
- response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True)
+ expected_error = {"faultstring": "Internal Server Error"}
+ response = self.app.post_json(
+ SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True
+ )
self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR)
self.assertDictEqual(response.json, expected_error)
def test_default_backend_not_implemented(self):
- expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE}
- response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True)
+ expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE}
+ response = self.app.post_json(
+ SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True
+ )
self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR)
self.assertDictEqual(response.json, expected_error)
@mock.patch.object(
sso_api_controller.SSO_BACKEND,
- 'verify_response',
- mock.MagicMock(return_value={'referer': MOCK_REFERER, 'username': MOCK_USER}))
+ "verify_response",
+ mock.MagicMock(return_value={"referer": MOCK_REFERER, "username": MOCK_USER}),
+ )
def test_idp_callback(self):
expected_body = sso_api_controller.CALLBACK_SUCCESS_RESPONSE_BODY % MOCK_REFERER
- response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=False)
+ response = self.app.post_json(
+ SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=False
+ )
self.assertTrue(response.status_code, http_client.OK)
- self.assertEqual(expected_body, response.body.decode('utf-8'))
+ self.assertEqual(expected_body, response.body.decode("utf-8"))
- set_cookies_list = [h for h in response.headerlist if h[0] == 'Set-Cookie']
+ set_cookies_list = [h for h in response.headerlist if h[0] == "Set-Cookie"]
self.assertEqual(len(set_cookies_list), 1)
- self.assertIn('st2-auth-token', set_cookies_list[0][1])
+ self.assertIn("st2-auth-token", set_cookies_list[0][1])
- cookie = urllib.parse.unquote(set_cookies_list[0][1]).split('=')
- st2_auth_token = json.loads(cookie[1].split(';')[0])
- self.assertIn('token', st2_auth_token)
- self.assertEqual(st2_auth_token['user'], MOCK_USER)
+ cookie = urllib.parse.unquote(set_cookies_list[0][1]).split("=")
+ st2_auth_token = json.loads(cookie[1].split(";")[0])
+ self.assertIn("token", st2_auth_token)
+ self.assertEqual(st2_auth_token["user"], MOCK_USER)
@mock.patch.object(
sso_api_controller.SSO_BACKEND,
- 'verify_response',
- mock.MagicMock(side_effect=auth_exc.SSOVerificationError('Verification Failed')))
+ "verify_response",
+ mock.MagicMock(
+ side_effect=auth_exc.SSOVerificationError("Verification Failed")
+ ),
+ )
def test_idp_callback_verification_failed(self):
- expected_error = {'faultstring': 'Verification Failed'}
- response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True)
+ expected_error = {"faultstring": "Verification Failed"}
+ response = self.app.post_json(
+ SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True
+ )
self.assertTrue(response.status_code, http_client.UNAUTHORIZED)
self.assertDictEqual(response.json, expected_error)
diff --git a/st2auth/tests/unit/controllers/v1/test_token.py b/st2auth/tests/unit/controllers/v1/test_token.py
index ab5f12342b..cd90a6cef1 100644
--- a/st2auth/tests/unit/controllers/v1/test_token.py
+++ b/st2auth/tests/unit/controllers/v1/test_token.py
@@ -29,25 +29,25 @@
from st2common.persistence.auth import User, Token, ApiKey
-USERNAME = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
-TOKEN_DEFAULT_PATH = '/tokens'
-TOKEN_V1_PATH = '/v1/tokens'
-TOKEN_VERIFY_PATH = '/v1/tokens/validate'
+USERNAME = "".join(random.choice(string.ascii_lowercase) for i in range(10))
+TOKEN_DEFAULT_PATH = "/tokens"
+TOKEN_V1_PATH = "/v1/tokens"
+TOKEN_VERIFY_PATH = "/v1/tokens/validate"
class TestTokenController(FunctionalTest):
-
@classmethod
def setUpClass(cls, **kwargs):
- kwargs['extra_environ'] = {
- 'REMOTE_USER': USERNAME
- }
+ kwargs["extra_environ"] = {"REMOTE_USER": USERNAME}
super(TestTokenController, cls).setUpClass(**kwargs)
def test_token_model(self):
dt = date_utils.get_datetime_utc_now()
- tk1 = TokenAPI(user='stanley', token=uuid.uuid4().hex,
- expiry=isotime.format(dt, offset=False))
+ tk1 = TokenAPI(
+ user="stanley",
+ token=uuid.uuid4().hex,
+ expiry=isotime.format(dt, offset=False),
+ )
tkdb1 = TokenAPI.to_model(tk1)
self.assertIsNotNone(tkdb1)
self.assertIsInstance(tkdb1, TokenDB)
@@ -64,7 +64,7 @@ def test_token_model(self):
def test_token_model_null_token(self):
dt = date_utils.get_datetime_utc_now()
- tk = TokenAPI(user='stanley', token=None, expiry=isotime.format(dt))
+ tk = TokenAPI(user="stanley", token=None, expiry=isotime.format(dt))
self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk))
def test_token_model_null_user(self):
@@ -73,191 +73,215 @@ def test_token_model_null_user(self):
self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk))
def test_token_model_null_expiry(self):
- tk = TokenAPI(user='stanley', token=uuid.uuid4().hex, expiry=None)
+ tk = TokenAPI(user="stanley", token=uuid.uuid4().hex, expiry=None)
self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk))
def _test_token_post(self, path=TOKEN_V1_PATH):
ttl = cfg.CONF.auth.token_ttl
timestamp = date_utils.get_datetime_utc_now()
response = self.app.post_json(path, {}, expect_errors=False)
- expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl)
+ expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=ttl
+ )
expected_expiry = date_utils.add_utc_tz(expected_expiry)
self.assertEqual(response.status_int, 201)
- self.assertIsNotNone(response.json['token'])
- self.assertEqual(response.json['user'], USERNAME)
- actual_expiry = isotime.parse(response.json['expiry'])
+ self.assertIsNotNone(response.json["token"])
+ self.assertEqual(response.json["user"], USERNAME)
+ actual_expiry = isotime.parse(response.json["expiry"])
self.assertLess(timestamp, actual_expiry)
self.assertLess(actual_expiry, expected_expiry)
return response
def test_token_post_unauthorized(self):
- response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={
- 'REMOTE_USER': ''
- })
+ response = self.app.post_json(
+ TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={"REMOTE_USER": ""}
+ )
self.assertEqual(response.status_int, 401)
+ @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=Exception()))
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(side_effect=Exception()))
- @mock.patch.object(
- User, 'add_or_update',
- mock.Mock(return_value=UserDB(name=USERNAME)))
+ User, "add_or_update", mock.Mock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_new_user(self):
self._test_token_post()
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_existing_user(self):
self._test_token_post()
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_success_x_api_url_header_value(self):
# auth.api_url option is explicitly set
- cfg.CONF.set_override('api_url', override='https://example.com', group='auth')
+ cfg.CONF.set_override("api_url", override="https://example.com", group="auth")
resp = self._test_token_post()
- self.assertEqual(resp.headers['X-API-URL'], 'https://example.com')
+ self.assertEqual(resp.headers["X-API-URL"], "https://example.com")
# auth.api_url option is not set, url is inferred from listen host and port
- cfg.CONF.set_override('api_url', override=None, group='auth')
+ cfg.CONF.set_override("api_url", override=None, group="auth")
resp = self._test_token_post()
- self.assertEqual(resp.headers['X-API-URL'], 'http://127.0.0.1:9101')
+ self.assertEqual(resp.headers["X-API-URL"], "http://127.0.0.1:9101")
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_default_url_path(self):
self._test_token_post(path=TOKEN_DEFAULT_PATH)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_set_ttl(self):
timestamp = date_utils.add_utc_tz(date_utils.get_datetime_utc_now())
- response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 60}, expect_errors=False)
- expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=60)
+ response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 60}, expect_errors=False)
+ expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=60
+ )
self.assertEqual(response.status_int, 201)
- actual_expiry = isotime.parse(response.json['expiry'])
+ actual_expiry = isotime.parse(response.json["expiry"])
self.assertLess(timestamp, actual_expiry)
self.assertLess(actual_expiry, expected_expiry)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_no_data_in_body_text_plain_context_type_used(self):
- response = self.app.post(TOKEN_V1_PATH, expect_errors=False, content_type='text/plain')
+ response = self.app.post(
+ TOKEN_V1_PATH, expect_errors=False, content_type="text/plain"
+ )
self.assertEqual(response.status_int, 201)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_set_ttl_over_policy(self):
ttl = cfg.CONF.auth.token_ttl
- response = self.app.post_json(TOKEN_V1_PATH, {'ttl': ttl + 60}, expect_errors=True)
- self.assertEqual(response.status_int, 400)
- message = 'TTL specified %s is greater than max allowed %s.' % (
- ttl + 60, ttl
+ response = self.app.post_json(
+ TOKEN_V1_PATH, {"ttl": ttl + 60}, expect_errors=True
)
- self.assertEqual(response.json['faultstring'], message)
+ self.assertEqual(response.status_int, 400)
+ message = "TTL specified %s is greater than max allowed %s." % (ttl + 60, ttl)
+ self.assertEqual(response.json["faultstring"], message)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_post_set_bad_ttl(self):
- response = self.app.post_json(TOKEN_V1_PATH, {'ttl': -1}, expect_errors=True)
+ response = self.app.post_json(TOKEN_V1_PATH, {"ttl": -1}, expect_errors=True)
self.assertEqual(response.status_int, 400)
- response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 0}, expect_errors=True)
+ response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 0}, expect_errors=True)
self.assertEqual(response.status_int, 400)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_get_unauthorized(self):
# Create a new token.
response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False)
# Verify the token. 401 is expected because an API key or token is not provided in header.
- data = {'token': str(response.json['token'])}
+ data = {"token": str(response.json["token"])}
response = self.app.post_json(TOKEN_VERIFY_PATH, data, expect_errors=True)
self.assertEqual(response.status_int, 401)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_get_unauthorized_bad_api_key(self):
# Create a new token.
response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False)
# Verify the token. 401 is expected because the API key is bad.
- headers = {'St2-Api-Key': 'foobar'}
- data = {'token': str(response.json['token'])}
- response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True)
+ headers = {"St2-Api-Key": "foobar"}
+ data = {"token": str(response.json["token"])}
+ response = self.app.post_json(
+ TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True
+ )
self.assertEqual(response.status_int, 401)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_get_unauthorized_bad_token(self):
# Create a new token.
response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False)
# Verify the token. 401 is expected because the token is bad.
- headers = {'X-Auth-Token': 'foobar'}
- data = {'token': str(response.json['token'])}
- response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True)
+ headers = {"X-Auth-Token": "foobar"}
+ data = {"token": str(response.json["token"])}
+ response = self.app.post_json(
+ TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True
+ )
self.assertEqual(response.status_int, 401)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
@mock.patch.object(
- ApiKey, 'get',
- mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar')))
+ ApiKey,
+ "get",
+ mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")),
+ )
def test_token_get_auth_with_api_key(self):
# Create a new token.
response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False)
# Verify the token. Use an API key to authenticate with the st2 auth get token endpoint.
- headers = {'St2-Api-Key': 'foobar'}
- data = {'token': str(response.json['token'])}
- response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True)
+ headers = {"St2-Api-Key": "foobar"}
+ data = {"token": str(response.json["token"])}
+ response = self.app.post_json(
+ TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True
+ )
self.assertEqual(response.status_int, 200)
- self.assertTrue(response.json['valid'])
+ self.assertTrue(response.json["valid"])
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
def test_token_get_auth_with_token(self):
# Create a new token.
response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=False)
# Verify the token. Use a token to authenticate with the st2 auth get token endpoint.
- headers = {'X-Auth-Token': str(response.json['token'])}
- data = {'token': str(response.json['token'])}
- response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True)
+ headers = {"X-Auth-Token": str(response.json["token"])}
+ data = {"token": str(response.json["token"])}
+ response = self.app.post_json(
+ TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True
+ )
self.assertEqual(response.status_int, 200)
- self.assertTrue(response.json['valid'])
+ self.assertTrue(response.json["valid"])
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name=USERNAME)))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME))
+ )
@mock.patch.object(
- ApiKey, 'get',
- mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar')))
+ ApiKey,
+ "get",
+ mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")),
+ )
@mock.patch.object(
- Token, 'get',
+ Token,
+ "get",
mock.MagicMock(
return_value=TokenDB(
- user=USERNAME, token='12345',
- expiry=date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=1))))
+ user=USERNAME,
+ token="12345",
+ expiry=date_utils.get_datetime_utc_now()
+ - datetime.timedelta(minutes=1),
+ )
+ ),
+ )
def test_token_get_unauthorized_bad_ttl(self):
# Verify the token. 400 is expected because the token has expired.
- headers = {'St2-Api-Key': 'foobar'}
- data = {'token': '12345'}
- response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False)
+ headers = {"St2-Api-Key": "foobar"}
+ data = {"token": "12345"}
+ response = self.app.post_json(
+ TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False
+ )
self.assertEqual(response.status_int, 200)
- self.assertFalse(response.json['valid'])
+ self.assertFalse(response.json["valid"])
diff --git a/st2auth/tests/unit/test_auth_backends.py b/st2auth/tests/unit/test_auth_backends.py
index 96856e8a3e..b367e328da 100644
--- a/st2auth/tests/unit/test_auth_backends.py
+++ b/st2auth/tests/unit/test_auth_backends.py
@@ -25,4 +25,4 @@
class AuthenticationBackendsTestCase(unittest2.TestCase):
def test_flat_file_backend_is_available_by_default(self):
available_backends = get_available_backends()
- self.assertIn('flat_file', available_backends)
+ self.assertIn("flat_file", available_backends)
diff --git a/st2auth/tests/unit/test_handlers.py b/st2auth/tests/unit/test_handlers.py
index a3627019d8..cf00e642a6 100644
--- a/st2auth/tests/unit/test_handlers.py
+++ b/st2auth/tests/unit/test_handlers.py
@@ -30,25 +30,23 @@
from st2tests.mocks.auth import MockRequest
from st2tests.mocks.auth import get_mock_backend
-__all__ = [
- 'AuthHandlerTestCase'
-]
+__all__ = ["AuthHandlerTestCase"]
-@mock.patch('st2auth.handlers.get_auth_backend_instance', get_mock_backend)
+@mock.patch("st2auth.handlers.get_auth_backend_instance", get_mock_backend)
class AuthHandlerTestCase(CleanDbTestCase):
def setUp(self):
super(AuthHandlerTestCase, self).setUp()
- cfg.CONF.auth.backend = 'mock'
+ cfg.CONF.auth.backend = "mock"
def test_proxy_handler(self):
h = handlers.ProxyAuthHandler()
request = {}
token = h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user='test_proxy_handler')
- self.assertEqual(token.user, 'test_proxy_handler')
+ request, headers={}, remote_addr=None, remote_user="test_proxy_handler"
+ )
+ self.assertEqual(token.user, "test_proxy_handler")
def test_standalone_bad_auth_type(self):
h = handlers.StandaloneAuthHandler()
@@ -56,8 +54,12 @@ def test_standalone_bad_auth_type(self):
with self.assertRaises(exc.HTTPUnauthorized):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('complex', DUMMY_CREDS))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("complex", DUMMY_CREDS),
+ )
def test_standalone_no_auth(self):
h = handlers.StandaloneAuthHandler()
@@ -65,8 +67,12 @@ def test_standalone_no_auth(self):
with self.assertRaises(exc.HTTPUnauthorized):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=None)
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=None,
+ )
def test_standalone_bad_auth_value(self):
h = handlers.StandaloneAuthHandler()
@@ -74,109 +80,159 @@ def test_standalone_bad_auth_value(self):
with self.assertRaises(exc.HTTPUnauthorized):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', 'gobblegobble'))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", "gobblegobble"),
+ )
def test_standalone_handler(self):
h = handlers.StandaloneAuthHandler()
request = {}
token = h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
- self.assertEqual(token.user, 'auser')
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
+ self.assertEqual(token.user, "auser")
def test_standalone_handler_ttl(self):
h = handlers.StandaloneAuthHandler()
token1 = h.handle_auth(
- MockRequest(23), headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
+ MockRequest(23),
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
token2 = h.handle_auth(
- MockRequest(2300), headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
- self.assertEqual(token1.user, 'auser')
+ MockRequest(2300),
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
+ self.assertEqual(token1.user, "auser")
self.assertNotEqual(token1.expiry, token2.expiry)
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name='auser')))
+ User, "get_by_name", mock.MagicMock(return_value=UserDB(name="auser"))
+ )
def test_standalone_for_user_not_service(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.user = 'anotheruser'
+ request.user = "anotheruser"
with self.assertRaises(exc.HTTPBadRequest):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name='auser', is_service=True)))
+ User,
+ "get_by_name",
+ mock.MagicMock(return_value=UserDB(name="auser", is_service=True)),
+ )
def test_standalone_for_user_service(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.user = 'anotheruser'
+ request.user = "anotheruser"
token = h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
- self.assertEqual(token.user, 'anotheruser')
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
+ self.assertEqual(token.user, "anotheruser")
def test_standalone_for_user_not_found(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.user = 'anotheruser'
+ request.user = "anotheruser"
with self.assertRaises(exc.HTTPBadRequest):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
def test_standalone_impersonate_user_not_found(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.impersonate_user = 'anotheruser'
+ request.impersonate_user = "anotheruser"
with self.assertRaises(exc.HTTPBadRequest):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
@mock.patch.object(
- User, 'get_by_name',
- mock.MagicMock(return_value=UserDB(name='auser', is_service=True)))
+ User,
+ "get_by_name",
+ mock.MagicMock(return_value=UserDB(name="auser", is_service=True)),
+ )
@mock.patch.object(
- User, 'get_by_nickname',
- mock.MagicMock(return_value=UserDB(name='anotheruser', is_service=True)))
+ User,
+ "get_by_nickname",
+ mock.MagicMock(return_value=UserDB(name="anotheruser", is_service=True)),
+ )
def test_standalone_impersonate_user_with_nick_origin(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.impersonate_user = 'anotheruser'
- request.nickname_origin = 'slack'
+ request.impersonate_user = "anotheruser"
+ request.nickname_origin = "slack"
token = h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
- self.assertEqual(token.user, 'anotheruser')
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
+ self.assertEqual(token.user, "anotheruser")
def test_standalone_impersonate_user_no_origin(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- request.impersonate_user = '@anotheruser'
+ request.impersonate_user = "@anotheruser"
with self.assertRaises(exc.HTTPBadRequest):
h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=('basic', DUMMY_CREDS))
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=("basic", DUMMY_CREDS),
+ )
def test_password_contains_colon(self):
h = handlers.StandaloneAuthHandler()
request = MockRequest(60)
- authorization = ('Basic', base64.b64encode(b'username:password:password'))
+ authorization = ("Basic", base64.b64encode(b"username:password:password"))
token = h.handle_auth(
- request, headers={}, remote_addr=None,
- remote_user=None, authorization=authorization)
- self.assertEqual(token.user, 'username')
+ request,
+ headers={},
+ remote_addr=None,
+ remote_user=None,
+ authorization=authorization,
+ )
+ self.assertEqual(token.user, "username")
diff --git a/st2auth/tests/unit/test_validation_utils.py b/st2auth/tests/unit/test_validation_utils.py
index 21ab5e26b5..213e106625 100644
--- a/st2auth/tests/unit/test_validation_utils.py
+++ b/st2auth/tests/unit/test_validation_utils.py
@@ -19,9 +19,7 @@
from st2auth.validation import validate_auth_backend_is_correctly_configured
from st2tests import config as tests_config
-__all__ = [
- 'ValidationUtilsTestCase'
-]
+__all__ = ["ValidationUtilsTestCase"]
class ValidationUtilsTestCase(unittest2.TestCase):
@@ -34,22 +32,31 @@ def test_validate_auth_backend_is_correctly_configured_success(self):
self.assertTrue(result)
def test_validate_auth_backend_is_correctly_configured_invalid_backend(self):
- cfg.CONF.set_override(group='auth', name='mode', override='invalid')
- expected_msg = ('Invalid auth mode "invalid" specified in the config. '
- 'Valid modes are: proxy, standalone')
- self.assertRaisesRegexp(ValueError, expected_msg,
- validate_auth_backend_is_correctly_configured)
-
- def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups(self):
+ cfg.CONF.set_override(group="auth", name="mode", override="invalid")
+ expected_msg = (
+ 'Invalid auth mode "invalid" specified in the config. '
+ "Valid modes are: proxy, standalone"
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, validate_auth_backend_is_correctly_configured
+ )
+
+ def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups(
+ self,
+ ):
# Flat file backend doesn't expose user group membership information aha provide
# "has group info" capability
- cfg.CONF.set_override(group='auth', name='backend', override='flat_file')
- cfg.CONF.set_override(group='auth', name='backend_kwargs',
- override='{"file_path": "dummy"}')
- cfg.CONF.set_override(group='rbac', name='enable', override=True)
- cfg.CONF.set_override(group='rbac', name='sync_remote_groups', override=True)
-
- expected_msg = ('Configured auth backend doesn\'t expose user group information. Disable '
- 'remote group synchronization or')
- self.assertRaisesRegexp(ValueError, expected_msg,
- validate_auth_backend_is_correctly_configured)
+ cfg.CONF.set_override(group="auth", name="backend", override="flat_file")
+ cfg.CONF.set_override(
+ group="auth", name="backend_kwargs", override='{"file_path": "dummy"}'
+ )
+ cfg.CONF.set_override(group="rbac", name="enable", override=True)
+ cfg.CONF.set_override(group="rbac", name="sync_remote_groups", override=True)
+
+ expected_msg = (
+ "Configured auth backend doesn't expose user group information. Disable "
+ "remote group synchronization or"
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, validate_auth_backend_is_correctly_configured
+ )
diff --git a/st2client/Makefile b/st2client/Makefile
index 9d6cf70a66..e17db7e4f6 100644
--- a/st2client/Makefile
+++ b/st2client/Makefile
@@ -9,7 +9,7 @@ RELEASE=1
COMPONENTS := st2client
.PHONY: rpm
-rpm:
+rpm:
$(PY3) setup.py bdist_rpm --python=$(PY3)
mkdir -p $(RPM_ROOT)/RPMS/noarch
cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm
diff --git a/st2client/dist_utils.py b/st2client/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2client/dist_utils.py
+++ b/st2client/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2client/setup.py b/st2client/setup.py
index 916b282301..b318aed359 100644
--- a/st2client/setup.py
+++ b/st2client/setup.py
@@ -26,10 +26,10 @@
check_pip_version()
-ST2_COMPONENT = 'st2client'
+ST2_COMPONENT = "st2client"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
-README_FILE = os.path.join(BASE_DIR, 'README.rst')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
+README_FILE = os.path.join(BASE_DIR, "README.rst")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
apply_vagrant_workaround()
@@ -40,43 +40,41 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description=('Python client library and CLI for the StackStorm (st2) event-driven '
- 'automation platform.'),
+ description=(
+ "Python client library and CLI for the StackStorm (st2) event-driven "
+ "automation platform."
+ ),
long_description=readme,
- author='StackStorm',
- author_email='info@stackstorm.com',
- url='https://stackstorm.com/',
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ url="https://stackstorm.com/",
classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Intended Audience :: Information Technology',
- 'Intended Audience :: Developers',
- 'Intended Audience :: System Administrators',
- 'License :: OSI Approved :: Apache Software License',
- 'Operating System :: POSIX :: Linux',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7'
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Information Technology",
+ "Intended Audience :: Developers",
+ "Intended Audience :: System Administrators",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: POSIX :: Linux",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
],
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- entry_points={
- 'console_scripts': [
- 'st2 = st2client.shell:main'
- ]
- },
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ entry_points={"console_scripts": ["st2 = st2client.shell:main"]},
project_urls={
- 'Pack Exchange': 'https://exchange.stackstorm.org',
- 'Repository': 'https://github.com/StackStorm/st2',
- 'Documentation': 'https://docs.stackstorm.com',
- 'Community': 'https://stackstorm.com/community-signup',
- 'Questions': 'https://forum.stackstorm.com/',
- 'Donate': 'https://funding.communitybridge.org/projects/stackstorm',
- 'News/Blog': 'https://stackstorm.com/blog',
- 'Security': 'https://docs.stackstorm.com/latest/security.html',
- 'Bug Reports': 'https://github.com/StackStorm/st2/issues',
- }
+ "Pack Exchange": "https://exchange.stackstorm.org",
+ "Repository": "https://github.com/StackStorm/st2",
+ "Documentation": "https://docs.stackstorm.com",
+ "Community": "https://stackstorm.com/community-signup",
+ "Questions": "https://forum.stackstorm.com/",
+ "Donate": "https://funding.communitybridge.org/projects/stackstorm",
+ "News/Blog": "https://stackstorm.com/blog",
+ "Security": "https://docs.stackstorm.com/latest/security.html",
+ "Bug Reports": "https://github.com/StackStorm/st2/issues",
+ },
)
diff --git a/st2client/st2client/__init__.py b/st2client/st2client/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2client/st2client/__init__.py
+++ b/st2client/st2client/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2client/st2client/base.py b/st2client/st2client/base.py
index e435540726..54b7a91b14 100644
--- a/st2client/st2client/base.py
+++ b/st2client/st2client/base.py
@@ -38,9 +38,7 @@
from st2client.utils.date import parse as parse_isotime
from st2client.utils.misc import merge_dicts
-__all__ = [
- 'BaseCLIApp'
-]
+__all__ = ["BaseCLIApp"]
# Fix for "os.getlogin()) OSError: [Errno 2] No such file or directory"
os.getlogin = lambda: pwd.getpwuid(os.getuid())[0]
@@ -51,14 +49,14 @@
TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS = 15
CONFIG_OPTION_TO_CLIENT_KWARGS_MAP = {
- 'base_url': ['general', 'base_url'],
- 'auth_url': ['auth', 'url'],
- 'stream_url': ['stream', 'url'],
- 'api_url': ['api', 'url'],
- 'api_version': ['general', 'api_version'],
- 'api_key': ['credentials', 'api_key'],
- 'cacert': ['general', 'cacert'],
- 'debug': ['cli', 'debug']
+ "base_url": ["general", "base_url"],
+ "auth_url": ["auth", "url"],
+ "stream_url": ["stream", "url"],
+ "api_url": ["api", "url"],
+ "api_version": ["general", "api_version"],
+ "api_key": ["credentials", "api_key"],
+ "cacert": ["general", "cacert"],
+ "debug": ["cli", "debug"],
}
@@ -74,7 +72,7 @@ class BaseCLIApp(object):
SKIP_AUTH_CLASSES = []
def get_client(self, args, debug=False):
- ST2_CLI_SKIP_CONFIG = os.environ.get('ST2_CLI_SKIP_CONFIG', 0)
+ ST2_CLI_SKIP_CONFIG = os.environ.get("ST2_CLI_SKIP_CONFIG", 0)
ST2_CLI_SKIP_CONFIG = int(ST2_CLI_SKIP_CONFIG)
skip_config = args.skip_config
@@ -82,12 +80,19 @@ def get_client(self, args, debug=False):
# Note: Options provided as the CLI argument have the highest precedence
# Precedence order: cli arguments > environment variables > rc file variables
- cli_options = ['base_url', 'auth_url', 'api_url', 'stream_url', 'api_version', 'cacert']
+ cli_options = [
+ "base_url",
+ "auth_url",
+ "api_url",
+ "stream_url",
+ "api_version",
+ "cacert",
+ ]
cli_options = {opt: getattr(args, opt, None) for opt in cli_options}
if cli_options.get("cacert", None) is not None:
- if cli_options["cacert"].lower() in ['true', '1', 't', 'y', 'yes']:
+ if cli_options["cacert"].lower() in ["true", "1", "t", "y", "yes"]:
cli_options["cacert"] = True
- elif cli_options["cacert"].lower() in ['false', '0', 'f', 'no']:
+ elif cli_options["cacert"].lower() in ["false", "0", "f", "no"]:
cli_options["cacert"] = False
config_file_options = self._get_config_file_options(args=args)
@@ -98,20 +103,22 @@ def get_client(self, args, debug=False):
kwargs = merge_dicts(kwargs, config_file_options)
kwargs = merge_dicts(kwargs, cli_options)
- kwargs['debug'] = debug
+ kwargs["debug"] = debug
client = Client(**kwargs)
if skip_config:
# Config parsing is skipped
- self.LOG.info('Skipping parsing CLI config')
+ self.LOG.info("Skipping parsing CLI config")
return client
# Ok to use config at this point
rc_config = get_config()
# Silence SSL warnings
- silence_ssl_warnings = rc_config.get('general', {}).get('silence_ssl_warnings', False)
+ silence_ssl_warnings = rc_config.get("general", {}).get(
+ "silence_ssl_warnings", False
+ )
if silence_ssl_warnings:
# pylint: disable=no-member
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
@@ -127,34 +134,45 @@ def get_client(self, args, debug=False):
# We also skip automatic authentication if token is provided via the environment variable
# or as a command line argument
- env_var_token = os.environ.get('ST2_AUTH_TOKEN', None)
- cli_argument_token = getattr(args, 'token', None)
- env_var_api_key = os.environ.get('ST2_API_KEY', None)
- cli_argument_api_key = getattr(args, 'api_key', None)
- if env_var_token or cli_argument_token or env_var_api_key or cli_argument_api_key:
+ env_var_token = os.environ.get("ST2_AUTH_TOKEN", None)
+ cli_argument_token = getattr(args, "token", None)
+ env_var_api_key = os.environ.get("ST2_API_KEY", None)
+ cli_argument_api_key = getattr(args, "api_key", None)
+ if (
+ env_var_token
+ or cli_argument_token
+ or env_var_api_key
+ or cli_argument_api_key
+ ):
return client
# If credentials are provided in the CLI config use them and try to authenticate
- credentials = rc_config.get('credentials', {})
- username = credentials.get('username', None)
- password = credentials.get('password', None)
- cache_token = rc_config.get('cli', {}).get('cache_token', False)
+ credentials = rc_config.get("credentials", {})
+ username = credentials.get("username", None)
+ password = credentials.get("password", None)
+ cache_token = rc_config.get("cli", {}).get("cache_token", False)
if username:
# Credentials are provided, try to authenticate agaist the API
try:
- token = self._get_auth_token(client=client, username=username, password=password,
- cache_token=cache_token)
+ token = self._get_auth_token(
+ client=client,
+ username=username,
+ password=password,
+ cache_token=cache_token,
+ )
except requests.exceptions.ConnectionError as e:
- self.LOG.warn('Auth API server is not available, skipping authentication.')
+ self.LOG.warn(
+ "Auth API server is not available, skipping authentication."
+ )
self.LOG.exception(e)
return client
except Exception as e:
- print('Failed to authenticate with credentials provided in the config.')
+ print("Failed to authenticate with credentials provided in the config.")
raise e
client.token = token
# TODO: Hack, refactor when splitting out the client
- os.environ['ST2_AUTH_TOKEN'] = token
+ os.environ["ST2_AUTH_TOKEN"] = token
return client
@@ -166,9 +184,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False):
:rtype: ``dict``
"""
rc_options = self._parse_config_file(
- args=args, validate_config_permissions=validate_config_permissions)
+ args=args, validate_config_permissions=validate_config_permissions
+ )
result = {}
- for kwarg_name, (section, option) in six.iteritems(CONFIG_OPTION_TO_CLIENT_KWARGS_MAP):
+ for kwarg_name, (section, option) in six.iteritems(
+ CONFIG_OPTION_TO_CLIENT_KWARGS_MAP
+ ):
result[kwarg_name] = rc_options.get(section, {}).get(option, None)
return result
@@ -176,10 +197,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False):
def _parse_config_file(self, args, validate_config_permissions=False):
config_file_path = self._get_config_file_path(args=args)
- parser = CLIConfigParser(config_file_path=config_file_path,
- validate_config_exists=False,
- validate_config_permissions=validate_config_permissions,
- log=self.LOG)
+ parser = CLIConfigParser(
+ config_file_path=config_file_path,
+ validate_config_exists=False,
+ validate_config_permissions=validate_config_permissions,
+ log=self.LOG,
+ )
result = parser.parse()
return result
@@ -189,7 +212,7 @@ def _get_config_file_path(self, args):
:rtype: ``str``
"""
- path = os.environ.get('ST2_CONFIG_FILE', ST2_CONFIG_PATH)
+ path = os.environ.get("ST2_CONFIG_FILE", ST2_CONFIG_PATH)
if args.config_file:
path = args.config_file
@@ -212,15 +235,16 @@ def _get_auth_token(self, client, username, password, cache_token):
:rtype: ``str``
"""
if cache_token:
- token = self._get_cached_auth_token(client=client, username=username,
- password=password)
+ token = self._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
else:
token = None
if not token:
# Token is either expired or not available
- token_obj = self._authenticate_and_retrieve_auth_token(client=client,
- username=username,
- password=password)
+ token_obj = self._authenticate_and_retrieve_auth_token(
+ client=client, username=username, password=password
+ )
self._cache_auth_token(token_obj=token_obj)
token = token_obj.token
@@ -243,10 +267,12 @@ def _get_cached_auth_token(self, client, username, password):
if not os.access(ST2_CONFIG_DIRECTORY, os.R_OK):
# We don't have read access to the file with a cached token
- message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read '
- 'access to the parent directory). Subsequent requests won\'t use a '
- 'cached token meaning they may be slower.' % (cached_token_path,
- os.getlogin()))
+ message = (
+ 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read '
+ "access to the parent directory). Subsequent requests won't use a "
+ "cached token meaning they may be slower."
+ % (cached_token_path, os.getlogin())
+ )
self.LOG.warn(message)
return None
@@ -255,9 +281,11 @@ def _get_cached_auth_token(self, client, username, password):
if not os.access(cached_token_path, os.R_OK):
# We don't have read access to the file with a cached token
- message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read '
- 'access to this file). Subsequent requests won\'t use a cached token '
- 'meaning they may be slower.' % (cached_token_path, os.getlogin()))
+ message = (
+ 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read '
+ "access to this file). Subsequent requests won't use a cached token "
+ "meaning they may be slower." % (cached_token_path, os.getlogin())
+ )
self.LOG.warn(message)
return None
@@ -267,9 +295,11 @@ def _get_cached_auth_token(self, client, username, password):
if others_st_mode >= 2:
# Every user has access to this file which is dangerous
- message = ('Permissions (%s) for cached token file "%s" are too permissive. Please '
- 'restrict the permissions and make sure only your own user can read '
- 'from or write to the file.' % (file_st_mode, cached_token_path))
+ message = (
+ 'Permissions (%s) for cached token file "%s" are too permissive. Please '
+ "restrict the permissions and make sure only your own user can read "
+ "from or write to the file." % (file_st_mode, cached_token_path)
+ )
self.LOG.warn(message)
with open(cached_token_path) as fp:
@@ -278,16 +308,20 @@ def _get_cached_auth_token(self, client, username, password):
try:
data = json.loads(data)
- token = data['token']
- expire_timestamp = data['expire_timestamp']
+ token = data["token"]
+ expire_timestamp = data["expire_timestamp"]
except Exception as e:
- msg = ('File "%s" with cached token is corrupted or invalid (%s). Please delete '
- ' this file' % (cached_token_path, six.text_type(e)))
+ msg = (
+ 'File "%s" with cached token is corrupted or invalid (%s). Please delete '
+ " this file" % (cached_token_path, six.text_type(e))
+ )
raise ValueError(msg)
now = int(time.time())
if (expire_timestamp - TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS) < now:
- self.LOG.debug('Cached token from file "%s" has expired' % (cached_token_path))
+ self.LOG.debug(
+ 'Cached token from file "%s" has expired' % (cached_token_path)
+ )
# Token has expired
return None
@@ -312,19 +346,25 @@ def _cache_auth_token(self, token_obj):
if not os.access(ST2_CONFIG_DIRECTORY, os.W_OK):
# We don't have write access to the file with a cached token
- message = ('Unable to write token to "%s" (user %s doesn\'t have write '
- 'access to the parent directory). Subsequent requests won\'t use a '
- 'cached token meaning they may be slower.' % (cached_token_path,
- os.getlogin()))
+ message = (
+ 'Unable to write token to "%s" (user %s doesn\'t have write '
+ "access to the parent directory). Subsequent requests won't use a "
+ "cached token meaning they may be slower."
+ % (cached_token_path, os.getlogin())
+ )
self.LOG.warn(message)
return None
- if os.path.isfile(cached_token_path) and not os.access(cached_token_path, os.W_OK):
+ if os.path.isfile(cached_token_path) and not os.access(
+ cached_token_path, os.W_OK
+ ):
# We don't have write access to the file with a cached token
- message = ('Unable to write token to "%s" (user %s doesn\'t have write '
- 'access to this file). Subsequent requests won\'t use a '
- 'cached token meaning they may be slower.' % (cached_token_path,
- os.getlogin()))
+ message = (
+ 'Unable to write token to "%s" (user %s doesn\'t have write '
+ "access to this file). Subsequent requests won't use a "
+ "cached token meaning they may be slower."
+ % (cached_token_path, os.getlogin())
+ )
self.LOG.warn(message)
return None
@@ -333,8 +373,8 @@ def _cache_auth_token(self, token_obj):
expire_timestamp = calendar.timegm(expire_timestamp.timetuple())
data = {}
- data['token'] = token
- data['expire_timestamp'] = expire_timestamp
+ data["token"] = token
+ data["expire_timestamp"] = expire_timestamp
data = json.dumps(data)
# Note: We explictly use fdopen instead of open + chmod to avoid a security issue.
@@ -342,7 +382,7 @@ def _cache_auth_token(self, token_obj):
# open and chmod) when file can potentially be read by other users if the default
# permissions used during create allow that.
fd = os.open(cached_token_path, os.O_WRONLY | os.O_CREAT, 0o660)
- with os.fdopen(fd, 'w') as fp:
+ with os.fdopen(fd, "w") as fp:
fp.write(data)
os.chmod(cached_token_path, 0o660)
@@ -350,8 +390,12 @@ def _cache_auth_token(self, token_obj):
return True
def _authenticate_and_retrieve_auth_token(self, client, username, password):
- manager = models.ResourceManager(models.Token, client.endpoints['auth'],
- cacert=client.cacert, debug=client.debug)
+ manager = models.ResourceManager(
+ models.Token,
+ client.endpoints["auth"],
+ cacert=client.cacert,
+ debug=client.debug,
+ )
instance = models.Token()
instance = manager.create(instance, auth=(username, password))
return instance
@@ -360,7 +404,7 @@ def _get_cached_token_path_for_user(self, username):
"""
Retrieve cached token path for the provided username.
"""
- file_name = 'token-%s' % (username)
+ file_name = "token-%s" % (username)
result = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, file_name))
return result
@@ -368,10 +412,10 @@ def _print_config(self, args):
config = self._parse_config_file(args=args, validate_config_permissions=False)
for section, options in six.iteritems(config):
- print('[%s]' % (section))
+ print("[%s]" % (section))
for name, value in six.iteritems(options):
- print('%s = %s' % (name, value))
+ print("%s = %s" % (name, value))
def _print_debug_info(self, args):
# Print client settings
@@ -388,19 +432,19 @@ def _print_client_settings(self, args):
config_file_path = self._get_config_file_path(args=args)
- print('CLI settings:')
- print('----------------')
- print('Config file path: %s' % (config_file_path))
- print('Client settings:')
- print('----------------')
- print('ST2_BASE_URL: %s' % (client.endpoints['base']))
- print('ST2_AUTH_URL: %s' % (client.endpoints['auth']))
- print('ST2_API_URL: %s' % (client.endpoints['api']))
- print('ST2_STREAM_URL: %s' % (client.endpoints['stream']))
- print('ST2_AUTH_TOKEN: %s' % (os.environ.get('ST2_AUTH_TOKEN')))
- print('')
- print('Proxy settings:')
- print('---------------')
- print('HTTP_PROXY: %s' % (os.environ.get('HTTP_PROXY', '')))
- print('HTTPS_PROXY: %s' % (os.environ.get('HTTPS_PROXY', '')))
- print('')
+ print("CLI settings:")
+ print("----------------")
+ print("Config file path: %s" % (config_file_path))
+ print("Client settings:")
+ print("----------------")
+ print("ST2_BASE_URL: %s" % (client.endpoints["base"]))
+ print("ST2_AUTH_URL: %s" % (client.endpoints["auth"]))
+ print("ST2_API_URL: %s" % (client.endpoints["api"]))
+ print("ST2_STREAM_URL: %s" % (client.endpoints["stream"]))
+ print("ST2_AUTH_TOKEN: %s" % (os.environ.get("ST2_AUTH_TOKEN")))
+ print("")
+ print("Proxy settings:")
+ print("---------------")
+ print("HTTP_PROXY: %s" % (os.environ.get("HTTP_PROXY", "")))
+ print("HTTPS_PROXY: %s" % (os.environ.get("HTTPS_PROXY", "")))
+ print("")
diff --git a/st2client/st2client/client.py b/st2client/st2client/client.py
index 6bda37942b..9772c825b7 100644
--- a/st2client/st2client/client.py
+++ b/st2client/st2client/client.py
@@ -47,144 +47,224 @@
DEFAULT_AUTH_PORT = 9100
DEFAULT_STREAM_PORT = 9102
-DEFAULT_BASE_URL = 'http://127.0.0.1'
-DEFAULT_API_VERSION = 'v1'
+DEFAULT_BASE_URL = "http://127.0.0.1"
+DEFAULT_API_VERSION = "v1"
class Client(object):
- def __init__(self, base_url=None, auth_url=None, api_url=None, stream_url=None,
- api_version=None, cacert=None, debug=False, token=None, api_key=None):
+ def __init__(
+ self,
+ base_url=None,
+ auth_url=None,
+ api_url=None,
+ stream_url=None,
+ api_version=None,
+ cacert=None,
+ debug=False,
+ token=None,
+ api_key=None,
+ ):
# Get CLI options. If not given, then try to get it from the environment.
self.endpoints = dict()
# Populate the endpoints
if base_url:
- self.endpoints['base'] = base_url
+ self.endpoints["base"] = base_url
else:
- self.endpoints['base'] = os.environ.get('ST2_BASE_URL', DEFAULT_BASE_URL)
+ self.endpoints["base"] = os.environ.get("ST2_BASE_URL", DEFAULT_BASE_URL)
- api_version = api_version or os.environ.get('ST2_API_VERSION', DEFAULT_API_VERSION)
+ api_version = api_version or os.environ.get(
+ "ST2_API_VERSION", DEFAULT_API_VERSION
+ )
- self.endpoints['exp'] = '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, 'exp')
+ self.endpoints["exp"] = "%s:%s/%s" % (
+ self.endpoints["base"],
+ DEFAULT_API_PORT,
+ "exp",
+ )
if api_url:
- self.endpoints['api'] = api_url
+ self.endpoints["api"] = api_url
else:
- self.endpoints['api'] = os.environ.get(
- 'ST2_API_URL', '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, api_version))
+ self.endpoints["api"] = os.environ.get(
+ "ST2_API_URL",
+ "%s:%s/%s" % (self.endpoints["base"], DEFAULT_API_PORT, api_version),
+ )
if auth_url:
- self.endpoints['auth'] = auth_url
+ self.endpoints["auth"] = auth_url
else:
- self.endpoints['auth'] = os.environ.get(
- 'ST2_AUTH_URL', '%s:%s' % (self.endpoints['base'], DEFAULT_AUTH_PORT))
+ self.endpoints["auth"] = os.environ.get(
+ "ST2_AUTH_URL", "%s:%s" % (self.endpoints["base"], DEFAULT_AUTH_PORT)
+ )
if stream_url:
- self.endpoints['stream'] = stream_url
+ self.endpoints["stream"] = stream_url
else:
- self.endpoints['stream'] = os.environ.get(
- 'ST2_STREAM_URL',
- '%s:%s/%s' % (
- self.endpoints['base'],
- DEFAULT_STREAM_PORT,
- api_version
- )
+ self.endpoints["stream"] = os.environ.get(
+ "ST2_STREAM_URL",
+ "%s:%s/%s" % (self.endpoints["base"], DEFAULT_STREAM_PORT, api_version),
)
if cacert is not None:
self.cacert = cacert
else:
- self.cacert = os.environ.get('ST2_CACERT', None)
+ self.cacert = os.environ.get("ST2_CACERT", None)
# Note: boolean is also a valid value for "cacert"
is_cacert_string = isinstance(self.cacert, six.string_types)
- if (self.cacert and is_cacert_string and not os.path.isfile(self.cacert)):
+ if self.cacert and is_cacert_string and not os.path.isfile(self.cacert):
raise ValueError('CA cert file "%s" does not exist.' % (self.cacert))
self.debug = debug
# Note: This is a nasty hack for now, but we need to get rid of the decrator abuse
if token:
- os.environ['ST2_AUTH_TOKEN'] = token
+ os.environ["ST2_AUTH_TOKEN"] = token
self.token = token
if api_key:
- os.environ['ST2_API_KEY'] = api_key
+ os.environ["ST2_API_KEY"] = api_key
self.api_key = api_key
# Instantiate resource managers and assign appropriate API endpoint.
self.managers = dict()
- self.managers['Token'] = ResourceManager(
- models.Token, self.endpoints['auth'], cacert=self.cacert, debug=self.debug)
- self.managers['RunnerType'] = ResourceManager(
- models.RunnerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Action'] = ActionResourceManager(
- models.Action, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['ActionAlias'] = ActionAliasResourceManager(
- models.ActionAlias, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['ActionAliasExecution'] = ActionAliasExecutionManager(
- models.ActionAliasExecution, self.endpoints['api'],
- cacert=self.cacert, debug=self.debug)
- self.managers['ApiKey'] = ResourceManager(
- models.ApiKey, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Config'] = ConfigManager(
- models.Config, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['ConfigSchema'] = ResourceManager(
- models.ConfigSchema, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Execution'] = ExecutionResourceManager(
- models.Execution, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
+ self.managers["Token"] = ResourceManager(
+ models.Token, self.endpoints["auth"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["RunnerType"] = ResourceManager(
+ models.RunnerType,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Action"] = ActionResourceManager(
+ models.Action, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["ActionAlias"] = ActionAliasResourceManager(
+ models.ActionAlias,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["ActionAliasExecution"] = ActionAliasExecutionManager(
+ models.ActionAliasExecution,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["ApiKey"] = ResourceManager(
+ models.ApiKey, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Config"] = ConfigManager(
+ models.Config, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["ConfigSchema"] = ResourceManager(
+ models.ConfigSchema,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Execution"] = ExecutionResourceManager(
+ models.Execution,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
# NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for
# backward compatibility reasons until v3.2.0
- self.managers['LiveAction'] = self.managers['Execution']
- self.managers['Inquiry'] = InquiryResourceManager(
- models.Inquiry, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Pack'] = PackResourceManager(
- models.Pack, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Policy'] = ResourceManager(
- models.Policy, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['PolicyType'] = ResourceManager(
- models.PolicyType, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Rule'] = ResourceManager(
- models.Rule, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Sensor'] = ResourceManager(
- models.Sensor, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['TriggerType'] = ResourceManager(
- models.TriggerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Trigger'] = ResourceManager(
- models.Trigger, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['TriggerInstance'] = TriggerInstanceResourceManager(
- models.TriggerInstance, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['KeyValuePair'] = ResourceManager(
- models.KeyValuePair, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Webhook'] = WebhookManager(
- models.Webhook, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Timer'] = ResourceManager(
- models.Timer, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Trace'] = ResourceManager(
- models.Trace, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['RuleEnforcement'] = ResourceManager(
- models.RuleEnforcement, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['Stream'] = StreamManager(
- self.endpoints['stream'], cacert=self.cacert, debug=self.debug)
- self.managers['Workflow'] = WorkflowManager(
- self.endpoints['api'], cacert=self.cacert, debug=self.debug)
+ self.managers["LiveAction"] = self.managers["Execution"]
+ self.managers["Inquiry"] = InquiryResourceManager(
+ models.Inquiry, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Pack"] = PackResourceManager(
+ models.Pack, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Policy"] = ResourceManager(
+ models.Policy, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["PolicyType"] = ResourceManager(
+ models.PolicyType,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Rule"] = ResourceManager(
+ models.Rule, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Sensor"] = ResourceManager(
+ models.Sensor, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["TriggerType"] = ResourceManager(
+ models.TriggerType,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Trigger"] = ResourceManager(
+ models.Trigger, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["TriggerInstance"] = TriggerInstanceResourceManager(
+ models.TriggerInstance,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["KeyValuePair"] = ResourceManager(
+ models.KeyValuePair,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Webhook"] = WebhookManager(
+ models.Webhook, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Timer"] = ResourceManager(
+ models.Timer, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Trace"] = ResourceManager(
+ models.Trace, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["RuleEnforcement"] = ResourceManager(
+ models.RuleEnforcement,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+ self.managers["Stream"] = StreamManager(
+ self.endpoints["stream"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["Workflow"] = WorkflowManager(
+ self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
# Service Registry
- self.managers['ServiceRegistryGroups'] = ServiceRegistryGroupsManager(
- models.ServiceRegistryGroup, self.endpoints['api'], cacert=self.cacert,
- debug=self.debug)
-
- self.managers['ServiceRegistryMembers'] = ServiceRegistryMembersManager(
- models.ServiceRegistryMember, self.endpoints['api'], cacert=self.cacert,
- debug=self.debug)
+ self.managers["ServiceRegistryGroups"] = ServiceRegistryGroupsManager(
+ models.ServiceRegistryGroup,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
+
+ self.managers["ServiceRegistryMembers"] = ServiceRegistryMembersManager(
+ models.ServiceRegistryMember,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
# RBAC
- self.managers['Role'] = ResourceManager(
- models.Role, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
- self.managers['UserRoleAssignment'] = ResourceManager(
- models.UserRoleAssignment, self.endpoints['api'], cacert=self.cacert, debug=self.debug)
+ self.managers["Role"] = ResourceManager(
+ models.Role, self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
+ self.managers["UserRoleAssignment"] = ResourceManager(
+ models.UserRoleAssignment,
+ self.endpoints["api"],
+ cacert=self.cacert,
+ debug=self.debug,
+ )
@add_auth_token_to_kwargs_from_env
def get_user_info(self, **kwargs):
@@ -193,9 +273,10 @@ def get_user_info(self, **kwargs):
:rtype: ``dict``
"""
- url = '/user'
- client = httpclient.HTTPClient(root=self.endpoints['api'], cacert=self.cacert,
- debug=self.debug)
+ url = "/user"
+ client = httpclient.HTTPClient(
+ root=self.endpoints["api"], cacert=self.cacert, debug=self.debug
+ )
response = client.get(url=url, **kwargs)
if response.status_code != 200:
@@ -205,80 +286,85 @@ def get_user_info(self, **kwargs):
@property
def actions(self):
- return self.managers['Action']
+ return self.managers["Action"]
@property
def apikeys(self):
- return self.managers['ApiKey']
+ return self.managers["ApiKey"]
@property
def keys(self):
- return self.managers['KeyValuePair']
+ return self.managers["KeyValuePair"]
@property
def executions(self):
- return self.managers['Execution']
+ return self.managers["Execution"]
# NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for
# backward compatibility reasons until v3.2.0
@property
def liveactions(self):
- warnings.warn(('st2client.liveactions has been renamed to st2client.executions, please '
- 'update your code'), DeprecationWarning)
+ warnings.warn(
+ (
+ "st2client.liveactions has been renamed to st2client.executions, please "
+ "update your code"
+ ),
+ DeprecationWarning,
+ )
return self.executions
@property
def inquiries(self):
- return self.managers['Inquiry']
+ return self.managers["Inquiry"]
@property
def packs(self):
- return self.managers['Pack']
+ return self.managers["Pack"]
@property
def policies(self):
- return self.managers['Policy']
+ return self.managers["Policy"]
@property
def policytypes(self):
- return self.managers['PolicyType']
+ return self.managers["PolicyType"]
@property
def rules(self):
- return self.managers['Rule']
+ return self.managers["Rule"]
@property
def runners(self):
- return self.managers['RunnerType']
+ return self.managers["RunnerType"]
@property
def sensors(self):
- return self.managers['Sensor']
+ return self.managers["Sensor"]
@property
def tokens(self):
- return self.managers['Token']
+ return self.managers["Token"]
@property
def triggertypes(self):
- return self.managers['TriggerType']
+ return self.managers["TriggerType"]
@property
def triggerinstances(self):
- return self.managers['TriggerInstance']
+ return self.managers["TriggerInstance"]
@property
def trace(self):
- return self.managers['Trace']
+ return self.managers["Trace"]
@property
def ruleenforcements(self):
- return self.managers['RuleEnforcement']
+ return self.managers["RuleEnforcement"]
@property
def webhooks(self):
- return self.managers['Webhook']
+ return self.managers["Webhook"]
@property
def workflows(self):
- return self.managers['Workflow']
+ return self.managers["Workflow"]
diff --git a/st2client/st2client/commands/__init__.py b/st2client/st2client/commands/__init__.py
index a9b9cee86b..995d3fd9d3 100644
--- a/st2client/st2client/commands/__init__.py
+++ b/st2client/st2client/commands/__init__.py
@@ -35,9 +35,9 @@ def __init__(self, name, description, app, subparsers, parent_parser=None):
self.description = description
self.app = app
self.parent_parser = parent_parser
- self.parser = subparsers.add_parser(self.name,
- description=self.description,
- help=self.description)
+ self.parser = subparsers.add_parser(
+ self.name, description=self.description, help=self.description
+ )
self.commands = dict()
@@ -45,16 +45,19 @@ def __init__(self, name, description, app, subparsers, parent_parser=None):
class Command(object):
"""Represents a commandlet in the command tree."""
- def __init__(self, name, description, app, subparsers,
- parent_parser=None, add_help=True):
+ def __init__(
+ self, name, description, app, subparsers, parent_parser=None, add_help=True
+ ):
self.name = name
self.description = description
self.app = app
self.parent_parser = parent_parser
- self.parser = subparsers.add_parser(self.name,
- description=self.description,
- help=self.description,
- add_help=add_help)
+ self.parser = subparsers.add_parser(
+ self.name,
+ description=self.description,
+ help=self.description,
+ add_help=add_help,
+ )
self.parser.set_defaults(func=self.run_and_print)
@abc.abstractmethod
@@ -74,8 +77,8 @@ def run_and_print(self, args, **kwargs):
raise NotImplementedError
def format_output(self, subject, formatter, *args, **kwargs):
- json = kwargs.get('json', False)
- yaml = kwargs.get('yaml', False)
+ json = kwargs.get("json", False)
+ yaml = kwargs.get("yaml", False)
if json:
func = doc.JsonFormatter.format
@@ -90,4 +93,4 @@ def print_output(self, subject, formatter, *args, **kwargs):
output = self.format_output(subject, formatter, *args, **kwargs)
print(output)
else:
- print('No matching items found')
+ print("No matching items found")
diff --git a/st2client/st2client/commands/action.py b/st2client/st2client/commands/action.py
index 7a41d9e2eb..dcf76c3a7d 100644
--- a/st2client/st2client/commands/action.py
+++ b/st2client/st2client/commands/action.py
@@ -44,60 +44,54 @@
LOG = logging.getLogger(__name__)
-LIVEACTION_STATUS_REQUESTED = 'requested'
-LIVEACTION_STATUS_SCHEDULED = 'scheduled'
-LIVEACTION_STATUS_DELAYED = 'delayed'
-LIVEACTION_STATUS_RUNNING = 'running'
-LIVEACTION_STATUS_SUCCEEDED = 'succeeded'
-LIVEACTION_STATUS_FAILED = 'failed'
-LIVEACTION_STATUS_TIMED_OUT = 'timeout'
-LIVEACTION_STATUS_ABANDONED = 'abandoned'
-LIVEACTION_STATUS_CANCELING = 'canceling'
-LIVEACTION_STATUS_CANCELED = 'canceled'
-LIVEACTION_STATUS_PAUSING = 'pausing'
-LIVEACTION_STATUS_PAUSED = 'paused'
-LIVEACTION_STATUS_RESUMING = 'resuming'
+LIVEACTION_STATUS_REQUESTED = "requested"
+LIVEACTION_STATUS_SCHEDULED = "scheduled"
+LIVEACTION_STATUS_DELAYED = "delayed"
+LIVEACTION_STATUS_RUNNING = "running"
+LIVEACTION_STATUS_SUCCEEDED = "succeeded"
+LIVEACTION_STATUS_FAILED = "failed"
+LIVEACTION_STATUS_TIMED_OUT = "timeout"
+LIVEACTION_STATUS_ABANDONED = "abandoned"
+LIVEACTION_STATUS_CANCELING = "canceling"
+LIVEACTION_STATUS_CANCELED = "canceled"
+LIVEACTION_STATUS_PAUSING = "pausing"
+LIVEACTION_STATUS_PAUSED = "paused"
+LIVEACTION_STATUS_RESUMING = "resuming"
LIVEACTION_COMPLETED_STATES = [
LIVEACTION_STATUS_SUCCEEDED,
LIVEACTION_STATUS_FAILED,
LIVEACTION_STATUS_TIMED_OUT,
LIVEACTION_STATUS_CANCELED,
- LIVEACTION_STATUS_ABANDONED
+ LIVEACTION_STATUS_ABANDONED,
]
# Who parameters should be masked when displaying action execution output
-PARAMETERS_TO_MASK = [
- 'password',
- 'private_key'
-]
+PARAMETERS_TO_MASK = ["password", "private_key"]
# A list of environment variables which are never inherited when using run
# --inherit-env flag
ENV_VARS_BLACKLIST = [
- 'pwd',
- 'mail',
- 'username',
- 'user',
- 'path',
- 'home',
- 'ps1',
- 'shell',
- 'pythonpath',
- 'ssh_tty',
- 'ssh_connection',
- 'lang',
- 'ls_colors',
- 'logname',
- 'oldpwd',
- 'term',
- 'xdg_session_id'
+ "pwd",
+ "mail",
+ "username",
+ "user",
+ "path",
+ "home",
+ "ps1",
+ "shell",
+ "pythonpath",
+ "ssh_tty",
+ "ssh_connection",
+ "lang",
+ "ls_colors",
+ "logname",
+ "oldpwd",
+ "term",
+ "xdg_session_id",
]
-WORKFLOW_RUNNER_TYPES = [
- 'action-chain',
- 'orquesta'
-]
+WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"]
def format_parameters(value):
@@ -108,15 +102,15 @@ def format_parameters(value):
for param_name, _ in value.items():
if param_name in PARAMETERS_TO_MASK:
- value[param_name] = '********'
+ value[param_name] = "********"
return value
# String for indenting etc.
-WF_PREFIX = '+ '
-NON_WF_PREFIX = ' '
-INDENT_CHAR = ' '
+WF_PREFIX = "+ "
+NON_WF_PREFIX = " "
+INDENT_CHAR = " "
def format_wf_instances(instances):
@@ -127,7 +121,7 @@ def format_wf_instances(instances):
# only add extr chars if there are workflows.
has_wf = False
for instance in instances:
- if not getattr(instance, 'children', None):
+ if not getattr(instance, "children", None):
continue
else:
has_wf = True
@@ -136,7 +130,7 @@ def format_wf_instances(instances):
return instances
# Prepend wf and non_wf prefixes.
for instance in instances:
- if getattr(instance, 'children', None):
+ if getattr(instance, "children", None):
instance.id = WF_PREFIX + instance.id
else:
instance.id = NON_WF_PREFIX + instance.id
@@ -158,59 +152,75 @@ def format_execution_status(instance):
executions which are in running state and execution total run time for all the executions
which have finished.
"""
- status = getattr(instance, 'status', None)
- start_timestamp = getattr(instance, 'start_timestamp', None)
- end_timestamp = getattr(instance, 'end_timestamp', None)
+ status = getattr(instance, "status", None)
+ start_timestamp = getattr(instance, "start_timestamp", None)
+ end_timestamp = getattr(instance, "end_timestamp", None)
if status == LIVEACTION_STATUS_RUNNING and start_timestamp:
start_timestamp = instance.start_timestamp
start_timestamp = parse_isotime(start_timestamp)
start_timestamp = calendar.timegm(start_timestamp.timetuple())
now = int(time.time())
- elapsed_seconds = (now - start_timestamp)
- instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds)
+ elapsed_seconds = now - start_timestamp
+ instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds)
elif status in LIVEACTION_COMPLETED_STATES and start_timestamp and end_timestamp:
start_timestamp = parse_isotime(start_timestamp)
start_timestamp = calendar.timegm(start_timestamp.timetuple())
end_timestamp = parse_isotime(end_timestamp)
end_timestamp = calendar.timegm(end_timestamp.timetuple())
- elapsed_seconds = (end_timestamp - start_timestamp)
- instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds)
+ elapsed_seconds = end_timestamp - start_timestamp
+ instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds)
return instance
class ActionBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(ActionBranch, self).__init__(
- models.Action, description, app, subparsers,
+ models.Action,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': ActionListCommand,
- 'get': ActionGetCommand,
- 'update': ActionUpdateCommand,
- 'delete': ActionDeleteCommand
- })
+ "list": ActionListCommand,
+ "get": ActionGetCommand,
+ "update": ActionUpdateCommand,
+ "delete": ActionDeleteCommand,
+ },
+ )
# Registers extended commands
- self.commands['enable'] = ActionEnableCommand(self.resource, self.app, self.subparsers)
- self.commands['disable'] = ActionDisableCommand(self.resource, self.app, self.subparsers)
- self.commands['execute'] = ActionRunCommand(
- self.resource, self.app, self.subparsers,
- add_help=False)
+ self.commands["enable"] = ActionEnableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["disable"] = ActionDisableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["execute"] = ActionRunCommand(
+ self.resource, self.app, self.subparsers, add_help=False
+ )
class ActionListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['ref', 'pack', 'description']
+ display_attributes = ["ref", "pack", "description"]
class ActionGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description',
- 'enabled', 'entry_point', 'runner_type',
- 'parameters']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "uid",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ "entry_point",
+ "runner_type",
+ "parameters",
+ ]
class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand):
@@ -218,17 +228,33 @@ class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand):
class ActionEnableCommand(resource.ContentPackResourceEnableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'description',
- 'enabled', 'entry_point', 'runner_type',
- 'parameters']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ "entry_point",
+ "runner_type",
+ "parameters",
+ ]
class ActionDisableCommand(resource.ContentPackResourceDisableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'description',
- 'enabled', 'entry_point', 'runner_type',
- 'parameters']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ "entry_point",
+ "runner_type",
+ "parameters",
+ ]
class ActionDeleteCommand(resource.ContentPackResourceDeleteCommand):
@@ -239,15 +265,32 @@ class ActionRunCommandMixin(object):
"""
Mixin class which contains utility functions related to action execution.
"""
- display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status',
- 'start_timestamp', 'end_timestamp', 'result']
- attribute_display_order = ['id', 'action.ref', 'context.user', 'parameters', 'status',
- 'start_timestamp', 'end_timestamp', 'result']
+
+ display_attributes = [
+ "id",
+ "action.ref",
+ "context.user",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ "result",
+ ]
+ attribute_display_order = [
+ "id",
+ "action.ref",
+ "context.user",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ "result",
+ ]
attribute_transform_functions = {
- 'start_timestamp': format_isodate_for_user_timezone,
- 'end_timestamp': format_isodate_for_user_timezone,
- 'parameters': format_parameters,
- 'status': format_status
+ "start_timestamp": format_isodate_for_user_timezone,
+ "end_timestamp": format_isodate_for_user_timezone,
+ "parameters": format_parameters,
+ "status": format_status,
}
poll_interval = 2 # how often to poll for execution completion when using sync mode
@@ -262,14 +305,19 @@ def run_and_print(self, args, **kwargs):
execution = self.run(args, **kwargs)
if args.action_async:
- self.print_output('To get the results, execute:\n st2 execution get %s' %
- (execution.id), six.text_type)
- self.print_output('\nTo view output in real-time, execute:\n st2 execution '
- 'tail %s' % (execution.id), six.text_type)
+ self.print_output(
+ "To get the results, execute:\n st2 execution get %s" % (execution.id),
+ six.text_type,
+ )
+ self.print_output(
+ "\nTo view output in real-time, execute:\n st2 execution "
+ "tail %s" % (execution.id),
+ six.text_type,
+ )
else:
self._print_execution_details(execution=execution, args=args, **kwargs)
- if execution.status == 'failed':
+ if execution.status == "failed":
# Exit with non zero if the action has failed
sys.exit(1)
@@ -278,52 +326,99 @@ def _add_common_options(self):
# Display options
task_list_arg_grp = root_arg_grp.add_argument_group()
- task_list_arg_grp.add_argument('--with-schema',
- default=False, action='store_true',
- help=('Show schema_ouput suggestion with action.'))
-
- task_list_arg_grp.add_argument('--raw', action='store_true',
- help='Raw output, don\'t show sub-tasks for workflows.')
- task_list_arg_grp.add_argument('--show-tasks', action='store_true',
- help='Whether to show sub-tasks of an execution.')
- task_list_arg_grp.add_argument('--depth', type=int, default=-1,
- help='Depth to which to show sub-tasks. \
- By default all are shown.')
- task_list_arg_grp.add_argument('-w', '--width', nargs='+', type=int, default=None,
- help='Set the width of columns in output.')
+ task_list_arg_grp.add_argument(
+ "--with-schema",
+ default=False,
+ action="store_true",
+ help=("Show schema_ouput suggestion with action."),
+ )
+
+ task_list_arg_grp.add_argument(
+ "--raw",
+ action="store_true",
+ help="Raw output, don't show sub-tasks for workflows.",
+ )
+ task_list_arg_grp.add_argument(
+ "--show-tasks",
+ action="store_true",
+ help="Whether to show sub-tasks of an execution.",
+ )
+ task_list_arg_grp.add_argument(
+ "--depth",
+ type=int,
+ default=-1,
+ help="Depth to which to show sub-tasks. \
+ By default all are shown.",
+ )
+ task_list_arg_grp.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help="Set the width of columns in output.",
+ )
execution_details_arg_grp = root_arg_grp.add_mutually_exclusive_group()
detail_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group()
- detail_arg_grp.add_argument('--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" or unspecified will '
- 'return all attributes.'))
- detail_arg_grp.add_argument('-d', '--detail', action='store_true',
- help='Display full detail of the execution in table format.')
+ detail_arg_grp.add_argument(
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" or unspecified will '
+ "return all attributes."
+ ),
+ )
+ detail_arg_grp.add_argument(
+ "-d",
+ "--detail",
+ action="store_true",
+ help="Display full detail of the execution in table format.",
+ )
result_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group()
- result_arg_grp.add_argument('-k', '--key',
- help=('If result is type of JSON, then print specific '
- 'key-value pair; dot notation for nested JSON is '
- 'supported.'))
- result_arg_grp.add_argument('--delay', type=int, default=None,
- help=('How long (in milliseconds) to delay the '
- 'execution before scheduling.'))
+ result_arg_grp.add_argument(
+ "-k",
+ "--key",
+ help=(
+ "If result is type of JSON, then print specific "
+ "key-value pair; dot notation for nested JSON is "
+ "supported."
+ ),
+ )
+ result_arg_grp.add_argument(
+ "--delay",
+ type=int,
+ default=None,
+ help=(
+ "How long (in milliseconds) to delay the "
+ "execution before scheduling."
+ ),
+ )
# Other options
- detail_arg_grp.add_argument('--tail', action='store_true',
- help='Automatically start tailing new execution.')
+ detail_arg_grp.add_argument(
+ "--tail",
+ action="store_true",
+ help="Automatically start tailing new execution.",
+ )
# Flag to opt-in to functionality introduced in PR #3670. More robust parsing
# of complex datatypes is planned for 2.6, so this flag will be deprecated soon
- detail_arg_grp.add_argument('--auto-dict', action='store_true', dest='auto_dict',
- default=False, help='Automatically convert list items to '
- 'dictionaries when colons are detected. '
- '(NOTE - this parameter and its functionality will be '
- 'deprecated in the next release in favor of a more '
- 'robust conversion method)')
+ detail_arg_grp.add_argument(
+ "--auto-dict",
+ action="store_true",
+ dest="auto_dict",
+ default=False,
+ help="Automatically convert list items to "
+ "dictionaries when colons are detected. "
+ "(NOTE - this parameter and its functionality will be "
+ "deprecated in the next release in favor of a more "
+ "robust conversion method)",
+ )
return root_arg_grp
@@ -334,20 +429,24 @@ def _print_execution_details(self, execution, args, **kwargs):
This method takes into account if an executed action was workflow or not
and formats the output accordingly.
"""
- runner_type = execution.action.get('runner_type', 'unknown')
+ runner_type = execution.action.get("runner_type", "unknown")
is_workflow_action = runner_type in WORKFLOW_RUNNER_TYPES
- show_tasks = getattr(args, 'show_tasks', False)
- raw = getattr(args, 'raw', False)
- detail = getattr(args, 'detail', False)
- key = getattr(args, 'key', None)
- attr = getattr(args, 'attr', [])
+ show_tasks = getattr(args, "show_tasks", False)
+ raw = getattr(args, "raw", False)
+ detail = getattr(args, "detail", False)
+ key = getattr(args, "key", None)
+ attr = getattr(args, "attr", [])
if show_tasks and not is_workflow_action:
- raise ValueError('--show-tasks option can only be used with workflow actions')
+ raise ValueError(
+ "--show-tasks option can only be used with workflow actions"
+ )
if not raw and not detail and (show_tasks or is_workflow_action):
- self._run_and_print_child_task_list(execution=execution, args=args, **kwargs)
+ self._run_and_print_child_task_list(
+ execution=execution, args=args, **kwargs
+ )
else:
instance = execution
@@ -357,47 +456,61 @@ def _print_execution_details(self, execution, args, **kwargs):
formatter = execution_formatter.ExecutionResult
if detail:
- options = {'attributes': copy.copy(self.display_attributes)}
+ options = {"attributes": copy.copy(self.display_attributes)}
elif key:
- options = {'attributes': ['result.%s' % (key)], 'key': key}
+ options = {"attributes": ["result.%s" % (key)], "key": key}
else:
- options = {'attributes': attr}
-
- options['json'] = args.json
- options['yaml'] = args.yaml
- options['with_schema'] = args.with_schema
- options['attribute_transform_functions'] = self.attribute_transform_functions
+ options = {"attributes": attr}
+
+ options["json"] = args.json
+ options["yaml"] = args.yaml
+ options["with_schema"] = args.with_schema
+ options[
+ "attribute_transform_functions"
+ ] = self.attribute_transform_functions
self.print_output(instance, formatter, **options)
def _run_and_print_child_task_list(self, execution, args, **kwargs):
- action_exec_mgr = self.app.client.managers['Execution']
+ action_exec_mgr = self.app.client.managers["Execution"]
instance = execution
- options = {'attributes': ['id', 'action.ref', 'parameters', 'status', 'start_timestamp',
- 'end_timestamp']}
- options['json'] = args.json
- options['attribute_transform_functions'] = self.attribute_transform_functions
+ options = {
+ "attributes": [
+ "id",
+ "action.ref",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ ]
+ }
+ options["json"] = args.json
+ options["attribute_transform_functions"] = self.attribute_transform_functions
formatter = execution_formatter.ExecutionResult
- kwargs['depth'] = args.depth
- child_instances = action_exec_mgr.get_property(execution.id, 'children', **kwargs)
+ kwargs["depth"] = args.depth
+ child_instances = action_exec_mgr.get_property(
+ execution.id, "children", **kwargs
+ )
child_instances = self._format_child_instances(child_instances, execution.id)
child_instances = format_execution_statuses(child_instances)
if not child_instances:
# No child error, there might be a global error, include result in the output
- options['attributes'].append('result')
+ options["attributes"].append("result")
- status_index = options['attributes'].index('status')
+ status_index = options["attributes"].index("status")
- if hasattr(instance, 'result') and isinstance(instance.result, dict):
- tasks = instance.result.get('tasks', [])
+ if hasattr(instance, "result") and isinstance(instance.result, dict):
+ tasks = instance.result.get("tasks", [])
else:
tasks = []
# On failure we also want to include error message and traceback at the top level
- if instance.status == 'failed':
- top_level_error, top_level_traceback = self._get_top_level_error(live_action=instance)
+ if instance.status == "failed":
+ top_level_error, top_level_traceback = self._get_top_level_error(
+ live_action=instance
+ )
if len(tasks) >= 1:
task_error, task_traceback = self._get_task_error(task=tasks[-1])
@@ -408,18 +521,18 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs):
# Top-level error
instance.error = top_level_error
instance.traceback = top_level_traceback
- instance.result = 'See error and traceback.'
- options['attributes'].insert(status_index + 1, 'error')
- options['attributes'].insert(status_index + 2, 'traceback')
+ instance.result = "See error and traceback."
+ options["attributes"].insert(status_index + 1, "error")
+ options["attributes"].insert(status_index + 2, "traceback")
elif task_error:
# Task error
instance.error = task_error
instance.traceback = task_traceback
- instance.result = 'See error and traceback.'
- instance.failed_on = tasks[-1].get('name', 'unknown')
- options['attributes'].insert(status_index + 1, 'error')
- options['attributes'].insert(status_index + 2, 'traceback')
- options['attributes'].insert(status_index + 3, 'failed_on')
+ instance.result = "See error and traceback."
+ instance.failed_on = tasks[-1].get("name", "unknown")
+ options["attributes"].insert(status_index + 1, "error")
+ options["attributes"].insert(status_index + 2, "traceback")
+ options["attributes"].insert(status_index + 3, "failed_on")
# Include result on the top-level object so user doesn't need to issue another command to
# see the result
@@ -427,57 +540,63 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs):
task_result = self._get_task_result(task=tasks[-1])
if task_result:
- instance.result_task = tasks[-1].get('name', 'unknown')
- options['attributes'].insert(status_index + 1, 'result_task')
- options['attributes'].insert(status_index + 2, 'result')
+ instance.result_task = tasks[-1].get("name", "unknown")
+ options["attributes"].insert(status_index + 1, "result_task")
+ options["attributes"].insert(status_index + 2, "result")
instance.result = task_result
# Otherwise include the result of the workflow execution.
else:
- if 'result' not in options['attributes']:
- options['attributes'].append('result')
+ if "result" not in options["attributes"]:
+ options["attributes"].append("result")
# print root task
self.print_output(instance, formatter, **options)
# print child tasks
if child_instances:
- self.print_output(child_instances, table.MultiColumnTable,
- attributes=['id', 'status', 'task', 'action', 'start_timestamp'],
- widths=args.width, json=args.json,
- yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ child_instances,
+ table.MultiColumnTable,
+ attributes=["id", "status", "task", "action", "start_timestamp"],
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
def _get_execution_result(self, execution, action_exec_mgr, args, **kwargs):
pending_statuses = [
LIVEACTION_STATUS_REQUESTED,
LIVEACTION_STATUS_SCHEDULED,
LIVEACTION_STATUS_RUNNING,
- LIVEACTION_STATUS_CANCELING
+ LIVEACTION_STATUS_CANCELING,
]
if args.tail:
# Start tailing new execution
print('Tailing execution "%s"' % (str(execution.id)))
- execution_manager = self.app.client.managers['Execution']
- stream_manager = self.app.client.managers['Stream']
- ActionExecutionTailCommand.tail_execution(execution=execution,
- execution_manager=execution_manager,
- stream_manager=stream_manager,
- **kwargs)
+ execution_manager = self.app.client.managers["Execution"]
+ stream_manager = self.app.client.managers["Stream"]
+ ActionExecutionTailCommand.tail_execution(
+ execution=execution,
+ execution_manager=execution_manager,
+ stream_manager=stream_manager,
+ **kwargs,
+ )
execution = action_exec_mgr.get_by_id(execution.id, **kwargs)
- print('')
+ print("")
return execution
if not args.action_async:
while execution.status in pending_statuses:
time.sleep(self.poll_interval)
if not args.json and not args.yaml:
- sys.stdout.write('.')
+ sys.stdout.write(".")
sys.stdout.flush()
execution = action_exec_mgr.get_by_id(execution.id, **kwargs)
- sys.stdout.write('\n')
+ sys.stdout.write("\n")
if execution.status == LIVEACTION_STATUS_CANCELED:
return execution
@@ -491,8 +610,8 @@ def _get_top_level_error(self, live_action):
:return: (error, traceback)
"""
if isinstance(live_action.result, dict):
- error = live_action.result.get('error', None)
- traceback = live_action.result.get('traceback', None)
+ error = live_action.result.get("error", None)
+ traceback = live_action.result.get("traceback", None)
else:
error = "See result"
traceback = "See result"
@@ -508,12 +627,12 @@ def _get_task_error(self, task):
if not task:
return None, None
- result = task['result']
+ result = task["result"]
if isinstance(result, dict):
- stderr = result.get('stderr', None)
- error = result.get('error', None)
- traceback = result.get('traceback', None)
+ stderr = result.get("stderr", None)
+ error = result.get("error", None)
+ traceback = result.get("traceback", None)
error = error if error else stderr
else:
stderr = None
@@ -526,7 +645,7 @@ def _get_task_result(self, task):
if not task:
return None
- return task['result']
+ return task["result"]
def _get_action_parameters_from_args(self, action, runner, args):
"""
@@ -553,22 +672,22 @@ def read_file(file_path):
if not os.path.isfile(file_path):
raise ValueError('"%s" is not a file' % (file_path))
- with open(file_path, 'rb') as fp:
+ with open(file_path, "rb") as fp:
content = fp.read()
return content.decode("utf-8")
def transform_object(value):
# Also support simple key1=val1,key2=val2 syntax
- if value.startswith('{'):
+ if value.startswith("{"):
# Assume it's JSON
result = value = json.loads(value)
else:
- pairs = value.split(',')
+ pairs = value.split(",")
result = {}
for pair in pairs:
- split = pair.split('=', 1)
+ split = pair.split("=", 1)
if len(split) != 2:
continue
@@ -605,18 +724,22 @@ def transform_array(value, action_params=None, auto_dict=False):
try:
result = json.loads(value)
except ValueError:
- result = [v.strip() for v in value.split(',')]
+ result = [v.strip() for v in value.split(",")]
# When each values in this array represent dict type, this converts
# the 'result' to the dict type value.
- if all([isinstance(x, str) and ':' in x for x in result]) and auto_dict:
+ if all([isinstance(x, str) and ":" in x for x in result]) and auto_dict:
result_dict = {}
- for (k, v) in [x.split(':') for x in result]:
+ for (k, v) in [x.split(":") for x in result]:
# To parse values using the 'transformer' according to the type which is
# specified in the action metadata, calling 'normalize' method recursively.
- if 'properties' in action_params and k in action_params['properties']:
- result_dict[k] = normalize(k, v, action_params['properties'],
- auto_dict=auto_dict)
+ if (
+ "properties" in action_params
+ and k in action_params["properties"]
+ ):
+ result_dict[k] = normalize(
+ k, v, action_params["properties"], auto_dict=auto_dict
+ )
else:
result_dict[k] = v
return [result_dict]
@@ -624,12 +747,12 @@ def transform_array(value, action_params=None, auto_dict=False):
return result
transformer = {
- 'array': transform_array,
- 'boolean': (lambda x: ast.literal_eval(x.capitalize())),
- 'integer': int,
- 'number': float,
- 'object': transform_object,
- 'string': str
+ "array": transform_array,
+ "boolean": (lambda x: ast.literal_eval(x.capitalize())),
+ "integer": int,
+ "number": float,
+ "object": transform_object,
+ "string": str,
}
def get_param_type(key, action_params=None):
@@ -642,13 +765,13 @@ def get_param_type(key, action_params=None):
param = action_params[key]
if param:
- return param['type']
+ return param["type"]
return None
def normalize(name, value, action_params=None, auto_dict=False):
- """ The desired type is contained in the action meta-data, so we can look that up
- and call the desired "caster" function listed in the "transformer" dict
+ """The desired type is contained in the action meta-data, so we can look that up
+ and call the desired "caster" function listed in the "transformer" dict
"""
action_params = action_params or action.parameters
@@ -663,8 +786,10 @@ def normalize(name, value, action_params=None, auto_dict=False):
# (items: type: int for example) and this information is available here so we could
# also leverage that to cast each array item to the correct type.
param_type = get_param_type(name, action_params)
- if param_type == 'array' and name in action_params:
- return transformer[param_type](value, action_params[name], auto_dict=auto_dict)
+ if param_type == "array" and name in action_params:
+ return transformer[param_type](
+ value, action_params[name], auto_dict=auto_dict
+ )
elif param_type:
return transformer[param_type](value)
@@ -677,11 +802,11 @@ def normalize(name, value, action_params=None, auto_dict=False):
for idx in range(len(args.parameters)):
arg = args.parameters[idx]
- if '=' in arg:
- k, v = arg.split('=', 1)
+ if "=" in arg:
+ k, v = arg.split("=", 1)
# Attribute for files are prefixed with "@"
- if k.startswith('@'):
+ if k.startswith("@"):
k = k[1:]
is_file = True
else:
@@ -695,15 +820,15 @@ def normalize(name, value, action_params=None, auto_dict=False):
file_name = os.path.basename(file_path)
content = read_file(file_path=file_path)
- if action_ref_or_id == 'core.http':
+ if action_ref_or_id == "core.http":
# Special case for http runner
- result['_file_name'] = file_name
- result['file_content'] = content
+ result["_file_name"] = file_name
+ result["file_content"] = content
else:
result[k] = content
else:
# This permits multiple declarations of argument only in the array type.
- if get_param_type(k) == 'array' and k in result:
+ if get_param_type(k) == "array" and k in result:
result[k] += normalize(k, v, auto_dict=args.auto_dict)
else:
result[k] = normalize(k, v, auto_dict=args.auto_dict)
@@ -711,42 +836,44 @@ def normalize(name, value, action_params=None, auto_dict=False):
except Exception as e:
# TODO: Move transformers in a separate module and handle
# exceptions there
- if 'malformed string' in six.text_type(e):
- message = ('Invalid value for boolean parameter. '
- 'Valid values are: true, false')
+ if "malformed string" in six.text_type(e):
+ message = (
+ "Invalid value for boolean parameter. "
+ "Valid values are: true, false"
+ )
raise ValueError(message)
else:
raise e
else:
- result['cmd'] = ' '.join(args.parameters[idx:])
+ result["cmd"] = " ".join(args.parameters[idx:])
break
# Special case for http runner
- if 'file_content' in result:
- if 'method' not in result:
+ if "file_content" in result:
+ if "method" not in result:
# Default to POST if a method is not provided
- result['method'] = 'POST'
+ result["method"] = "POST"
- if 'file_name' not in result:
+ if "file_name" not in result:
# File name not provided, use default file name
- result['file_name'] = result['_file_name']
+ result["file_name"] = result["_file_name"]
- del result['_file_name']
+ del result["_file_name"]
if args.inherit_env:
- result['env'] = self._get_inherited_env_vars()
+ result["env"] = self._get_inherited_env_vars()
return result
@add_auth_token_to_kwargs_from_cli
def _print_help(self, args, **kwargs):
# Print appropriate help message if the help option is given.
- action_mgr = self.app.client.managers['Action']
- action_exec_mgr = self.app.client.managers['Execution']
+ action_mgr = self.app.client.managers["Action"]
+ action_exec_mgr = self.app.client.managers["Execution"]
if args.help:
- action_ref_or_id = getattr(args, 'ref_or_id', None)
- action_exec_id = getattr(args, 'id', None)
+ action_ref_or_id = getattr(args, "ref_or_id", None)
+ action_exec_id = getattr(args, "id", None)
if action_exec_id and not action_ref_or_id:
action_exec = action_exec_mgr.get_by_id(action_exec_id, **kwargs)
@@ -756,34 +883,47 @@ def _print_help(self, args, **kwargs):
try:
action = action_mgr.get_by_ref_or_id(args.ref_or_id, **kwargs)
if not action:
- raise resource.ResourceNotFoundError('Action %s not found' % args.ref_or_id)
- runner_mgr = self.app.client.managers['RunnerType']
+ raise resource.ResourceNotFoundError(
+ "Action %s not found" % args.ref_or_id
+ )
+ runner_mgr = self.app.client.managers["RunnerType"]
runner = runner_mgr.get_by_name(action.runner_type, **kwargs)
- parameters, required, optional, _ = self._get_params_types(runner,
- action)
- print('')
+ parameters, required, optional, _ = self._get_params_types(
+ runner, action
+ )
+ print("")
print(textwrap.fill(action.description))
- print('')
+ print("")
if required:
- required = self._sort_parameters(parameters=parameters,
- names=required)
-
- print('Required Parameters:')
- [self._print_param(name, parameters.get(name))
- for name in required]
+ required = self._sort_parameters(
+ parameters=parameters, names=required
+ )
+
+ print("Required Parameters:")
+ [
+ self._print_param(name, parameters.get(name))
+ for name in required
+ ]
if optional:
- optional = self._sort_parameters(parameters=parameters,
- names=optional)
-
- print('Optional Parameters:')
- [self._print_param(name, parameters.get(name))
- for name in optional]
+ optional = self._sort_parameters(
+ parameters=parameters, names=optional
+ )
+
+ print("Optional Parameters:")
+ [
+ self._print_param(name, parameters.get(name))
+ for name in optional
+ ]
except resource.ResourceNotFoundError:
- print(('Action "%s" is not found. ' % args.ref_or_id) +
- 'Use "st2 action list" to see the list of available actions.')
+ print(
+ ('Action "%s" is not found. ' % args.ref_or_id)
+ + 'Use "st2 action list" to see the list of available actions.'
+ )
except Exception as e:
- print('ERROR: Unable to print help for action "%s". %s' %
- (args.ref_or_id, e))
+ print(
+ 'ERROR: Unable to print help for action "%s". %s'
+ % (args.ref_or_id, e)
+ )
else:
self.parser.print_help()
return True
@@ -795,20 +935,20 @@ def _print_param(name, schema):
raise ValueError('Missing schema for parameter "%s"' % (name))
wrapper = textwrap.TextWrapper(width=78)
- wrapper.initial_indent = ' ' * 4
+ wrapper.initial_indent = " " * 4
wrapper.subsequent_indent = wrapper.initial_indent
print(wrapper.fill(name))
- wrapper.initial_indent = ' ' * 8
+ wrapper.initial_indent = " " * 8
wrapper.subsequent_indent = wrapper.initial_indent
- if 'description' in schema and schema['description']:
- print(wrapper.fill(schema['description']))
- if 'type' in schema and schema['type']:
- print(wrapper.fill('Type: %s' % schema['type']))
- if 'enum' in schema and schema['enum']:
- print(wrapper.fill('Enum: %s' % ', '.join(schema['enum'])))
- if 'default' in schema and schema['default'] is not None:
- print(wrapper.fill('Default: %s' % schema['default']))
- print('')
+ if "description" in schema and schema["description"]:
+ print(wrapper.fill(schema["description"]))
+ if "type" in schema and schema["type"]:
+ print(wrapper.fill("Type: %s" % schema["type"]))
+ if "enum" in schema and schema["enum"]:
+ print(wrapper.fill("Enum: %s" % ", ".join(schema["enum"])))
+ if "default" in schema and schema["default"] is not None:
+ print(wrapper.fill("Default: %s" % schema["default"]))
+ print("")
@staticmethod
def _get_params_types(runner, action):
@@ -816,19 +956,18 @@ def _get_params_types(runner, action):
action_params = action.parameters
parameters = copy.copy(runner_params)
parameters.update(copy.copy(action_params))
- required = set([k for k, v in six.iteritems(parameters) if v.get('required')])
+ required = set([k for k, v in six.iteritems(parameters) if v.get("required")])
def is_immutable(runner_param_meta, action_param_meta):
# If runner sets a param as immutable, action cannot override that.
- if runner_param_meta.get('immutable', False):
+ if runner_param_meta.get("immutable", False):
return True
else:
- return action_param_meta.get('immutable', False)
+ return action_param_meta.get("immutable", False)
immutable = set()
for param in parameters.keys():
- if is_immutable(runner_params.get(param, {}),
- action_params.get(param, {})):
+ if is_immutable(runner_params.get(param, {}), action_params.get(param, {})):
immutable.add(param)
required = required - immutable
@@ -837,12 +976,12 @@ def is_immutable(runner_param_meta, action_param_meta):
return parameters, required, optional, immutable
def _format_child_instances(self, children, parent_id):
- '''
+ """
The goal of this method is to add an indent at every level. This way the
WF is represented as a tree structure while in a list. For the right visuals
representation the list must be a DF traversal else the idents will end up
looking strange.
- '''
+ """
# apply basic WF formating first.
children = format_wf_instances(children)
# setup a depth lookup table
@@ -856,7 +995,9 @@ def _format_child_instances(self, children, parent_id):
parent = None
for instance in children:
if WF_PREFIX in instance.id:
- instance_id = instance.id[instance.id.index(WF_PREFIX) + len(WF_PREFIX):]
+ instance_id = instance.id[
+ instance.id.index(WF_PREFIX) + len(WF_PREFIX) :
+ ]
else:
instance_id = instance.id
if instance_id == child.parent:
@@ -871,26 +1012,28 @@ def _format_child_instances(self, children, parent_id):
return result
def _format_for_common_representation(self, task):
- '''
+ """
Formats a task for common representation for action-chain.
- '''
+ """
# This really needs to be better handled on the back-end but that would be a bigger
# change so handling in cli.
- context = getattr(task, 'context', None)
- if context and 'chain' in context:
- task_name_key = 'context.chain.name'
- elif context and 'orquesta' in context:
- task_name_key = 'context.orquesta.task_name'
+ context = getattr(task, "context", None)
+ if context and "chain" in context:
+ task_name_key = "context.chain.name"
+ elif context and "orquesta" in context:
+ task_name_key = "context.orquesta.task_name"
# Use Execution as the object so that the formatter lookup does not change.
# AKA HACK!
- return models.action.Execution(**{
- 'id': task.id,
- 'status': task.status,
- 'task': jsutil.get_value(vars(task), task_name_key),
- 'action': task.action.get('ref', None),
- 'start_timestamp': task.start_timestamp,
- 'end_timestamp': getattr(task, 'end_timestamp', None)
- })
+ return models.action.Execution(
+ **{
+ "id": task.id,
+ "status": task.status,
+ "task": jsutil.get_value(vars(task), task_name_key),
+ "action": task.action.get("ref", None),
+ "start_timestamp": task.start_timestamp,
+ "end_timestamp": getattr(task, "end_timestamp", None),
+ }
+ )
def _sort_parameters(self, parameters, names):
"""
@@ -899,10 +1042,12 @@ def _sort_parameters(self, parameters, names):
:type parameters: ``list``
:type names: ``list`` or ``set``
"""
- sorted_parameters = sorted(names, key=lambda name:
- self._get_parameter_sort_value(
- parameters=parameters,
- name=name))
+ sorted_parameters = sorted(
+ names,
+ key=lambda name: self._get_parameter_sort_value(
+ parameters=parameters, name=name
+ ),
+ )
return sorted_parameters
@@ -919,7 +1064,7 @@ def _get_parameter_sort_value(self, parameters, name):
if not parameter:
return None
- sort_value = parameter.get('position', name)
+ sort_value = parameter.get("position", name)
return sort_value
def _get_inherited_env_vars(self):
@@ -938,44 +1083,76 @@ class ActionRunCommand(ActionRunCommandMixin, resource.ResourceCommand):
def __init__(self, resource, *args, **kwargs):
super(ActionRunCommand, self).__init__(
- resource, kwargs.pop('name', 'execute'),
- 'Invoke an action manually.',
- *args, **kwargs)
-
- self.parser.add_argument('ref_or_id', nargs='?',
- metavar='ref-or-id',
- help='Action reference (pack.action_name) ' +
- 'or ID of the action.')
- self.parser.add_argument('parameters', nargs='*',
- help='List of keyword args, positional args, '
- 'and optional args for the action.')
-
- self.parser.add_argument('-h', '--help',
- action='store_true', dest='help',
- help='Print usage for the given action.')
+ resource,
+ kwargs.pop("name", "execute"),
+ "Invoke an action manually.",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "ref_or_id",
+ nargs="?",
+ metavar="ref-or-id",
+ help="Action reference (pack.action_name) " + "or ID of the action.",
+ )
+ self.parser.add_argument(
+ "parameters",
+ nargs="*",
+ help="List of keyword args, positional args, "
+ "and optional args for the action.",
+ )
+
+ self.parser.add_argument(
+ "-h",
+ "--help",
+ action="store_true",
+ dest="help",
+ help="Print usage for the given action.",
+ )
self._add_common_options()
- if self.name in ['run', 'execute']:
- self.parser.add_argument('--trace-tag', '--trace_tag',
- help='A trace tag string to track execution later.',
- dest='trace_tag', required=False)
- self.parser.add_argument('--trace-id',
- help='Existing trace id for this execution.',
- dest='trace_id', required=False)
- self.parser.add_argument('-a', '--async',
- action='store_true', dest='action_async',
- help='Do not wait for action to finish.')
- self.parser.add_argument('-e', '--inherit-env',
- action='store_true', dest='inherit_env',
- help='Pass all the environment variables '
- 'which are accessible to the CLI as "env" '
- 'parameter to the action. Note: Only works '
- 'with python, local and remote runners.')
- self.parser.add_argument('-u', '--user', type=str, default=None,
- help='User under which to run the action (admins only).')
-
- if self.name == 'run':
+ if self.name in ["run", "execute"]:
+ self.parser.add_argument(
+ "--trace-tag",
+ "--trace_tag",
+ help="A trace tag string to track execution later.",
+ dest="trace_tag",
+ required=False,
+ )
+ self.parser.add_argument(
+ "--trace-id",
+ help="Existing trace id for this execution.",
+ dest="trace_id",
+ required=False,
+ )
+ self.parser.add_argument(
+ "-a",
+ "--async",
+ action="store_true",
+ dest="action_async",
+ help="Do not wait for action to finish.",
+ )
+ self.parser.add_argument(
+ "-e",
+ "--inherit-env",
+ action="store_true",
+ dest="inherit_env",
+ help="Pass all the environment variables "
+ 'which are accessible to the CLI as "env" '
+ "parameter to the action. Note: Only works "
+ "with python, local and remote runners.",
+ )
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ type=str,
+ default=None,
+ help="User under which to run the action (admins only).",
+ )
+
+ if self.name == "run":
self.parser.set_defaults(action_async=False)
else:
self.parser.set_defaults(action_async=True)
@@ -983,22 +1160,27 @@ def __init__(self, resource, *args, **kwargs):
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
if not args.ref_or_id:
- self.parser.error('Missing action reference or id')
+ self.parser.error("Missing action reference or id")
action = self.get_resource(args.ref_or_id, **kwargs)
if not action:
- raise resource.ResourceNotFoundError('Action "%s" cannot be found.'
- % (args.ref_or_id))
+ raise resource.ResourceNotFoundError(
+ 'Action "%s" cannot be found.' % (args.ref_or_id)
+ )
- runner_mgr = self.app.client.managers['RunnerType']
+ runner_mgr = self.app.client.managers["RunnerType"]
runner = runner_mgr.get_by_name(action.runner_type, **kwargs)
if not runner:
- raise resource.ResourceNotFoundError('Runner type "%s" for action "%s" cannot be \
- found.' % (action.runner_type, action.name))
+ raise resource.ResourceNotFoundError(
+ 'Runner type "%s" for action "%s" cannot be \
+ found.'
+ % (action.runner_type, action.name)
+ )
- action_ref = '.'.join([action.pack, action.name])
- action_parameters = self._get_action_parameters_from_args(action=action, runner=runner,
- args=args)
+ action_ref = ".".join([action.pack, action.name])
+ action_parameters = self._get_action_parameters_from_args(
+ action=action, runner=runner, args=args
+ )
execution = models.Execution()
execution.action = action_ref
@@ -1009,56 +1191,79 @@ def run(self, args, **kwargs):
execution.delay = args.delay
if not args.trace_id and args.trace_tag:
- execution.context = {'trace_context': {'trace_tag': args.trace_tag}}
+ execution.context = {"trace_context": {"trace_tag": args.trace_tag}}
if args.trace_id:
- execution.context = {'trace_context': {'id_': args.trace_id}}
+ execution.context = {"trace_context": {"id_": args.trace_id}}
- action_exec_mgr = self.app.client.managers['Execution']
+ action_exec_mgr = self.app.client.managers["Execution"]
execution = action_exec_mgr.create(execution, **kwargs)
- execution = self._get_execution_result(execution=execution,
- action_exec_mgr=action_exec_mgr,
- args=args, **kwargs)
+ execution = self._get_execution_result(
+ execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs
+ )
return execution
class ActionExecutionBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(ActionExecutionBranch, self).__init__(
- models.Execution, description, app, subparsers,
- parent_parser=parent_parser, read_only=True,
- commands={'list': ActionExecutionListCommand,
- 'get': ActionExecutionGetCommand})
+ models.Execution,
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ read_only=True,
+ commands={
+ "list": ActionExecutionListCommand,
+ "get": ActionExecutionGetCommand,
+ },
+ )
# Register extended commands
- self.commands['re-run'] = ActionExecutionReRunCommand(
- self.resource, self.app, self.subparsers, add_help=False)
- self.commands['cancel'] = ActionExecutionCancelCommand(
- self.resource, self.app, self.subparsers, add_help=True)
- self.commands['pause'] = ActionExecutionPauseCommand(
- self.resource, self.app, self.subparsers, add_help=True)
- self.commands['resume'] = ActionExecutionResumeCommand(
- self.resource, self.app, self.subparsers, add_help=True)
- self.commands['tail'] = ActionExecutionTailCommand(self.resource, self.app,
- self.subparsers,
- add_help=True)
-
-
-POSSIBLE_ACTION_STATUS_VALUES = ('succeeded', 'running', 'scheduled', 'paused', 'failed',
- 'canceling', 'canceled')
+ self.commands["re-run"] = ActionExecutionReRunCommand(
+ self.resource, self.app, self.subparsers, add_help=False
+ )
+ self.commands["cancel"] = ActionExecutionCancelCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
+ self.commands["pause"] = ActionExecutionPauseCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
+ self.commands["resume"] = ActionExecutionResumeCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
+ self.commands["tail"] = ActionExecutionTailCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
+
+
+POSSIBLE_ACTION_STATUS_VALUES = (
+ "succeeded",
+ "running",
+ "scheduled",
+ "paused",
+ "failed",
+ "canceling",
+ "canceled",
+)
class ActionExecutionListCommand(ResourceViewCommand):
- display_attributes = ['id', 'action.ref', 'context.user', 'status', 'start_timestamp',
- 'end_timestamp']
+ display_attributes = [
+ "id",
+ "action.ref",
+ "context.user",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ ]
attribute_transform_functions = {
- 'start_timestamp': format_isodate_for_user_timezone,
- 'end_timestamp': format_isodate_for_user_timezone,
- 'parameters': format_parameters,
- 'status': format_status
+ "start_timestamp": format_isodate_for_user_timezone,
+ "end_timestamp": format_isodate_for_user_timezone,
+ "parameters": format_parameters,
+ "status": format_status,
}
def __init__(self, resource, *args, **kwargs):
@@ -1066,83 +1271,133 @@ def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
super(ActionExecutionListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s most recent %s.' %
- (self.default_limit, resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
self.group = self.parser.add_argument_group()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
- self.parser.add_argument('-s', '--sort', type=str, dest='sort_order',
- default='descending',
- help=('Sort %s by start timestamp, '
- 'asc|ascending (earliest first) '
- 'or desc|descending (latest first)' % self.resource_name))
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
+ self.parser.add_argument(
+ "-s",
+ "--sort",
+ type=str,
+ dest="sort_order",
+ default="descending",
+ help=(
+ "Sort %s by start timestamp, "
+ "asc|ascending (earliest first) "
+ "or desc|descending (latest first)" % self.resource_name
+ ),
+ )
# Filter options
- self.group.add_argument('--action', help='Action reference to filter the list.')
- self.group.add_argument('--status', help=('Only return executions with the provided \
- status. Possible values are \'%s\', \'%s\', \
- \'%s\', \'%s\', \'%s\', \'%s\' or \'%s\''
- '.' % POSSIBLE_ACTION_STATUS_VALUES))
- self.group.add_argument('--user',
- help='Only return executions created by the provided user.')
- self.group.add_argument('--trigger_instance',
- help='Trigger instance id to filter the list.')
- self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt',
- default=None,
- help=('Only return executions with timestamp '
- 'greater than the one provided. '
- 'Use time in the format "2000-01-01T12:00:00.000Z".'))
- self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt',
- default=None,
- help=('Only return executions with timestamp '
- 'lower than the one provided. '
- 'Use time in the format "2000-01-01T12:00:00.000Z".'))
- self.parser.add_argument('-l', '--showall', action='store_true',
- help='')
+ self.group.add_argument("--action", help="Action reference to filter the list.")
+ self.group.add_argument(
+ "--status",
+ help=(
+ "Only return executions with the provided \
+ status. Possible values are '%s', '%s', \
+ '%s', '%s', '%s', '%s' or '%s'"
+ "." % POSSIBLE_ACTION_STATUS_VALUES
+ ),
+ )
+ self.group.add_argument(
+ "--user", help="Only return executions created by the provided user."
+ )
+ self.group.add_argument(
+ "--trigger_instance", help="Trigger instance id to filter the list."
+ )
+ self.parser.add_argument(
+ "-tg",
+ "--timestamp-gt",
+ type=str,
+ dest="timestamp_gt",
+ default=None,
+ help=(
+ "Only return executions with timestamp "
+ "greater than the one provided. "
+ 'Use time in the format "2000-01-01T12:00:00.000Z".'
+ ),
+ )
+ self.parser.add_argument(
+ "-tl",
+ "--timestamp-lt",
+ type=str,
+ dest="timestamp_lt",
+ default=None,
+ help=(
+ "Only return executions with timestamp "
+ "lower than the one provided. "
+ 'Use time in the format "2000-01-01T12:00:00.000Z".'
+ ),
+ )
+ self.parser.add_argument("-l", "--showall", action="store_true", help="")
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.action:
- kwargs['action'] = args.action
+ kwargs["action"] = args.action
if args.status:
- kwargs['status'] = args.status
+ kwargs["status"] = args.status
if args.user:
- kwargs['user'] = args.user
+ kwargs["user"] = args.user
if args.trigger_instance:
- kwargs['trigger_instance'] = args.trigger_instance
+ kwargs["trigger_instance"] = args.trigger_instance
if not args.showall:
# null is the magic string that translates to does not exist.
- kwargs['parent'] = 'null'
+ kwargs["parent"] = "null"
if args.timestamp_gt:
- kwargs['timestamp_gt'] = args.timestamp_gt
+ kwargs["timestamp_gt"] = args.timestamp_gt
if args.timestamp_lt:
- kwargs['timestamp_lt'] = args.timestamp_lt
+ kwargs["timestamp_lt"] = args.timestamp_lt
if args.sort_order:
- if args.sort_order in ['asc', 'ascending']:
- kwargs['sort_asc'] = True
- elif args.sort_order in ['desc', 'descending']:
- kwargs['sort_desc'] = True
+ if args.sort_order in ["asc", "ascending"]:
+ kwargs["sort_asc"] = True
+ elif args.sort_order in ["desc", "descending"]:
+ kwargs["sort_desc"] = True
# We only retrieve attributes which are needed to speed things up
include_attributes = self._get_include_attributes(args=args)
if include_attributes:
- kwargs['include_attributes'] = ','.join(include_attributes)
+ kwargs["include_attributes"] = ",".join(include_attributes)
return self.manager.query_with_count(limit=args.last, **kwargs)
@@ -1152,49 +1407,73 @@ def run_and_print(self, args, **kwargs):
instances = format_wf_instances(result)
if args.json or args.yaml:
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json,
- yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
else:
# Include elapsed time for running executions
instances = format_execution_statuses(instances)
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class ActionExecutionGetCommand(ActionRunCommandMixin, ResourceViewCommand):
- display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status',
- 'start_timestamp', 'end_timestamp', 'result']
- include_attributes = ['action.ref', 'action.runner_type', 'start_timestamp',
- 'end_timestamp']
+ display_attributes = [
+ "id",
+ "action.ref",
+ "context.user",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ "result",
+ ]
+ include_attributes = [
+ "action.ref",
+ "action.runner_type",
+ "start_timestamp",
+ "end_timestamp",
+ ]
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionGetCommand, self).__init__(
- resource, 'get',
- 'Get individual %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ resource,
+ "get",
+ "Get individual %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
- self.parser.add_argument('id',
- help=('ID of the %s.' %
- resource.get_display_name().lower()))
+ self.parser.add_argument(
+ "id", help=("ID of the %s." % resource.get_display_name().lower())
+ )
self._add_common_options()
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# We only retrieve attributes which are needed to speed things up
- include_attributes = self._get_include_attributes(args=args,
- extra_attributes=self.include_attributes)
+ include_attributes = self._get_include_attributes(
+ args=args, extra_attributes=self.include_attributes
+ )
if include_attributes:
- include_attributes = ','.join(include_attributes)
- kwargs['params'] = {'include_attributes': include_attributes}
+ include_attributes = ",".join(include_attributes)
+ kwargs["params"] = {"include_attributes": include_attributes}
execution = self.get_resource_by_id(id=args.id, **kwargs)
return execution
@@ -1209,22 +1488,25 @@ def run_and_print(self, args, **kwargs):
execution = format_execution_status(execution)
except resource.ResourceNotFoundError:
self.print_not_found(args.id)
- raise ResourceNotFoundError('Execution with id %s not found.' % (args.id))
+ raise ResourceNotFoundError("Execution with id %s not found." % (args.id))
return self._print_execution_details(execution=execution, args=args, **kwargs)
class ActionExecutionCancelCommand(resource.ResourceCommand):
-
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionCancelCommand, self).__init__(
- resource, 'cancel', 'Cancel %s.' %
- resource.get_plural_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('ids',
- nargs='+',
- help=('IDs of the %ss to cancel.' %
- resource.get_display_name().lower()))
+ resource,
+ "cancel",
+ "Cancel %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "ids",
+ nargs="+",
+ help=("IDs of the %ss to cancel." % resource.get_display_name().lower()),
+ )
def run(self, args, **kwargs):
responses = []
@@ -1242,16 +1524,23 @@ def run_and_print(self, args, **kwargs):
self._print_result(execution_id=execution_id, response=response)
def _print_result(self, execution_id, response):
- if response and 'faultstring' in response:
- message = response.get('faultstring', 'Cancellation requested for %s with id %s.' %
- (self.resource.get_display_name().lower(), execution_id))
+ if response and "faultstring" in response:
+ message = response.get(
+ "faultstring",
+ "Cancellation requested for %s with id %s."
+ % (self.resource.get_display_name().lower(), execution_id),
+ )
elif response:
- message = '%s with id %s canceled.' % (self.resource.get_display_name().lower(),
- execution_id)
+ message = "%s with id %s canceled." % (
+ self.resource.get_display_name().lower(),
+ execution_id,
+ )
else:
- message = 'Cannot cancel %s with id %s.' % (self.resource.get_display_name().lower(),
- execution_id)
+ message = "Cannot cancel %s with id %s." % (
+ self.resource.get_display_name().lower(),
+ execution_id,
+ )
print(message)
@@ -1259,35 +1548,58 @@ class ActionExecutionReRunCommand(ActionRunCommandMixin, resource.ResourceComman
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionReRunCommand, self).__init__(
- resource, kwargs.pop('name', 're-run'),
- 'Re-run a particular action.',
- *args, **kwargs)
-
- self.parser.add_argument('id', nargs='?',
- metavar='id',
- help='ID of action execution to re-run ')
- self.parser.add_argument('parameters', nargs='*',
- help='List of keyword args, positional args, '
- 'and optional args for the action.')
- self.parser.add_argument('--tasks', nargs='*',
- help='Name of the workflow tasks to re-run.')
- self.parser.add_argument('--no-reset', dest='no_reset', nargs='*',
- help='Name of the with-items tasks to not reset. This only '
- 'applies to Orquesta workflows. By default, all iterations '
- 'for with-items tasks is rerun. If no reset, only failed '
- ' iterations are rerun.')
- self.parser.add_argument('-a', '--async',
- action='store_true', dest='action_async',
- help='Do not wait for action to finish.')
- self.parser.add_argument('-e', '--inherit-env',
- action='store_true', dest='inherit_env',
- help='Pass all the environment variables '
- 'which are accessible to the CLI as "env" '
- 'parameter to the action. Note: Only works '
- 'with python, local and remote runners.')
- self.parser.add_argument('-h', '--help',
- action='store_true', dest='help',
- help='Print usage for the given action.')
+ resource,
+ kwargs.pop("name", "re-run"),
+ "Re-run a particular action.",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "id", nargs="?", metavar="id", help="ID of action execution to re-run "
+ )
+ self.parser.add_argument(
+ "parameters",
+ nargs="*",
+ help="List of keyword args, positional args, "
+ "and optional args for the action.",
+ )
+ self.parser.add_argument(
+ "--tasks", nargs="*", help="Name of the workflow tasks to re-run."
+ )
+ self.parser.add_argument(
+ "--no-reset",
+ dest="no_reset",
+ nargs="*",
+ help="Name of the with-items tasks to not reset. This only "
+ "applies to Orquesta workflows. By default, all iterations "
+ "for with-items tasks is rerun. If no reset, only failed "
+ " iterations are rerun.",
+ )
+ self.parser.add_argument(
+ "-a",
+ "--async",
+ action="store_true",
+ dest="action_async",
+ help="Do not wait for action to finish.",
+ )
+ self.parser.add_argument(
+ "-e",
+ "--inherit-env",
+ action="store_true",
+ dest="inherit_env",
+ help="Pass all the environment variables "
+ 'which are accessible to the CLI as "env" '
+ "parameter to the action. Note: Only works "
+ "with python, local and remote runners.",
+ )
+ self.parser.add_argument(
+ "-h",
+ "--help",
+ action="store_true",
+ dest="help",
+ help="Print usage for the given action.",
+ )
self._add_common_options()
@add_auth_token_to_kwargs_from_cli
@@ -1295,47 +1607,63 @@ def run(self, args, **kwargs):
existing_execution = self.manager.get_by_id(args.id, **kwargs)
if not existing_execution:
- raise resource.ResourceNotFoundError('Action execution with id "%s" cannot be found.' %
- (args.id))
+ raise resource.ResourceNotFoundError(
+ 'Action execution with id "%s" cannot be found.' % (args.id)
+ )
- action_mgr = self.app.client.managers['Action']
- runner_mgr = self.app.client.managers['RunnerType']
- action_exec_mgr = self.app.client.managers['Execution']
+ action_mgr = self.app.client.managers["Action"]
+ runner_mgr = self.app.client.managers["RunnerType"]
+ action_exec_mgr = self.app.client.managers["Execution"]
- action_ref = existing_execution.action['ref']
+ action_ref = existing_execution.action["ref"]
action = action_mgr.get_by_ref_or_id(action_ref)
runner = runner_mgr.get_by_name(action.runner_type)
- action_parameters = self._get_action_parameters_from_args(action=action, runner=runner,
- args=args)
+ action_parameters = self._get_action_parameters_from_args(
+ action=action, runner=runner, args=args
+ )
- execution = action_exec_mgr.re_run(execution_id=args.id,
- parameters=action_parameters,
- tasks=args.tasks,
- no_reset=args.no_reset,
- delay=args.delay if args.delay else 0,
- **kwargs)
+ execution = action_exec_mgr.re_run(
+ execution_id=args.id,
+ parameters=action_parameters,
+ tasks=args.tasks,
+ no_reset=args.no_reset,
+ delay=args.delay if args.delay else 0,
+ **kwargs,
+ )
- execution = self._get_execution_result(execution=execution,
- action_exec_mgr=action_exec_mgr,
- args=args, **kwargs)
+ execution = self._get_execution_result(
+ execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs
+ )
return execution
class ActionExecutionPauseCommand(ActionRunCommandMixin, ResourceViewCommand):
- display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status',
- 'start_timestamp', 'end_timestamp', 'result']
+ display_attributes = [
+ "id",
+ "action.ref",
+ "context.user",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ "result",
+ ]
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionPauseCommand, self).__init__(
- resource, 'pause', 'Pause %s (workflow executions only).' %
- resource.get_plural_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('ids',
- nargs='+',
- help='ID of action execution to pause.')
+ resource,
+ "pause",
+ "Pause %s (workflow executions only)."
+ % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "ids", nargs="+", help="ID of action execution to pause."
+ )
self._add_common_options()
@@ -1348,7 +1676,9 @@ def run(self, args, **kwargs):
responses.append([execution_id, response])
except resource.ResourceNotFoundError:
self.print_not_found(args.ids)
- raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id))
+ raise ResourceNotFoundError(
+ "Execution with id %s not found." % (execution_id)
+ )
return responses
@@ -1367,18 +1697,30 @@ def _print_result(self, args, execution_id, execution, **kwargs):
class ActionExecutionResumeCommand(ActionRunCommandMixin, ResourceViewCommand):
- display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status',
- 'start_timestamp', 'end_timestamp', 'result']
+ display_attributes = [
+ "id",
+ "action.ref",
+ "context.user",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "end_timestamp",
+ "result",
+ ]
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionResumeCommand, self).__init__(
- resource, 'resume', 'Resume %s (workflow executions only).' %
- resource.get_plural_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('ids',
- nargs='+',
- help='ID of action execution to resume.')
+ resource,
+ "resume",
+ "Resume %s (workflow executions only)."
+ % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "ids", nargs="+", help="ID of action execution to resume."
+ )
self._add_common_options()
@@ -1391,7 +1733,9 @@ def run(self, args, **kwargs):
responses.append([execution_id, response])
except resource.ResourceNotFoundError:
self.print_not_found(execution_id)
- raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id))
+ raise ResourceNotFoundError(
+ "Execution with id %s not found." % (execution_id)
+ )
return responses
@@ -1412,22 +1756,33 @@ def _print_result(self, args, execution, **kwargs):
class ActionExecutionTailCommand(resource.ResourceCommand):
def __init__(self, resource, *args, **kwargs):
super(ActionExecutionTailCommand, self).__init__(
- resource, kwargs.pop('name', 'tail'),
- 'Tail output of a particular execution.',
- *args, **kwargs)
-
- self.parser.add_argument('id', nargs='?',
- metavar='id',
- default='last',
- help='ID of action execution to tail.')
- self.parser.add_argument('--type', dest='output_type', action='store',
- help=('Type of output to tail for. If not provided, '
- 'defaults to all.'))
- self.parser.add_argument('--include-metadata', dest='include_metadata',
- action='store_true',
- default=False,
- help=('Include metadata (timestamp, output type) with the '
- 'output.'))
+ resource,
+ kwargs.pop("name", "tail"),
+ "Tail output of a particular execution.",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "id",
+ nargs="?",
+ metavar="id",
+ default="last",
+ help="ID of action execution to tail.",
+ )
+ self.parser.add_argument(
+ "--type",
+ dest="output_type",
+ action="store",
+ help=("Type of output to tail for. If not provided, " "defaults to all."),
+ )
+ self.parser.add_argument(
+ "--include-metadata",
+ dest="include_metadata",
+ action="store_true",
+ default=False,
+ help=("Include metadata (timestamp, output type) with the " "output."),
+ )
def run(self, args, **kwargs):
pass
@@ -1435,45 +1790,55 @@ def run(self, args, **kwargs):
@add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
execution_id = args.id
- output_type = getattr(args, 'output_type', None)
+ output_type = getattr(args, "output_type", None)
include_metadata = args.include_metadata
# Special case for id "last"
- if execution_id == 'last':
+ if execution_id == "last":
executions = self.manager.query(limit=1)
if executions:
execution = executions[0]
execution_id = execution.id
else:
- print('No executions found in db.')
+ print("No executions found in db.")
return
else:
execution = self.manager.get_by_id(execution_id, **kwargs)
if not execution:
- raise ResourceNotFoundError('Execution with id %s not found.' % (args.id))
+ raise ResourceNotFoundError("Execution with id %s not found." % (args.id))
execution_manager = self.manager
- stream_manager = self.app.client.managers['Stream']
- ActionExecutionTailCommand.tail_execution(execution=execution,
- execution_manager=execution_manager,
- stream_manager=stream_manager,
- output_type=output_type,
- include_metadata=include_metadata,
- **kwargs)
+ stream_manager = self.app.client.managers["Stream"]
+ ActionExecutionTailCommand.tail_execution(
+ execution=execution,
+ execution_manager=execution_manager,
+ stream_manager=stream_manager,
+ output_type=output_type,
+ include_metadata=include_metadata,
+ **kwargs,
+ )
@classmethod
- def tail_execution(cls, execution_manager, stream_manager, execution, output_type=None,
- include_metadata=False, **kwargs):
+ def tail_execution(
+ cls,
+ execution_manager,
+ stream_manager,
+ execution,
+ output_type=None,
+ include_metadata=False,
+ **kwargs,
+ ):
execution_id = str(execution.id)
# Indicates if the execution we are tailing is a child execution in a workflow
context = cls.get_normalized_context_execution_task_event(execution.__dict__)
- has_parent_attribute = bool(getattr(execution, 'parent', None))
- has_parent_execution_id = bool(context['parent_execution_id'])
+ has_parent_attribute = bool(getattr(execution, "parent", None))
+ has_parent_execution_id = bool(context["parent_execution_id"])
- is_tailing_execution_child_execution = bool(has_parent_attribute or
- has_parent_execution_id)
+ is_tailing_execution_child_execution = bool(
+ has_parent_attribute or has_parent_execution_id
+ )
# Note: For non-workflow actions child_execution_id always matches parent_execution_id so
# we don't need to do any other checks to determine if executions represents a workflow
@@ -1484,10 +1849,14 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ
# NOTE: This doesn't recurse down into child executions if user is tailing a workflow
# execution
if execution.status in LIVEACTION_COMPLETED_STATES:
- output = execution_manager.get_output(execution_id=execution_id,
- output_type=output_type)
+ output = execution_manager.get_output(
+ execution_id=execution_id, output_type=output_type
+ )
print(output)
- print('Execution %s has completed (status=%s).' % (execution_id, execution.status))
+ print(
+ "Execution %s has completed (status=%s)."
+ % (execution_id, execution.status)
+ )
return
# We keep track of all the workflow executions which could contain children.
@@ -1497,29 +1866,27 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ
# Retrieve parent execution object so we can keep track of any existing children
# executions (only applies to already running executions).
- filters = {
- 'params': {
- 'include_attributes': 'id,children'
- }
- }
+ filters = {"params": {"include_attributes": "id,children"}}
execution = execution_manager.get_by_id(id=execution_id, **filters)
- children_execution_ids = getattr(execution, 'children', [])
+ children_execution_ids = getattr(execution, "children", [])
workflow_execution_ids.update(children_execution_ids)
- events = ['st2.execution__update', 'st2.execution.output__create']
- for event in stream_manager.listen(events,
- end_execution_id=execution_id,
- end_event="st2.execution__update",
- **kwargs):
- status = event.get('status', None)
+ events = ["st2.execution__update", "st2.execution.output__create"]
+ for event in stream_manager.listen(
+ events,
+ end_execution_id=execution_id,
+ end_event="st2.execution__update",
+ **kwargs,
+ ):
+ status = event.get("status", None)
is_execution_event = status is not None
if is_execution_event:
context = cls.get_normalized_context_execution_task_event(event)
- task_execution_id = context['execution_id']
- task_name = context['task_name']
- task_parent_execution_id = context['parent_execution_id']
+ task_execution_id = context["execution_id"]
+ task_name = context["task_name"]
+ task_parent_execution_id = context["parent_execution_id"]
# An execution is considered a child execution if it has parent execution id
is_child_execution = bool(task_parent_execution_id)
@@ -1536,14 +1903,18 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ
if is_child_execution:
if status == LIVEACTION_STATUS_RUNNING:
- print('Child execution (task=%s) %s has started.' % (task_name,
- task_execution_id))
- print('')
+ print(
+ "Child execution (task=%s) %s has started."
+ % (task_name, task_execution_id)
+ )
+ print("")
continue
elif status in LIVEACTION_COMPLETED_STATES:
- print('')
- print('Child execution (task=%s) %s has finished (status=%s).' %
- (task_name, task_execution_id, status))
+ print("")
+ print(
+ "Child execution (task=%s) %s has finished (status=%s)."
+ % (task_name, task_execution_id, status)
+ )
if is_tailing_execution_child_execution:
# User is tailing a child execution inside a workflow, stop the command.
@@ -1556,56 +1927,69 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ
else:
# NOTE: In some situations execution update event with "running" status is
# dispatched twice so we ignore any duplicated events
- if status == LIVEACTION_STATUS_RUNNING and not event.get('children', []):
- print('Execution %s has started.' % (execution_id))
- print('')
+ if status == LIVEACTION_STATUS_RUNNING and not event.get(
+ "children", []
+ ):
+ print("Execution %s has started." % (execution_id))
+ print("")
continue
elif status in LIVEACTION_COMPLETED_STATES:
# Bail out once parent execution has finished
- print('')
- print('Execution %s has completed (status=%s).' % (execution_id, status))
+ print("")
+ print(
+ "Execution %s has completed (status=%s)."
+ % (execution_id, status)
+ )
break
else:
# We don't care about other execution events
continue
# Ignore events for executions which don't belong to the one we are tailing
- event_execution_id = event['execution_id']
+ event_execution_id = event["execution_id"]
if event_execution_id not in workflow_execution_ids:
continue
# Filter on output_type if provided
- event_output_type = event.get('output_type', None)
- if output_type != 'all' and output_type and (event_output_type != output_type):
+ event_output_type = event.get("output_type", None)
+ if (
+ output_type != "all"
+ and output_type
+ and (event_output_type != output_type)
+ ):
continue
if include_metadata:
- sys.stdout.write('[%s][%s] %s' % (event['timestamp'], event['output_type'],
- event['data']))
+ sys.stdout.write(
+ "[%s][%s] %s"
+ % (event["timestamp"], event["output_type"], event["data"])
+ )
else:
- sys.stdout.write(event['data'])
+ sys.stdout.write(event["data"])
@classmethod
def get_normalized_context_execution_task_event(cls, event):
"""
Return a dictionary with normalized context attributes for execution event or object.
"""
- context = event.get('context', {})
-
- result = {
- 'parent_execution_id': None,
- 'execution_id': None,
- 'task_name': None
- }
-
- if 'orquesta' in context:
- result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None)
- result['execution_id'] = event['id']
- result['task_name'] = context.get('orquesta', {}).get('task_name', 'unknown')
+ context = event.get("context", {})
+
+ result = {"parent_execution_id": None, "execution_id": None, "task_name": None}
+
+ if "orquesta" in context:
+ result["parent_execution_id"] = context.get("parent", {}).get(
+ "execution_id", None
+ )
+ result["execution_id"] = event["id"]
+ result["task_name"] = context.get("orquesta", {}).get(
+ "task_name", "unknown"
+ )
else:
# Action chain workflow
- result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None)
- result['execution_id'] = event['id']
- result['task_name'] = context.get('chain', {}).get('name', 'unknown')
+ result["parent_execution_id"] = context.get("parent", {}).get(
+ "execution_id", None
+ )
+ result["execution_id"] = event["id"]
+ result["task_name"] = context.get("chain", {}).get("name", "unknown")
return result
diff --git a/st2client/st2client/commands/action_alias.py b/st2client/st2client/commands/action_alias.py
index 32a65776cc..d6f5fbcfc1 100644
--- a/st2client/st2client/commands/action_alias.py
+++ b/st2client/st2client/commands/action_alias.py
@@ -22,63 +22,87 @@
from st2client.formatters import table
-__all__ = [
- 'ActionAliasBranch',
- 'ActionAliasMatchCommand',
- 'ActionAliasExecuteCommand'
-]
+__all__ = ["ActionAliasBranch", "ActionAliasMatchCommand", "ActionAliasExecuteCommand"]
class ActionAliasBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(ActionAliasBranch, self).__init__(
- ActionAlias, description, app, subparsers,
- parent_parser=parent_parser, read_only=False,
- commands={
- 'list': ActionAliasListCommand,
- 'get': ActionAliasGetCommand
- })
-
- self.commands['match'] = ActionAliasMatchCommand(
- self.resource, self.app, self.subparsers,
- add_help=True)
- self.commands['execute'] = ActionAliasExecuteCommand(
- self.resource, self.app, self.subparsers,
- add_help=True)
+ ActionAlias,
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ read_only=False,
+ commands={"list": ActionAliasListCommand, "get": ActionAliasGetCommand},
+ )
+
+ self.commands["match"] = ActionAliasMatchCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
+ self.commands["execute"] = ActionAliasExecuteCommand(
+ self.resource, self.app, self.subparsers, add_help=True
+ )
class ActionAliasListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['ref', 'pack', 'description', 'enabled']
+ display_attributes = ["ref", "pack", "description", "enabled"]
class ActionAliasGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'description',
- 'enabled', 'action_ref', 'formats']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ "action_ref",
+ "formats",
+ ]
class ActionAliasMatchCommand(resource.ResourceCommand):
- display_attributes = ['name', 'description']
+ display_attributes = ["name", "description"]
def __init__(self, resource, *args, **kwargs):
super(ActionAliasMatchCommand, self).__init__(
- resource, 'match',
- 'Get the %s that match the command text.' %
- resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('match_text',
- metavar='command',
- help=('Get the %s that match the command text.' %
- resource.get_display_name().lower()))
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ resource,
+ "match",
+ "Get the %s that match the command text."
+ % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "match_text",
+ metavar="command",
+ help=(
+ "Get the %s that match the command text."
+ % resource.get_display_name().lower()
+ ),
+ )
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -90,40 +114,62 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class ActionAliasExecuteCommand(resource.ResourceCommand):
- display_attributes = ['name']
+ display_attributes = ["name"]
def __init__(self, resource, *args, **kwargs):
super(ActionAliasExecuteCommand, self).__init__(
- resource, 'execute',
- ('Execute the command text by finding a matching %s.' %
- resource.get_display_name().lower()), *args, **kwargs)
-
- self.parser.add_argument('command_text',
- metavar='command',
- help=('Execute the command text by finding a matching %s.' %
- resource.get_display_name().lower()))
- self.parser.add_argument('-u', '--user', type=str, default=None,
- help='User under which to run the action (admins only).')
+ resource,
+ "execute",
+ (
+ "Execute the command text by finding a matching %s."
+ % resource.get_display_name().lower()
+ ),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "command_text",
+ metavar="command",
+ help=(
+ "Execute the command text by finding a matching %s."
+ % resource.get_display_name().lower()
+ ),
+ )
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ type=str,
+ default=None,
+ help="User under which to run the action (admins only).",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
payload = core.Resource()
payload.command = args.command_text
payload.user = args.user or ""
- payload.source_channel = 'cli'
+ payload.source_channel = "cli"
- alias_execution_mgr = self.app.client.managers['ActionAliasExecution']
+ alias_execution_mgr = self.app.client.managers["ActionAliasExecution"]
execution = alias_execution_mgr.match_and_execute(payload)
return execution
def run_and_print(self, args, **kwargs):
execution = self.run(args, **kwargs)
- print("Matching Action-alias: '%s'" % execution.actionalias['ref'])
- print("To get the results, execute:\n st2 execution get %s" %
- (execution.execution['id']))
+ print("Matching Action-alias: '%s'" % execution.actionalias["ref"])
+ print(
+ "To get the results, execute:\n st2 execution get %s"
+ % (execution.execution["id"])
+ )
diff --git a/st2client/st2client/commands/auth.py b/st2client/st2client/commands/auth.py
index 40066d1a5e..5b0507f324 100644
--- a/st2client/st2client/commands/auth.py
+++ b/st2client/st2client/commands/auth.py
@@ -39,36 +39,54 @@
class TokenCreateCommand(resource.ResourceCommand):
- display_attributes = ['user', 'token', 'expiry']
+ display_attributes = ["user", "token", "expiry"]
def __init__(self, resource, *args, **kwargs):
- kwargs['has_token_opt'] = False
+ kwargs["has_token_opt"] = False
super(TokenCreateCommand, self).__init__(
- resource, kwargs.pop('name', 'create'),
- 'Authenticate user and acquire access token.',
- *args, **kwargs)
-
- self.parser.add_argument('username',
- help='Name of the user to authenticate.')
-
- self.parser.add_argument('-p', '--password', dest='password',
- help='Password for the user. If password is not provided, '
- 'it will be prompted for.')
- self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None,
- help='The life span of the token in seconds. '
- 'Max TTL configured by the admin supersedes this.')
- self.parser.add_argument('-t', '--only-token', action='store_true', dest='only_token',
- default=False,
- help='On successful authentication, print only token to the '
- 'console.')
+ resource,
+ kwargs.pop("name", "create"),
+ "Authenticate user and acquire access token.",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument("username", help="Name of the user to authenticate.")
+
+ self.parser.add_argument(
+ "-p",
+ "--password",
+ dest="password",
+ help="Password for the user. If password is not provided, "
+ "it will be prompted for.",
+ )
+ self.parser.add_argument(
+ "-l",
+ "--ttl",
+ type=int,
+ dest="ttl",
+ default=None,
+ help="The life span of the token in seconds. "
+ "Max TTL configured by the admin supersedes this.",
+ )
+ self.parser.add_argument(
+ "-t",
+ "--only-token",
+ action="store_true",
+ dest="only_token",
+ default=False,
+ help="On successful authentication, print only token to the " "console.",
+ )
def run(self, args, **kwargs):
if not args.password:
args.password = getpass.getpass()
instance = self.resource(ttl=args.ttl) if args.ttl else self.resource()
- return self.manager.create(instance, auth=(args.username, args.password), **kwargs)
+ return self.manager.create(
+ instance, auth=(args.username, args.password), **kwargs
+ )
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
@@ -76,35 +94,57 @@ def run_and_print(self, args, **kwargs):
if args.only_token:
print(instance.token)
else:
- self.print_output(instance, table.PropertyValueTable,
- attributes=self.display_attributes, json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=self.display_attributes,
+ json=args.json,
+ yaml=args.yaml,
+ )
class LoginCommand(resource.ResourceCommand):
- display_attributes = ['user', 'token', 'expiry']
+ display_attributes = ["user", "token", "expiry"]
def __init__(self, resource, *args, **kwargs):
- kwargs['has_token_opt'] = False
+ kwargs["has_token_opt"] = False
super(LoginCommand, self).__init__(
- resource, kwargs.pop('name', 'create'),
- 'Authenticate user, acquire access token, and update CLI config directory',
- *args, **kwargs)
-
- self.parser.add_argument('username',
- help='Name of the user to authenticate.')
-
- self.parser.add_argument('-p', '--password', dest='password',
- help='Password for the user. If password is not provided, '
- 'it will be prompted for.')
- self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None,
- help='The life span of the token in seconds. '
- 'Max TTL configured by the admin supersedes this.')
- self.parser.add_argument('-w', '--write-password', action='store_true', default=False,
- dest='write_password',
- help='Write the password in plain text to the config file '
- '(default is to omit it)')
+ resource,
+ kwargs.pop("name", "create"),
+ "Authenticate user, acquire access token, and update CLI config directory",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument("username", help="Name of the user to authenticate.")
+
+ self.parser.add_argument(
+ "-p",
+ "--password",
+ dest="password",
+ help="Password for the user. If password is not provided, "
+ "it will be prompted for.",
+ )
+ self.parser.add_argument(
+ "-l",
+ "--ttl",
+ type=int,
+ dest="ttl",
+ default=None,
+ help="The life span of the token in seconds. "
+ "Max TTL configured by the admin supersedes this.",
+ )
+ self.parser.add_argument(
+ "-w",
+ "--write-password",
+ action="store_true",
+ default=False,
+ dest="write_password",
+ help="Write the password in plain text to the config file "
+ "(default is to omit it)",
+ )
def run(self, args, **kwargs):
@@ -122,7 +162,9 @@ def run(self, args, **kwargs):
config_file = config_parser.ST2_CONFIG_PATH
# Retrieve token
- manager = self.manager.create(instance, auth=(args.username, args.password), **kwargs)
+ manager = self.manager.create(
+ instance, auth=(args.username, args.password), **kwargs
+ )
cli._cache_auth_token(token_obj=manager)
# Update existing configuration with new credentials
@@ -130,18 +172,18 @@ def run(self, args, **kwargs):
config.read(config_file)
# Modify config (and optionally populate with password)
- if not config.has_section('credentials'):
- config.add_section('credentials')
+ if not config.has_section("credentials"):
+ config.add_section("credentials")
- config.set('credentials', 'username', args.username)
+ config.set("credentials", "username", args.username)
if args.write_password:
- config.set('credentials', 'password', args.password)
+ config.set("credentials", "password", args.password)
else:
# Remove any existing password from config
- config.remove_option('credentials', 'password')
+ config.remove_option("credentials", "password")
config_existed = os.path.exists(config_file)
- with open(config_file, 'w') as cfg_file_out:
+ with open(config_file, "w") as cfg_file_out:
config.write(cfg_file_out)
# If we created the config file, correct the permissions
if not config_existed:
@@ -156,35 +198,44 @@ def run_and_print(self, args, **kwargs):
if self.app.client.debug:
raise
- raise Exception('Failed to log in as %s: %s' % (args.username, six.text_type(e)))
+ raise Exception(
+ "Failed to log in as %s: %s" % (args.username, six.text_type(e))
+ )
- print('Logged in as %s' % (args.username))
+ print("Logged in as %s" % (args.username))
if not args.write_password:
# Note: Client can't depend and import from common so we need to hard-code this
# default value
token_expire_hours = 24
- print('')
- print('Note: You didn\'t use --write-password option so the password hasn\'t been '
- 'stored in the client config and you will need to login again in %s hours when '
- 'the auth token expires.' % (token_expire_hours))
- print('As an alternative, you can run st2 login command with the "--write-password" '
- 'flag, but keep it mind this will cause it to store the password in plain-text '
- 'in the client config file (~/.st2/config).')
+ print("")
+ print(
+ "Note: You didn't use --write-password option so the password hasn't been "
+ "stored in the client config and you will need to login again in %s hours when "
+ "the auth token expires." % (token_expire_hours)
+ )
+ print(
+ 'As an alternative, you can run st2 login command with the "--write-password" '
+ "flag, but keep it mind this will cause it to store the password in plain-text "
+ "in the client config file (~/.st2/config)."
+ )
class WhoamiCommand(resource.ResourceCommand):
- display_attributes = ['user', 'token', 'expiry']
+ display_attributes = ["user", "token", "expiry"]
def __init__(self, resource, *args, **kwargs):
- kwargs['has_token_opt'] = False
+ kwargs["has_token_opt"] = False
super(WhoamiCommand, self).__init__(
- resource, kwargs.pop('name', 'create'),
- 'Display the currently authenticated user',
- *args, **kwargs)
+ resource,
+ kwargs.pop("name", "create"),
+ "Display the currently authenticated user",
+ *args,
+ **kwargs,
+ )
def run(self, args, **kwargs):
user_info = self.app.client.get_user_info(**kwargs)
@@ -194,119 +245,157 @@ def run_and_print(self, args, **kwargs):
try:
user_info = self.run(args, **kwargs)
except Exception as e:
- response = getattr(e, 'response', None)
- status_code = getattr(response, 'status_code', None)
- is_unathorized_error = (status_code == http_client.UNAUTHORIZED)
+ response = getattr(e, "response", None)
+ status_code = getattr(response, "status_code", None)
+ is_unathorized_error = status_code == http_client.UNAUTHORIZED
if response and is_unathorized_error:
- print('Not authenticated')
+ print("Not authenticated")
else:
- print('Unable to retrieve currently logged-in user')
+ print("Unable to retrieve currently logged-in user")
if self.app.client.debug:
raise
return
- print('Currently logged in as "%s".' % (user_info['username']))
- print('')
- print('Authentication method: %s' % (user_info['authentication']['method']))
+ print('Currently logged in as "%s".' % (user_info["username"]))
+ print("")
+ print("Authentication method: %s" % (user_info["authentication"]["method"]))
- if user_info['authentication']['method'] == 'authentication token':
- print('Authentication token expire time: %s' %
- (user_info['authentication']['token_expire']))
+ if user_info["authentication"]["method"] == "authentication token":
+ print(
+ "Authentication token expire time: %s"
+ % (user_info["authentication"]["token_expire"])
+ )
- print('')
- print('RBAC:')
- print(' - Enabled: %s' % (user_info['rbac']['enabled']))
- print(' - Roles: %s' % (', '.join(user_info['rbac']['roles'])))
+ print("")
+ print("RBAC:")
+ print(" - Enabled: %s" % (user_info["rbac"]["enabled"]))
+ print(" - Roles: %s" % (", ".join(user_info["rbac"]["roles"])))
class ApiKeyBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(ApiKeyBranch, self).__init__(
- models.ApiKey, description, app, subparsers,
+ models.ApiKey,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': ApiKeyListCommand,
- 'get': ApiKeyGetCommand,
- 'create': ApiKeyCreateCommand,
- 'update': NoopCommand,
- 'delete': ApiKeyDeleteCommand
- })
-
- self.commands['enable'] = ApiKeyEnableCommand(self.resource, self.app, self.subparsers)
- self.commands['disable'] = ApiKeyDisableCommand(self.resource, self.app, self.subparsers)
- self.commands['load'] = ApiKeyLoadCommand(self.resource, self.app, self.subparsers)
+ "list": ApiKeyListCommand,
+ "get": ApiKeyGetCommand,
+ "create": ApiKeyCreateCommand,
+ "update": NoopCommand,
+ "delete": ApiKeyDeleteCommand,
+ },
+ )
+
+ self.commands["enable"] = ApiKeyEnableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["disable"] = ApiKeyDisableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["load"] = ApiKeyLoadCommand(
+ self.resource, self.app, self.subparsers
+ )
class ApiKeyListCommand(resource.ResourceListCommand):
- detail_display_attributes = ['all']
- display_attributes = ['id', 'user', 'metadata']
+ detail_display_attributes = ["all"]
+ display_attributes = ["id", "user", "metadata"]
def __init__(self, resource, *args, **kwargs):
super(ApiKeyListCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-u', '--user', type=str,
- help='Only return ApiKeys belonging to the provided user')
- self.parser.add_argument('-d', '--detail', action='store_true',
- help='Full list of attributes.')
- self.parser.add_argument('--show-secrets', action='store_true',
- help='Full list of attributes.')
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ type=str,
+ help="Only return ApiKeys belonging to the provided user",
+ )
+ self.parser.add_argument(
+ "-d", "--detail", action="store_true", help="Full list of attributes."
+ )
+ self.parser.add_argument(
+ "--show-secrets", action="store_true", help="Full list of attributes."
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
filters = {}
- filters['user'] = args.user
+ filters["user"] = args.user
filters.update(**kwargs)
# show_secrets is not a filter but a query param. There is some special
# handling for filters in the get method which reuqires this odd hack.
if args.show_secrets:
- params = filters.get('params', {})
- params['show_secrets'] = True
- filters['params'] = params
+ params = filters.get("params", {})
+ params["show_secrets"] = True
+ filters["params"] = params
return self.manager.get_all(**filters)
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
attr = self.detail_display_attributes if args.detail else args.attr
- self.print_output(instances, table.MultiColumnTable,
- attributes=attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class ApiKeyGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'user', 'metadata']
+ display_attributes = ["all"]
+ attribute_display_order = ["id", "user", "metadata"]
- pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK
+ pk_argument_name = "key_or_id" # name of the attribute which stores resource PK
class ApiKeyCreateCommand(resource.ResourceCommand):
-
def __init__(self, resource, *args, **kwargs):
super(ApiKeyCreateCommand, self).__init__(
- resource, 'create', 'Create a new %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('-u', '--user', type=str,
- help='User for which to create API Keys.',
- default='')
- self.parser.add_argument('-m', '--metadata', type=json.loads,
- help='Optional metadata to associate with the API Keys.',
- default={})
- self.parser.add_argument('-k', '--only-key', action='store_true', dest='only_key',
- default=False,
- help='Only print API Key to the console on creation.')
+ resource,
+ "create",
+ "Create a new %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ type=str,
+ help="User for which to create API Keys.",
+ default="",
+ )
+ self.parser.add_argument(
+ "-m",
+ "--metadata",
+ type=json.loads,
+ help="Optional metadata to associate with the API Keys.",
+ default={},
+ )
+ self.parser.add_argument(
+ "-k",
+ "--only-key",
+ action="store_true",
+ dest="only_key",
+ default=False,
+ help="Only print API Key to the console on creation.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
data = {}
if args.user:
- data['user'] = args.user
+ data["user"] = args.user
if args.metadata:
- data['metadata'] = args.metadata
+ data["metadata"] = args.metadata
instance = self.resource.deserialize(data)
return self.manager.create(instance, **kwargs)
@@ -314,39 +403,59 @@ def run_and_print(self, args, **kwargs):
try:
instance = self.run(args, **kwargs)
if not instance:
- raise Exception('Server did not create instance.')
+ raise Exception("Server did not create instance.")
except Exception as e:
message = six.text_type(e)
- print('ERROR: %s' % (message))
+ print("ERROR: %s" % (message))
raise OperationFailureException(message)
if args.only_key:
print(instance.key)
else:
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
class ApiKeyLoadCommand(resource.ResourceCommand):
-
def __init__(self, resource, *args, **kwargs):
super(ApiKeyLoadCommand, self).__init__(
- resource, 'load', 'Load %s from a file.' % resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('file',
- help=('JSON/YAML file containing the %s(s) to load.'
- % resource.get_display_name().lower()),
- default='')
-
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ resource,
+ "load",
+ "Load %s from a file." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "file",
+ help=(
+ "JSON/YAML file containing the %s(s) to load."
+ % resource.get_display_name().lower()
+ ),
+ default="",
+ )
+
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
resources = resource.load_meta_file(args.file)
if not resources:
- print('No %s found in %s.' % (self.resource.get_display_name().lower(), args.file))
+ print(
+ "No %s found in %s."
+ % (self.resource.get_display_name().lower(), args.file)
+ )
return None
if not isinstance(resources, list):
resources = [resources]
@@ -354,14 +463,14 @@ def run(self, args, **kwargs):
for res in resources:
# pick only the meaningful properties.
data = {
- 'user': res['user'], # required
- 'key_hash': res['key_hash'], # required
- 'metadata': res.get('metadata', {}),
- 'enabled': res.get('enabled', False)
+ "user": res["user"], # required
+ "key_hash": res["key_hash"], # required
+ "metadata": res.get("metadata", {}),
+ "enabled": res.get("enabled", False),
}
- if 'id' in res:
- data['id'] = res['id']
+ if "id" in res:
+ data["id"] = res["id"]
instance = self.resource.deserialize(data)
@@ -381,19 +490,23 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
if instances:
- self.print_output(instances, table.MultiColumnTable,
- attributes=ApiKeyListCommand.display_attributes,
- widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=ApiKeyListCommand.display_attributes,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class ApiKeyDeleteCommand(resource.ResourceDeleteCommand):
- pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK
+ pk_argument_name = "key_or_id" # name of the attribute which stores resource PK
class ApiKeyEnableCommand(resource.ResourceEnableCommand):
- pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK
+ pk_argument_name = "key_or_id" # name of the attribute which stores resource PK
class ApiKeyDisableCommand(resource.ResourceDisableCommand):
- pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK
+ pk_argument_name = "key_or_id" # name of the attribute which stores resource PK
diff --git a/st2client/st2client/commands/inquiry.py b/st2client/st2client/commands/inquiry.py
index 250c86b3d7..d9395a5c54 100644
--- a/st2client/st2client/commands/inquiry.py
+++ b/st2client/st2client/commands/inquiry.py
@@ -25,60 +25,81 @@
LOG = logging.getLogger(__name__)
-DEFAULT_SCOPE = 'system'
+DEFAULT_SCOPE = "system"
class InquiryBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(InquiryBranch, self).__init__(
- Inquiry, description, app, subparsers,
- parent_parser=parent_parser, read_only=True,
- commands={'list': InquiryListCommand,
- 'get': InquiryGetCommand})
+ Inquiry,
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ read_only=True,
+ commands={"list": InquiryListCommand, "get": InquiryGetCommand},
+ )
# Register extended commands
- self.commands['respond'] = InquiryRespondCommand(
- self.resource, self.app, self.subparsers)
+ self.commands["respond"] = InquiryRespondCommand(
+ self.resource, self.app, self.subparsers
+ )
class InquiryListCommand(resource.ResourceCommand):
# Omitting "schema" and "response", as it doesn't really show up in a table well.
# The user can drill into a specific Inquiry to get this
- display_attributes = [
- 'id',
- 'roles',
- 'users',
- 'route',
- 'ttl'
- ]
+ display_attributes = ["id", "roles", "users", "route", "ttl"]
def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
super(InquiryListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s most recent %s.' %
- (self.default_limit, resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -87,17 +108,21 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances, count = self.run(args, **kwargs)
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json,
- yaml=args.yaml)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class InquiryGetCommand(resource.ResourceGetCommand):
- pk_argument_name = 'id'
- display_attributes = ['id', 'roles', 'users', 'route', 'ttl', 'schema']
+ pk_argument_name = "id"
+ display_attributes = ["id", "roles", "users", "route", "ttl", "schema"]
def __init__(self, kv_resource, *args, **kwargs):
super(InquiryGetCommand, self).__init__(kv_resource, *args, **kwargs)
@@ -109,22 +134,28 @@ def run(self, args, **kwargs):
class InquiryRespondCommand(resource.ResourceCommand):
- display_attributes = ['id', 'response']
+ display_attributes = ["id", "response"]
def __init__(self, resource, *args, **kwargs):
super(InquiryRespondCommand, self).__init__(
- resource, 'respond',
- 'Respond to an %s.' % resource.get_display_name().lower(),
- *args, **kwargs
+ resource,
+ "respond",
+ "Respond to an %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
)
- self.parser.add_argument('id',
- metavar='id',
- help='Inquiry ID')
- self.parser.add_argument('-r', '--response', type=str, dest='response',
- default=None,
- help=('Entire response payload as JSON string '
- '(bypass interactive mode)'))
+ self.parser.add_argument("id", metavar="id", help="Inquiry ID")
+ self.parser.add_argument(
+ "-r",
+ "--response",
+ type=str,
+ dest="response",
+ default=None,
+ help=(
+ "Entire response payload as JSON string " "(bypass interactive mode)"
+ ),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -135,12 +166,13 @@ def run(self, args, **kwargs):
instance.response = json.loads(args.response)
else:
response = InteractiveForm(
- inquiry.schema.get('properties')).initiate_dialog()
+ inquiry.schema.get("properties")
+ ).initiate_dialog()
instance.response = response
- return self.manager.respond(inquiry_id=instance.id,
- inquiry_response=instance.response,
- **kwargs)
+ return self.manager.respond(
+ inquiry_id=instance.id, inquiry_response=instance.response, **kwargs
+ )
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
diff --git a/st2client/st2client/commands/keyvalue.py b/st2client/st2client/commands/keyvalue.py
index 9c0d06f806..e87f6afa35 100644
--- a/st2client/st2client/commands/keyvalue.py
+++ b/st2client/st2client/commands/keyvalue.py
@@ -31,83 +31,125 @@
LOG = logging.getLogger(__name__)
-DEFAULT_LIST_SCOPE = 'all'
-DEFAULT_GET_SCOPE = 'system'
-DEFAULT_CUD_SCOPE = 'system'
+DEFAULT_LIST_SCOPE = "all"
+DEFAULT_GET_SCOPE = "system"
+DEFAULT_CUD_SCOPE = "system"
class KeyValuePairBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(KeyValuePairBranch, self).__init__(
- KeyValuePair, description, app, subparsers,
+ KeyValuePair,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': KeyValuePairListCommand,
- 'get': KeyValuePairGetCommand,
- 'delete': KeyValuePairDeleteCommand,
- 'create': NoopCommand,
- 'update': NoopCommand
- })
+ "list": KeyValuePairListCommand,
+ "get": KeyValuePairGetCommand,
+ "delete": KeyValuePairDeleteCommand,
+ "create": NoopCommand,
+ "update": NoopCommand,
+ },
+ )
# Registers extended commands
- self.commands['set'] = KeyValuePairSetCommand(self.resource, self.app,
- self.subparsers)
- self.commands['load'] = KeyValuePairLoadCommand(
- self.resource, self.app, self.subparsers)
- self.commands['delete_by_prefix'] = KeyValuePairDeleteByPrefixCommand(
- self.resource, self.app, self.subparsers)
+ self.commands["set"] = KeyValuePairSetCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["load"] = KeyValuePairLoadCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["delete_by_prefix"] = KeyValuePairDeleteByPrefixCommand(
+ self.resource, self.app, self.subparsers
+ )
# Remove unsupported commands
# TODO: Refactor parent class and make it nicer
- del self.commands['create']
- del self.commands['update']
+ del self.commands["create"]
+ del self.commands["update"]
class KeyValuePairListCommand(resource.ResourceTableCommand):
- display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'user',
- 'expire_timestamp']
+ display_attributes = [
+ "name",
+ "value",
+ "secret",
+ "encrypted",
+ "scope",
+ "user",
+ "expire_timestamp",
+ ]
attribute_transform_functions = {
- 'expire_timestamp': format_isodate_for_user_timezone,
+ "expire_timestamp": format_isodate_for_user_timezone,
}
def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
- super(KeyValuePairListCommand, self).__init__(resource, 'list',
- 'Get the list of the %s most recent %s.' %
- (self.default_limit,
- resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ super(KeyValuePairListCommand, self).__init__(
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
# Filter options
- self.parser.add_argument('--prefix', help=('Only return values with names starting with '
- 'the provided prefix.'))
- self.parser.add_argument('-d', '--decrypt', action='store_true',
- help='Decrypt secrets and displays plain text.')
- self.parser.add_argument('-s', '--scope', default=DEFAULT_LIST_SCOPE, dest='scope',
- help='Scope item is under. Example: "user".')
- self.parser.add_argument('-u', '--user', dest='user', default=None,
- help='User for user scoped items (admin only).')
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
+ self.parser.add_argument(
+ "--prefix",
+ help=(
+ "Only return values with names starting with " "the provided prefix."
+ ),
+ )
+ self.parser.add_argument(
+ "-d",
+ "--decrypt",
+ action="store_true",
+ help="Decrypt secrets and displays plain text.",
+ )
+ self.parser.add_argument(
+ "-s",
+ "--scope",
+ default=DEFAULT_LIST_SCOPE,
+ dest="scope",
+ help='Scope item is under. Example: "user".',
+ )
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ dest="user",
+ default=None,
+ help="User for user scoped items (admin only).",
+ )
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.prefix:
- kwargs['prefix'] = args.prefix
+ kwargs["prefix"] = args.prefix
- decrypt = getattr(args, 'decrypt', False)
- kwargs['params'] = {'decrypt': str(decrypt).lower()}
- scope = getattr(args, 'scope', DEFAULT_LIST_SCOPE)
- kwargs['params']['scope'] = scope
+ decrypt = getattr(args, "decrypt", False)
+ kwargs["params"] = {"decrypt": str(decrypt).lower()}
+ scope = getattr(args, "scope", DEFAULT_LIST_SCOPE)
+ kwargs["params"]["scope"] = scope
if args.user:
- kwargs['params']['user'] = args.user
- kwargs['params']['limit'] = args.last
+ kwargs["params"]["user"] = args.user
+ kwargs["params"]["limit"] = args.last
return self.manager.query_with_count(**kwargs)
@@ -115,73 +157,124 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances, count = self.run(args, **kwargs)
if args.json or args.yaml:
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
else:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class KeyValuePairGetCommand(resource.ResourceGetCommand):
- pk_argument_name = 'name'
- display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'expire_timestamp']
+ pk_argument_name = "name"
+ display_attributes = [
+ "name",
+ "value",
+ "secret",
+ "encrypted",
+ "scope",
+ "expire_timestamp",
+ ]
def __init__(self, kv_resource, *args, **kwargs):
super(KeyValuePairGetCommand, self).__init__(kv_resource, *args, **kwargs)
- self.parser.add_argument('-d', '--decrypt', action='store_true',
- help='Decrypt secret if encrypted and show plain text.')
- self.parser.add_argument('-s', '--scope', default=DEFAULT_GET_SCOPE, dest='scope',
- help='Scope item is under. Example: "user".')
+ self.parser.add_argument(
+ "-d",
+ "--decrypt",
+ action="store_true",
+ help="Decrypt secret if encrypted and show plain text.",
+ )
+ self.parser.add_argument(
+ "-s",
+ "--scope",
+ default=DEFAULT_GET_SCOPE,
+ dest="scope",
+ help='Scope item is under. Example: "user".',
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
resource_name = getattr(args, self.pk_argument_name, None)
- decrypt = getattr(args, 'decrypt', False)
- scope = getattr(args, 'scope', DEFAULT_GET_SCOPE)
- kwargs['params'] = {'decrypt': str(decrypt).lower()}
- kwargs['params']['scope'] = scope
+ decrypt = getattr(args, "decrypt", False)
+ scope = getattr(args, "scope", DEFAULT_GET_SCOPE)
+ kwargs["params"] = {"decrypt": str(decrypt).lower()}
+ kwargs["params"]["scope"] = scope
return self.get_resource_by_id(id=resource_name, **kwargs)
class KeyValuePairSetCommand(resource.ResourceCommand):
- display_attributes = ['name', 'value', 'scope', 'expire_timestamp']
+ display_attributes = ["name", "value", "scope", "expire_timestamp"]
def __init__(self, resource, *args, **kwargs):
super(KeyValuePairSetCommand, self).__init__(
- resource, 'set',
- 'Set an existing %s.' % resource.get_display_name().lower(),
- *args, **kwargs
+ resource,
+ "set",
+ "Set an existing %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
)
# --encrypt and --encrypted options are mutually exclusive.
# --encrypt implies provided value is plain text and should be encrypted whereas
# --encrypted implies value is already encrypted and should be treated as-is.
encryption_group = self.parser.add_mutually_exclusive_group()
- encryption_group.add_argument('-e', '--encrypt', dest='secret',
- action='store_true',
- help='Encrypt value before saving.')
- encryption_group.add_argument('--encrypted', dest='encrypted',
- action='store_true',
- help=('Value provided is already encrypted with the '
- 'instance crypto key and should be stored as-is.'))
-
- self.parser.add_argument('name',
- metavar='name',
- help='Name of the key value pair.')
- self.parser.add_argument('value', help='Value paired with the key.')
- self.parser.add_argument('-l', '--ttl', dest='ttl', type=int, default=None,
- help='TTL (in seconds) for this value.')
- self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE,
- help='Specify the scope under which you want ' +
- 'to place the item.')
- self.parser.add_argument('-u', '--user', dest='user', default=None,
- help='User for user scoped items (admin only).')
+ encryption_group.add_argument(
+ "-e",
+ "--encrypt",
+ dest="secret",
+ action="store_true",
+ help="Encrypt value before saving.",
+ )
+ encryption_group.add_argument(
+ "--encrypted",
+ dest="encrypted",
+ action="store_true",
+ help=(
+ "Value provided is already encrypted with the "
+ "instance crypto key and should be stored as-is."
+ ),
+ )
+
+ self.parser.add_argument(
+ "name", metavar="name", help="Name of the key value pair."
+ )
+ self.parser.add_argument("value", help="Value paired with the key.")
+ self.parser.add_argument(
+ "-l",
+ "--ttl",
+ dest="ttl",
+ type=int,
+ default=None,
+ help="TTL (in seconds) for this value.",
+ )
+ self.parser.add_argument(
+ "-s",
+ "--scope",
+ dest="scope",
+ default=DEFAULT_CUD_SCOPE,
+ help="Specify the scope under which you want " + "to place the item.",
+ )
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ dest="user",
+ default=None,
+ help="User for user scoped items (admin only).",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -205,35 +298,49 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
- self.print_output(instance, table.PropertyValueTable,
- attributes=self.display_attributes, json=args.json,
- yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=self.display_attributes,
+ json=args.json,
+ yaml=args.yaml,
+ )
class KeyValuePairDeleteCommand(resource.ResourceDeleteCommand):
- pk_argument_name = 'name'
+ pk_argument_name = "name"
def __init__(self, resource, *args, **kwargs):
super(KeyValuePairDeleteCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE,
- help='Specify the scope under which you want ' +
- 'to place the item.')
- self.parser.add_argument('-u', '--user', dest='user', default=None,
- help='User for user scoped items (admin only).')
+ self.parser.add_argument(
+ "-s",
+ "--scope",
+ dest="scope",
+ default=DEFAULT_CUD_SCOPE,
+ help="Specify the scope under which you want " + "to place the item.",
+ )
+ self.parser.add_argument(
+ "-u",
+ "--user",
+ dest="user",
+ default=None,
+ help="User for user scoped items (admin only).",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
resource_id = getattr(args, self.pk_argument_name, None)
- scope = getattr(args, 'scope', DEFAULT_CUD_SCOPE)
- kwargs['params'] = {}
- kwargs['params']['scope'] = scope
- kwargs['params']['user'] = args.user
+ scope = getattr(args, "scope", DEFAULT_CUD_SCOPE)
+ kwargs["params"] = {}
+ kwargs["params"]["scope"] = scope
+ kwargs["params"]["user"] = args.user
instance = self.get_resource(resource_id, **kwargs)
if not instance:
- raise resource.ResourceNotFoundError('KeyValuePair with id "%s" not found'
- % resource_id)
+ raise resource.ResourceNotFoundError(
+ 'KeyValuePair with id "%s" not found' % resource_id
+ )
instance.id = resource_id # TODO: refactor and get rid of id
self.manager.delete(instance, **kwargs)
@@ -244,14 +351,23 @@ class KeyValuePairDeleteByPrefixCommand(resource.ResourceCommand):
Commands which delete all the key value pairs which match the provided
prefix.
"""
+
def __init__(self, resource, *args, **kwargs):
- super(KeyValuePairDeleteByPrefixCommand, self).__init__(resource, 'delete_by_prefix',
- 'Delete KeyValue pairs which \
- match the provided prefix',
- *args, **kwargs)
+ super(KeyValuePairDeleteByPrefixCommand, self).__init__(
+ resource,
+ "delete_by_prefix",
+ "Delete KeyValue pairs which \
+ match the provided prefix",
+ *args,
+ **kwargs,
+ )
- self.parser.add_argument('-p', '--prefix', required=True,
- help='Name prefix (e.g. twitter.TwitterSensor:)')
+ self.parser.add_argument(
+ "-p",
+ "--prefix",
+ required=True,
+ help="Name prefix (e.g. twitter.TwitterSensor:)",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -276,27 +392,39 @@ def run_and_print(self, args, **kwargs):
deleted = self.run(args, **kwargs)
key_ids = [key_pair.id for key_pair in deleted]
- print('Deleted %s keys' % (len(deleted)))
- print('Deleted key ids: %s' % (', '.join(key_ids)))
+ print("Deleted %s keys" % (len(deleted)))
+ print("Deleted key ids: %s" % (", ".join(key_ids)))
class KeyValuePairLoadCommand(resource.ResourceCommand):
- pk_argument_name = 'name'
- display_attributes = ['name', 'value']
+ pk_argument_name = "name"
+ display_attributes = ["name", "value"]
def __init__(self, resource, *args, **kwargs):
- help_text = ('Load a list of %s from file.' %
- resource.get_plural_display_name().lower())
- super(KeyValuePairLoadCommand, self).__init__(resource, 'load',
- help_text, *args, **kwargs)
-
- self.parser.add_argument('-c', '--convert', action='store_true',
- help=('Convert non-string types (hash, array, boolean,'
- ' int, float) to a JSON string before loading it'
- ' into the datastore.'))
+ help_text = (
+ "Load a list of %s from file." % resource.get_plural_display_name().lower()
+ )
+ super(KeyValuePairLoadCommand, self).__init__(
+ resource, "load", help_text, *args, **kwargs
+ )
+
+ self.parser.add_argument(
+ "-c",
+ "--convert",
+ action="store_true",
+ help=(
+ "Convert non-string types (hash, array, boolean,"
+ " int, float) to a JSON string before loading it"
+ " into the datastore."
+ ),
+ )
self.parser.add_argument(
- 'file', help=('JSON/YAML file containing the %s(s) to load'
- % resource.get_plural_display_name().lower()))
+ "file",
+ help=(
+ "JSON/YAML file containing the %s(s) to load"
+ % resource.get_plural_display_name().lower()
+ ),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -318,15 +446,15 @@ def run(self, args, **kwargs):
for item in kvps:
# parse required KeyValuePair properties
- name = item['name']
- value = item['value']
+ name = item["name"]
+ value = item["value"]
# parse optional KeyValuePair properties
- scope = item.get('scope', DEFAULT_CUD_SCOPE)
- user = item.get('user', None)
- encrypted = item.get('encrypted', False)
- secret = item.get('secret', False)
- ttl = item.get('ttl', None)
+ scope = item.get("scope", DEFAULT_CUD_SCOPE)
+ user = item.get("user", None)
+ encrypted = item.get("encrypted", False)
+ secret = item.get("secret", False)
+ ttl = item.get("ttl", None)
# if the value is not a string, convert it to JSON
# all keys in the datastore must strings
@@ -334,10 +462,15 @@ def run(self, args, **kwargs):
if args.convert:
value = json.dumps(value)
else:
- raise ValueError(("Item '%s' has a value that is not a string."
- " Either pass in the -c/--convert option to convert"
- " non-string types to JSON strings automatically, or"
- " convert the data to a string in the file") % name)
+ raise ValueError(
+ (
+ "Item '%s' has a value that is not a string."
+ " Either pass in the -c/--convert option to convert"
+ " non-string types to JSON strings automatically, or"
+ " convert the data to a string in the file"
+ )
+ % name
+ )
# create the KeyValuePair instance
instance = KeyValuePair()
@@ -368,7 +501,10 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- attributes=['name', 'value', 'secret', 'scope', 'user', 'ttl'],
- json=args.json,
- yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=["name", "value", "secret", "scope", "user", "ttl"],
+ json=args.json,
+ yaml=args.yaml,
+ )
diff --git a/st2client/st2client/commands/pack.py b/st2client/st2client/commands/pack.py
index 827db663df..8d05fc88dc 100644
--- a/st2client/st2client/commands/pack.py
+++ b/st2client/st2client/commands/pack.py
@@ -34,43 +34,56 @@
from st2client.utils import interactive
-LIVEACTION_STATUS_REQUESTED = 'requested'
-LIVEACTION_STATUS_SCHEDULED = 'scheduled'
-LIVEACTION_STATUS_DELAYED = 'delayed'
-LIVEACTION_STATUS_RUNNING = 'running'
-LIVEACTION_STATUS_SUCCEEDED = 'succeeded'
-LIVEACTION_STATUS_FAILED = 'failed'
-LIVEACTION_STATUS_TIMED_OUT = 'timeout'
-LIVEACTION_STATUS_ABANDONED = 'abandoned'
-LIVEACTION_STATUS_CANCELING = 'canceling'
-LIVEACTION_STATUS_CANCELED = 'canceled'
+LIVEACTION_STATUS_REQUESTED = "requested"
+LIVEACTION_STATUS_SCHEDULED = "scheduled"
+LIVEACTION_STATUS_DELAYED = "delayed"
+LIVEACTION_STATUS_RUNNING = "running"
+LIVEACTION_STATUS_SUCCEEDED = "succeeded"
+LIVEACTION_STATUS_FAILED = "failed"
+LIVEACTION_STATUS_TIMED_OUT = "timeout"
+LIVEACTION_STATUS_ABANDONED = "abandoned"
+LIVEACTION_STATUS_CANCELING = "canceling"
+LIVEACTION_STATUS_CANCELED = "canceled"
LIVEACTION_COMPLETED_STATES = [
LIVEACTION_STATUS_SUCCEEDED,
LIVEACTION_STATUS_FAILED,
LIVEACTION_STATUS_TIMED_OUT,
LIVEACTION_STATUS_CANCELED,
- LIVEACTION_STATUS_ABANDONED
+ LIVEACTION_STATUS_ABANDONED,
]
class PackBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(PackBranch, self).__init__(
- Pack, description, app, subparsers,
+ Pack,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': PackListCommand,
- 'get': PackGetCommand
- })
-
- self.commands['show'] = PackShowCommand(self.resource, self.app, self.subparsers)
- self.commands['search'] = PackSearchCommand(self.resource, self.app, self.subparsers)
- self.commands['install'] = PackInstallCommand(self.resource, self.app, self.subparsers)
- self.commands['remove'] = PackRemoveCommand(self.resource, self.app, self.subparsers)
- self.commands['register'] = PackRegisterCommand(self.resource, self.app, self.subparsers)
- self.commands['config'] = PackConfigCommand(self.resource, self.app, self.subparsers)
+ commands={"list": PackListCommand, "get": PackGetCommand},
+ )
+
+ self.commands["show"] = PackShowCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["search"] = PackSearchCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["install"] = PackInstallCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["remove"] = PackRemoveCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["register"] = PackRegisterCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["config"] = PackConfigCommand(
+ self.resource, self.app, self.subparsers
+ )
class PackResourceCommand(resource.ResourceCommand):
@@ -79,13 +92,18 @@ def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
if not instance:
raise resource.ResourceNotFoundError("No matching items found")
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
except resource.ResourceNotFoundError:
print("No matching items found")
except Exception as e:
message = six.text_type(e)
- print('ERROR: %s' % (message))
+ print("ERROR: %s" % (message))
raise OperationFailureException(message)
@@ -93,48 +111,72 @@ class PackAsyncCommand(ActionRunCommandMixin, resource.ResourceCommand):
def __init__(self, *args, **kwargs):
super(PackAsyncCommand, self).__init__(*args, **kwargs)
- self.parser.add_argument('-w', '--width', nargs='+', type=int, default=None,
- help='Set the width of columns in output.')
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help="Set the width of columns in output.",
+ )
detail_arg_grp = self.parser.add_mutually_exclusive_group()
- detail_arg_grp.add_argument('--attr', nargs='+',
- default=['ref', 'name', 'description', 'version', 'author'],
- help=('List of attributes to include in the '
- 'output. "all" or unspecified will '
- 'return all attributes.'))
- detail_arg_grp.add_argument('-d', '--detail', action='store_true',
- help='Display full detail of the execution in table format.')
+ detail_arg_grp.add_argument(
+ "--attr",
+ nargs="+",
+ default=["ref", "name", "description", "version", "author"],
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" or unspecified will '
+ "return all attributes."
+ ),
+ )
+ detail_arg_grp.add_argument(
+ "-d",
+ "--detail",
+ action="store_true",
+ help="Display full detail of the execution in table format.",
+ )
@add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
if not instance:
- raise Exception('Server did not create instance.')
+ raise Exception("Server did not create instance.")
parent_id = instance.execution_id
- stream_mgr = self.app.client.managers['Stream']
+ stream_mgr = self.app.client.managers["Stream"]
execution = None
with term.TaskIndicator() as indicator:
- events = ['st2.execution__create', 'st2.execution__update']
- for event in stream_mgr.listen(events, end_execution_id=parent_id,
- end_event="st2.execution__update", **kwargs):
+ events = ["st2.execution__create", "st2.execution__update"]
+ for event in stream_mgr.listen(
+ events,
+ end_execution_id=parent_id,
+ end_event="st2.execution__update",
+ **kwargs,
+ ):
execution = Execution(**event)
- if execution.id == parent_id \
- and execution.status in LIVEACTION_COMPLETED_STATES:
+ if (
+ execution.id == parent_id
+ and execution.status in LIVEACTION_COMPLETED_STATES
+ ):
break
# Suppress intermediate output in case output formatter is requested
if args.json or args.yaml:
continue
- if getattr(execution, 'parent', None) == parent_id:
+ if getattr(execution, "parent", None) == parent_id:
status = execution.status
- name = execution.context['orquesta']['task_name'] \
- if 'orquesta' in execution.context else execution.context['chain']['name']
+ name = (
+ execution.context["orquesta"]["task_name"]
+ if "orquesta" in execution.context
+ else execution.context["chain"]["name"]
+ )
if status == LIVEACTION_STATUS_SCHEDULED:
indicator.add_stage(status, name)
@@ -148,31 +190,48 @@ def run_and_print(self, args, **kwargs):
self._print_execution_details(execution=execution, args=args, **kwargs)
sys.exit(1)
- return self.app.client.managers['Execution'].get_by_id(parent_id, **kwargs)
+ return self.app.client.managers["Execution"].get_by_id(parent_id, **kwargs)
class PackListCommand(resource.ResourceListCommand):
- display_attributes = ['ref', 'name', 'description', 'version', 'author']
- attribute_display_order = ['ref', 'name', 'description', 'version', 'author']
+ display_attributes = ["ref", "name", "description", "version", "author"]
+ attribute_display_order = ["ref", "name", "description", "version", "author"]
class PackGetCommand(resource.ResourceGetCommand):
- pk_argument_name = 'ref'
- display_attributes = ['name', 'version', 'author', 'email', 'keywords', 'description']
- attribute_display_order = ['name', 'version', 'author', 'email', 'keywords', 'description']
- help_string = 'Get information about an installed pack.'
+ pk_argument_name = "ref"
+ display_attributes = [
+ "name",
+ "version",
+ "author",
+ "email",
+ "keywords",
+ "description",
+ ]
+ attribute_display_order = [
+ "name",
+ "version",
+ "author",
+ "email",
+ "keywords",
+ "description",
+ ]
+ help_string = "Get information about an installed pack."
class PackShowCommand(PackResourceCommand):
def __init__(self, resource, *args, **kwargs):
- help_string = ('Get information about an available %s from the index.' %
- resource.get_display_name().lower())
- super(PackShowCommand, self).__init__(resource, 'show', help_string,
- *args, **kwargs)
-
- self.parser.add_argument('pack',
- help='Name of the %s to show.' %
- resource.get_display_name().lower())
+ help_string = (
+ "Get information about an available %s from the index."
+ % resource.get_display_name().lower()
+ )
+ super(PackShowCommand, self).__init__(
+ resource, "show", help_string, *args, **kwargs
+ )
+
+ self.parser.add_argument(
+ "pack", help="Name of the %s to show." % resource.get_display_name().lower()
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -181,27 +240,39 @@ def run(self, args, **kwargs):
class PackInstallCommand(PackAsyncCommand):
def __init__(self, resource, *args, **kwargs):
- super(PackInstallCommand, self).__init__(resource, 'install', 'Install new %s.'
- % resource.get_plural_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('packs',
- nargs='+',
- metavar='pack',
- help='Name of the %s in Exchange, or a git repo URL.' %
- resource.get_plural_display_name().lower())
- self.parser.add_argument('--python3',
- action='store_true',
- default=False,
- help='Use Python 3 binary for pack virtual environment.')
- self.parser.add_argument('--force',
- action='store_true',
- default=False,
- help='Force pack installation.')
- self.parser.add_argument('--skip-dependencies',
- action='store_true',
- default=False,
- help='Skip pack dependency installation.')
+ super(PackInstallCommand, self).__init__(
+ resource,
+ "install",
+ "Install new %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "packs",
+ nargs="+",
+ metavar="pack",
+ help="Name of the %s in Exchange, or a git repo URL."
+ % resource.get_plural_display_name().lower(),
+ )
+ self.parser.add_argument(
+ "--python3",
+ action="store_true",
+ default=False,
+ help="Use Python 3 binary for pack virtual environment.",
+ )
+ self.parser.add_argument(
+ "--force",
+ action="store_true",
+ default=False,
+ help="Force pack installation.",
+ )
+ self.parser.add_argument(
+ "--skip-dependencies",
+ action="store_true",
+ default=False,
+ help="Skip pack dependency installation.",
+ )
def run(self, args, **kwargs):
is_structured_output = args.json or args.yaml
@@ -212,30 +283,42 @@ def run(self, args, **kwargs):
self._get_content_counts_for_pack(args, **kwargs)
if args.python3:
- warnings.warn('DEPRECATION WARNING: --python3 flag is ignored and will be removed '
- 'in v3.5.0 as StackStorm now runs with python3 only')
-
- return self.manager.install(args.packs, force=args.force,
- skip_dependencies=args.skip_dependencies, **kwargs)
+ warnings.warn(
+ "DEPRECATION WARNING: --python3 flag is ignored and will be removed "
+ "in v3.5.0 as StackStorm now runs with python3 only"
+ )
+
+ return self.manager.install(
+ args.packs,
+ force=args.force,
+ skip_dependencies=args.skip_dependencies,
+ **kwargs,
+ )
def _get_content_counts_for_pack(self, args, **kwargs):
# Global content list, excluding "tests"
# Note: We skip this step for local packs
- pack_content = {'actions': 0, 'rules': 0, 'sensors': 0, 'aliases': 0, 'triggers': 0}
+ pack_content = {
+ "actions": 0,
+ "rules": 0,
+ "sensors": 0,
+ "aliases": 0,
+ "triggers": 0,
+ }
if len(args.packs) == 1:
args.pack = args.packs[0]
- if args.pack.startswith('file://'):
+ if args.pack.startswith("file://"):
return
pack_info = self.manager.search(args, ignore_errors=True, **kwargs)
- content = getattr(pack_info, 'content', {})
+ content = getattr(pack_info, "content", {})
if content:
for entity in content.keys():
if entity in pack_content:
- pack_content[entity] += content[entity]['count']
+ pack_content[entity] += content[entity]["count"]
self._print_pack_content(args.packs, pack_content)
else:
@@ -246,122 +329,165 @@ def _get_content_counts_for_pack(self, args, **kwargs):
# args.pack required for search
args.pack = pack
- if args.pack.startswith('file://'):
+ if args.pack.startswith("file://"):
return
pack_info = self.manager.search(args, ignore_errors=True, **kwargs)
- content = getattr(pack_info, 'content', {})
+ content = getattr(pack_info, "content", {})
if content:
for entity in content.keys():
if entity in pack_content:
- pack_content[entity] += content[entity]['count']
+ pack_content[entity] += content[entity]["count"]
if content:
self._print_pack_content(args.packs, pack_content)
@staticmethod
def _print_pack_content(pack_name, pack_content):
- print('\nFor the "%s" %s, the following content will be registered:\n'
- % (', '.join(pack_name), 'pack' if len(pack_name) == 1 else 'packs'))
+ print(
+ '\nFor the "%s" %s, the following content will be registered:\n'
+ % (", ".join(pack_name), "pack" if len(pack_name) == 1 else "packs")
+ )
for item, count in pack_content.items():
- print('%-10s| %s' % (item, count))
- print('\nInstallation may take a while for packs with many items.')
+ print("%-10s| %s" % (item, count))
+ print("\nInstallation may take a while for packs with many items.")
@add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
instance = super(PackInstallCommand, self).run_and_print(args, **kwargs)
# Hack to get a list of resolved references of installed packs
- packs = instance.result['output']['packs_list']
+ packs = instance.result["output"]["packs_list"]
if len(packs) == 1:
- pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs)
- self.print_output(pack_instance, table.PropertyValueTable,
- attributes=args.attr, json=args.json, yaml=args.yaml,
- attribute_display_order=self.attribute_display_order)
+ pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id(
+ packs[0], **kwargs
+ )
+ self.print_output(
+ pack_instance,
+ table.PropertyValueTable,
+ attributes=args.attr,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_display_order=self.attribute_display_order,
+ )
else:
- all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs)
+ all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs)
pack_instances = []
for pack in all_pack_instances:
if pack.name in packs or pack.ref in packs:
pack_instances.append(pack)
- self.print_output(pack_instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ pack_instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
- warnings = instance.result['output']['warning_list']
+ warnings = instance.result["output"]["warning_list"]
for warning in warnings:
print(warning)
class PackRemoveCommand(PackAsyncCommand):
def __init__(self, resource, *args, **kwargs):
- super(PackRemoveCommand, self).__init__(resource, 'remove', 'Remove %s.'
- % resource.get_plural_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('packs',
- nargs='+',
- metavar='pack',
- help='Name of the %s to remove.' %
- resource.get_plural_display_name().lower())
+ super(PackRemoveCommand, self).__init__(
+ resource,
+ "remove",
+ "Remove %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "packs",
+ nargs="+",
+ metavar="pack",
+ help="Name of the %s to remove."
+ % resource.get_plural_display_name().lower(),
+ )
def run(self, args, **kwargs):
return self.manager.remove(args.packs, **kwargs)
@add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
- all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs)
+ all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs)
super(PackRemoveCommand, self).run_and_print(args, **kwargs)
packs = args.packs
if len(packs) == 1:
- pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs)
+ pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id(
+ packs[0], **kwargs
+ )
if pack_instance:
- raise OperationFailureException('Pack %s has not been removed properly' % packs[0])
-
- removed_pack_instance = next((pack for pack in all_pack_instances
- if pack.name == packs[0]), None)
-
- self.print_output(removed_pack_instance, table.PropertyValueTable,
- attributes=args.attr, json=args.json, yaml=args.yaml,
- attribute_display_order=self.attribute_display_order)
+ raise OperationFailureException(
+ "Pack %s has not been removed properly" % packs[0]
+ )
+
+ removed_pack_instance = next(
+ (pack for pack in all_pack_instances if pack.name == packs[0]), None
+ )
+
+ self.print_output(
+ removed_pack_instance,
+ table.PropertyValueTable,
+ attributes=args.attr,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_display_order=self.attribute_display_order,
+ )
else:
- remaining_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs)
+ remaining_pack_instances = self.app.client.managers["Pack"].get_all(
+ **kwargs
+ )
pack_instances = []
for pack in all_pack_instances:
if pack.name in packs or pack.ref in packs:
pack_instances.append(pack)
if pack in remaining_pack_instances:
- raise OperationFailureException('Pack %s has not been removed properly'
- % pack.name)
+ raise OperationFailureException(
+ "Pack %s has not been removed properly" % pack.name
+ )
- self.print_output(pack_instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ pack_instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class PackRegisterCommand(PackResourceCommand):
def __init__(self, resource, *args, **kwargs):
- super(PackRegisterCommand, self).__init__(resource, 'register',
- 'Register %s(s): sync all file changes with DB.'
- % resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('packs',
- nargs='*',
- metavar='pack',
- help='Name of the %s(s) to register.' %
- resource.get_display_name().lower())
-
- self.parser.add_argument('--types',
- nargs='+',
- help='Types of content to register.')
+ super(PackRegisterCommand, self).__init__(
+ resource,
+ "register",
+ "Register %s(s): sync all file changes with DB."
+ % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "packs",
+ nargs="*",
+ metavar="pack",
+ help="Name of the %s(s) to register." % resource.get_display_name().lower(),
+ )
+
+ self.parser.add_argument(
+ "--types", nargs="+", help="Types of content to register."
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -369,18 +495,21 @@ def run(self, args, **kwargs):
class PackSearchCommand(resource.ResourceTableCommand):
- display_attributes = ['name', 'description', 'version', 'author']
- attribute_display_order = ['name', 'description', 'version', 'author']
+ display_attributes = ["name", "description", "version", "author"]
+ attribute_display_order = ["name", "description", "version", "author"]
def __init__(self, resource, *args, **kwargs):
- super(PackSearchCommand, self).__init__(resource, 'search',
- 'Search the index for a %s with any attribute \
- matching the query.'
- % resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('query',
- help='Search query.')
+ super(PackSearchCommand, self).__init__(
+ resource,
+ "search",
+ "Search the index for a %s with any attribute \
+ matching the query."
+ % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument("query", help="Search query.")
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -389,31 +518,41 @@ def run(self, args, **kwargs):
class PackConfigCommand(resource.ResourceCommand):
def __init__(self, resource, *args, **kwargs):
- super(PackConfigCommand, self).__init__(resource, 'config',
- 'Configure a %s based on config schema.'
- % resource.get_display_name().lower(),
- *args, **kwargs)
-
- self.parser.add_argument('name',
- help='Name of the %s(s) to configure.' %
- resource.get_display_name().lower())
+ super(PackConfigCommand, self).__init__(
+ resource,
+ "config",
+ "Configure a %s based on config schema."
+ % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "name",
+ help="Name of the %s(s) to configure."
+ % resource.get_display_name().lower(),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
- schema = self.app.client.managers['ConfigSchema'].get_by_ref_or_id(args.name, **kwargs)
+ schema = self.app.client.managers["ConfigSchema"].get_by_ref_or_id(
+ args.name, **kwargs
+ )
if not schema:
- msg = '%s "%s" doesn\'t exist or doesn\'t have a config schema defined.'
- raise resource.ResourceNotFoundError(msg % (self.resource.get_display_name(),
- args.name))
+ msg = "%s \"%s\" doesn't exist or doesn't have a config schema defined."
+ raise resource.ResourceNotFoundError(
+ msg % (self.resource.get_display_name(), args.name)
+ )
config = interactive.InteractiveForm(schema.attributes).initiate_dialog()
- message = '---\nDo you want to preview the config in an editor before saving?'
- description = 'Secrets will be shown in plain text.'
- preview_dialog = interactive.Question(message, {'default': 'y',
- 'description': description})
- if preview_dialog.read() == 'y':
+ message = "---\nDo you want to preview the config in an editor before saving?"
+ description = "Secrets will be shown in plain text."
+ preview_dialog = interactive.Question(
+ message, {"default": "y", "description": description}
+ )
+ if preview_dialog.read() == "y":
try:
contents = yaml.safe_dump(config, indent=4, default_flow_style=False)
modified = editor.edit(contents=contents)
@@ -421,13 +560,13 @@ def run(self, args, **kwargs):
except editor.EditorError as e:
print(six.text_type(e))
- message = '---\nDo you want me to save it?'
- save_dialog = interactive.Question(message, {'default': 'y'})
- if save_dialog.read() == 'n':
- raise OperationFailureException('Interrupted')
+ message = "---\nDo you want me to save it?"
+ save_dialog = interactive.Question(message, {"default": "y"})
+ if save_dialog.read() == "n":
+ raise OperationFailureException("Interrupted")
config_item = Config(pack=args.name, values=config)
- result = self.app.client.managers['Config'].update(config_item, **kwargs)
+ result = self.app.client.managers["Config"].update(config_item, **kwargs)
return result
@@ -436,14 +575,19 @@ def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
if not instance:
raise Exception("Configuration failed")
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
except (KeyboardInterrupt, SystemExit):
- raise OperationFailureException('Interrupted')
+ raise OperationFailureException("Interrupted")
except Exception as e:
if self.app.client.debug:
raise
message = six.text_type(e)
- print('ERROR: %s' % (message))
+ print("ERROR: %s" % (message))
raise OperationFailureException(message)
diff --git a/st2client/st2client/commands/policy.py b/st2client/st2client/commands/policy.py
index de6c8ba997..cd891bc3a8 100644
--- a/st2client/st2client/commands/policy.py
+++ b/st2client/st2client/commands/policy.py
@@ -25,31 +25,36 @@
class PolicyTypeBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(PolicyTypeBranch, self).__init__(
- models.PolicyType, description, app, subparsers,
+ models.PolicyType,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': PolicyTypeListCommand,
- 'get': PolicyTypeGetCommand
- })
+ commands={"list": PolicyTypeListCommand, "get": PolicyTypeGetCommand},
+ )
class PolicyTypeListCommand(resource.ResourceListCommand):
- display_attributes = ['id', 'resource_type', 'name', 'description']
+ display_attributes = ["id", "resource_type", "name", "description"]
def __init__(self, resource, *args, **kwargs):
super(PolicyTypeListCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-r', '--resource-type', type=str, dest='resource_type',
- help='Return policy types for the resource type.')
+ self.parser.add_argument(
+ "-r",
+ "--resource-type",
+ type=str,
+ dest="resource_type",
+ help="Return policy types for the resource type.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
if args.resource_type:
- filters = {'resource_type': args.resource_type}
+ filters = {"resource_type": args.resource_type}
filters.update(**kwargs)
instances = self.manager.query(**filters)
return instances
@@ -58,36 +63,49 @@ def run(self, args, **kwargs):
class PolicyTypeGetCommand(resource.ResourceGetCommand):
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
def get_resource(self, ref_or_id, **kwargs):
return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs)
class PolicyBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(PolicyBranch, self).__init__(
- models.Policy, description, app, subparsers,
+ models.Policy,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': PolicyListCommand,
- 'get': PolicyGetCommand,
- 'update': PolicyUpdateCommand,
- 'delete': PolicyDeleteCommand
- })
+ "list": PolicyListCommand,
+ "get": PolicyGetCommand,
+ "update": PolicyUpdateCommand,
+ "delete": PolicyDeleteCommand,
+ },
+ )
class PolicyListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['ref', 'resource_ref', 'policy_type', 'enabled']
+ display_attributes = ["ref", "resource_ref", "policy_type", "enabled"]
def __init__(self, resource, *args, **kwargs):
super(PolicyListCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-r', '--resource-ref', type=str, dest='resource_ref',
- help='Return policies for the resource ref.')
- self.parser.add_argument('-pt', '--policy-type', type=str, dest='policy_type',
- help='Return policies of the policy type.')
+ self.parser.add_argument(
+ "-r",
+ "--resource-ref",
+ type=str,
+ dest="resource_ref",
+ help="Return policies for the resource ref.",
+ )
+ self.parser.add_argument(
+ "-pt",
+ "--policy-type",
+ type=str,
+ dest="policy_type",
+ help="Return policies of the policy type.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -95,10 +113,10 @@ def run(self, args, **kwargs):
filters = {}
if args.resource_ref:
- filters['resource_ref'] = args.resource_ref
+ filters["resource_ref"] = args.resource_ref
if args.policy_type:
- filters['policy_type'] = args.policy_type
+ filters["policy_type"] = args.policy_type
filters.update(**kwargs)
@@ -108,10 +126,18 @@ def run(self, args, **kwargs):
class PolicyGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'description',
- 'enabled', 'resource_ref', 'policy_type',
- 'parameters']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ "resource_ref",
+ "policy_type",
+ "parameters",
+ ]
class PolicyUpdateCommand(resource.ContentPackResourceUpdateCommand):
diff --git a/st2client/st2client/commands/rbac.py b/st2client/st2client/commands/rbac.py
index 0d7ea7f400..9a9e8c274b 100644
--- a/st2client/st2client/commands/rbac.py
+++ b/st2client/st2client/commands/rbac.py
@@ -20,58 +20,77 @@
from st2client.models.rbac import Role
from st2client.models.rbac import UserRoleAssignment
-__all__ = [
- 'RoleBranch',
- 'RoleAssignmentBranch'
+__all__ = ["RoleBranch", "RoleAssignmentBranch"]
+
+ROLE_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "system", "permission_grants"]
+ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = [
+ "id",
+ "role",
+ "user",
+ "is_remote",
+ "description",
]
-ROLE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'system', 'permission_grants']
-ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'role', 'user', 'is_remote', 'description']
-
class RoleBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(RoleBranch, self).__init__(
- Role, description, app, subparsers,
+ Role,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': RoleListCommand,
- 'get': RoleGetCommand
- })
+ commands={"list": RoleListCommand, "get": RoleGetCommand},
+ )
class RoleListCommand(resource.ResourceCommand):
- display_attributes = ['id', 'name', 'system', 'description']
+ display_attributes = ["id", "name", "system", "description"]
attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER
def __init__(self, resource, *args, **kwargs):
super(RoleListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s.' %
- resource.get_plural_display_name().lower(),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
self.group = self.parser.add_mutually_exclusive_group()
# Filter options
- self.group.add_argument('-s', '--system', action='store_true',
- help='Only display system roles.')
+ self.group.add_argument(
+ "-s", "--system", action="store_true", help="Only display system roles."
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.system:
- kwargs['system'] = args.system
+ kwargs["system"] = args.system
if args.system:
result = self.manager.query(**kwargs)
@@ -82,67 +101,93 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class RoleGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
+ display_attributes = ["all"]
attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER
- pk_argument_name = 'id'
+ pk_argument_name = "id"
class RoleAssignmentBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(RoleAssignmentBranch, self).__init__(
- UserRoleAssignment, description, app, subparsers,
+ UserRoleAssignment,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
commands={
- 'list': RoleAssignmentListCommand,
- 'get': RoleAssignmentGetCommand
- })
+ "list": RoleAssignmentListCommand,
+ "get": RoleAssignmentGetCommand,
+ },
+ )
class RoleAssignmentListCommand(resource.ResourceCommand):
- display_attributes = ['id', 'role', 'user', 'is_remote', 'source', 'description']
+ display_attributes = ["id", "role", "user", "is_remote", "source", "description"]
attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER
def __init__(self, resource, *args, **kwargs):
super(RoleAssignmentListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s.' %
- resource.get_plural_display_name().lower(),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
# Filter options
- self.parser.add_argument('-r', '--role', help='Role to filter on.')
- self.parser.add_argument('-u', '--user', help='User to filter on.')
- self.parser.add_argument('-s', '--source', help='Source to filter on.')
- self.parser.add_argument('--remote', action='store_true',
- help='Only display remote role assignments.')
+ self.parser.add_argument("-r", "--role", help="Role to filter on.")
+ self.parser.add_argument("-u", "--user", help="User to filter on.")
+ self.parser.add_argument("-s", "--source", help="Source to filter on.")
+ self.parser.add_argument(
+ "--remote",
+ action="store_true",
+ help="Only display remote role assignments.",
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.role:
- kwargs['role'] = args.role
+ kwargs["role"] = args.role
if args.user:
- kwargs['user'] = args.user
+ kwargs["user"] = args.user
if args.source:
- kwargs['source'] = args.source
+ kwargs["source"] = args.source
if args.remote:
- kwargs['remote'] = args.remote
+ kwargs["remote"] = args.remote
if args.role or args.user or args.remote or args.source:
result = self.manager.query(**kwargs)
@@ -153,12 +198,17 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class RoleAssignmentGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
+ display_attributes = ["all"]
attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER
- pk_argument_name = 'id'
+ pk_argument_name = "id"
diff --git a/st2client/st2client/commands/resource.py b/st2client/st2client/commands/resource.py
index 15ca68bb09..da0fbc85e3 100644
--- a/st2client/st2client/commands/resource.py
+++ b/st2client/st2client/commands/resource.py
@@ -32,8 +32,8 @@
from st2client.formatters import table
from st2client.utils.types import OrderedSet
-ALLOWED_EXTS = ['.json', '.yaml', '.yml']
-PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load}
+ALLOWED_EXTS = [".json", ".yaml", ".yml"]
+PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load}
LOG = logging.getLogger(__name__)
@@ -41,11 +41,12 @@ def add_auth_token_to_kwargs_from_cli(func):
@wraps(func)
def decorate(*args, **kwargs):
ns = args[1]
- if getattr(ns, 'token', None):
- kwargs['token'] = ns.token
- if getattr(ns, 'api_key', None):
- kwargs['api_key'] = ns.api_key
+ if getattr(ns, "token", None):
+ kwargs["token"] = ns.token
+ if getattr(ns, "api_key", None):
+ kwargs["api_key"] = ns.api_key
return func(*args, **kwargs)
+
return decorate
@@ -58,20 +59,34 @@ class ResourceNotFoundError(Exception):
class ResourceBranch(commands.Branch):
-
- def __init__(self, resource, description, app, subparsers,
- parent_parser=None, read_only=False, commands=None,
- has_disable=False):
+ def __init__(
+ self,
+ resource,
+ description,
+ app,
+ subparsers,
+ parent_parser=None,
+ read_only=False,
+ commands=None,
+ has_disable=False,
+ ):
self.resource = resource
super(ResourceBranch, self).__init__(
- self.resource.get_alias().lower(), description,
- app, subparsers, parent_parser=parent_parser)
+ self.resource.get_alias().lower(),
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ )
# Registers subcommands for managing the resource type.
self.subparsers = self.parser.add_subparsers(
- help=('List of commands for managing %s.' %
- self.resource.get_plural_display_name().lower()))
+ help=(
+ "List of commands for managing %s."
+ % self.resource.get_plural_display_name().lower()
+ )
+ )
# Resolves if commands need to be overridden.
commands = commands or {}
@@ -82,7 +97,7 @@ def __init__(self, resource, description, app, subparsers,
"update": ResourceUpdateCommand,
"delete": ResourceDeleteCommand,
"enable": ResourceEnableCommand,
- "disable": ResourceDisableCommand
+ "disable": ResourceDisableCommand,
}
for cmd, cmd_class in cmd_map.items():
if cmd not in commands:
@@ -90,17 +105,17 @@ def __init__(self, resource, description, app, subparsers,
# Instantiate commands.
args = [self.resource, self.app, self.subparsers]
- self.commands['list'] = commands['list'](*args)
- self.commands['get'] = commands['get'](*args)
+ self.commands["list"] = commands["list"](*args)
+ self.commands["get"] = commands["get"](*args)
if not read_only:
- self.commands['create'] = commands['create'](*args)
- self.commands['update'] = commands['update'](*args)
- self.commands['delete'] = commands['delete'](*args)
+ self.commands["create"] = commands["create"](*args)
+ self.commands["update"] = commands["update"](*args)
+ self.commands["delete"] = commands["delete"](*args)
if has_disable:
- self.commands['enable'] = commands['enable'](*args)
- self.commands['disable'] = commands['disable'](*args)
+ self.commands["enable"] = commands["enable"](*args)
+ self.commands["disable"] = commands["disable"](*args)
@six.add_metaclass(abc.ABCMeta)
@@ -109,29 +124,44 @@ class ResourceCommand(commands.Command):
def __init__(self, resource, *args, **kwargs):
- has_token_opt = kwargs.pop('has_token_opt', True)
+ has_token_opt = kwargs.pop("has_token_opt", True)
super(ResourceCommand, self).__init__(*args, **kwargs)
self.resource = resource
if has_token_opt:
- self.parser.add_argument('-t', '--token', dest='token',
- help='Access token for user authentication. '
- 'Get ST2_AUTH_TOKEN from the environment '
- 'variables by default.')
- self.parser.add_argument('--api-key', dest='api_key',
- help='Api Key for user authentication. '
- 'Get ST2_API_KEY from the environment '
- 'variables by default.')
+ self.parser.add_argument(
+ "-t",
+ "--token",
+ dest="token",
+ help="Access token for user authentication. "
+ "Get ST2_AUTH_TOKEN from the environment "
+ "variables by default.",
+ )
+ self.parser.add_argument(
+ "--api-key",
+ dest="api_key",
+ help="Api Key for user authentication. "
+ "Get ST2_API_KEY from the environment "
+ "variables by default.",
+ )
# Formatter flags
- self.parser.add_argument('-j', '--json',
- action='store_true', dest='json',
- help='Print output in JSON format.')
- self.parser.add_argument('-y', '--yaml',
- action='store_true', dest='yaml',
- help='Print output in YAML format.')
+ self.parser.add_argument(
+ "-j",
+ "--json",
+ action="store_true",
+ dest="json",
+ help="Print output in JSON format.",
+ )
+ self.parser.add_argument(
+ "-y",
+ "--yaml",
+ action="store_true",
+ dest="yaml",
+ help="Print output in YAML format.",
+ )
@property
def manager(self):
@@ -140,18 +170,17 @@ def manager(self):
@property
def arg_name_for_resource_id(self):
resource_name = self.resource.get_display_name().lower()
- return '%s-id' % resource_name.replace(' ', '-')
+ return "%s-id" % resource_name.replace(" ", "-")
def print_not_found(self, name):
- print('%s "%s" is not found.\n' %
- (self.resource.get_display_name(), name))
+ print('%s "%s" is not found.\n' % (self.resource.get_display_name(), name))
def get_resource(self, name_or_id, **kwargs):
pk_argument_name = self.pk_argument_name
- if pk_argument_name == 'name_or_id':
+ if pk_argument_name == "name_or_id":
instance = self.get_resource_by_name_or_id(name_or_id=name_or_id, **kwargs)
- elif pk_argument_name == 'ref_or_id':
+ elif pk_argument_name == "ref_or_id":
instance = self.get_resource_by_ref_or_id(ref_or_id=name_or_id, **kwargs)
else:
instance = self.get_resource_by_pk(pk=name_or_id, **kwargs)
@@ -167,8 +196,8 @@ def get_resource_by_pk(self, pk, **kwargs):
except Exception as e:
traceback.print_exc()
# Hack for "Unauthorized" exceptions, we do want to propagate those
- response = getattr(e, 'response', None)
- status_code = getattr(response, 'status_code', None)
+ response = getattr(e, "response", None)
+ status_code = getattr(response, "status_code", None)
if status_code and status_code == http_client.UNAUTHORIZED:
raise e
@@ -180,7 +209,7 @@ def get_resource_by_id(self, id, **kwargs):
instance = self.get_resource_by_pk(pk=id, **kwargs)
if not instance:
- message = ('Resource with id "%s" doesn\'t exist.' % (id))
+ message = 'Resource with id "%s" doesn\'t exist.' % (id)
raise ResourceNotFoundError(message)
return instance
@@ -197,8 +226,7 @@ def get_resource_by_name_or_id(self, name_or_id, **kwargs):
instance = self.get_resource_by_pk(pk=name_or_id, **kwargs)
if not instance:
- message = ('Resource with id or name "%s" doesn\'t exist.' %
- (name_or_id))
+ message = 'Resource with id or name "%s" doesn\'t exist.' % (name_or_id)
raise ResourceNotFoundError(message)
return instance
@@ -206,8 +234,7 @@ def get_resource_by_ref_or_id(self, ref_or_id, **kwargs):
instance = self.manager.get_by_ref_or_id(ref_or_id=ref_or_id, **kwargs)
if not instance:
- message = ('Resource with id or reference "%s" doesn\'t exist.' %
- (ref_or_id))
+ message = 'Resource with id or reference "%s" doesn\'t exist.' % (ref_or_id)
raise ResourceNotFoundError(message)
return instance
@@ -220,18 +247,18 @@ def run_and_print(self, args, **kwargs):
raise NotImplementedError
def _get_metavar_for_argument(self, argument):
- return argument.replace('_', '-')
+ return argument.replace("_", "-")
def _get_help_for_argument(self, resource, argument):
argument_display_name = argument.title()
resource_display_name = resource.get_display_name().lower()
- if 'ref' in argument:
- result = ('Reference or ID of the %s.' % (resource_display_name))
- elif 'name_or_id' in argument:
- result = ('Name or ID of the %s.' % (resource_display_name))
+ if "ref" in argument:
+ result = "Reference or ID of the %s." % (resource_display_name)
+ elif "name_or_id" in argument:
+ result = "Name or ID of the %s." % (resource_display_name)
else:
- result = ('%s of the %s.' % (argument_display_name, resource_display_name))
+ result = "%s of the %s." % (argument_display_name, resource_display_name)
return result
@@ -263,7 +290,7 @@ def _get_include_attributes(cls, args, extra_attributes=None):
# into account
# Special case for "all"
- if 'all' in args.attr:
+ if "all" in args.attr:
return None
for attr in args.attr:
@@ -272,7 +299,7 @@ def _get_include_attributes(cls, args, extra_attributes=None):
if include_attributes:
return include_attributes
- display_attributes = getattr(cls, 'display_attributes', [])
+ display_attributes = getattr(cls, "display_attributes", [])
if display_attributes:
include_attributes += display_attributes
@@ -283,97 +310,129 @@ def _get_include_attributes(cls, args, extra_attributes=None):
class ResourceTableCommand(ResourceViewCommand):
- display_attributes = ['id', 'name', 'description']
+ display_attributes = ["id", "name", "description"]
def __init__(self, resource, name, description, *args, **kwargs):
- super(ResourceTableCommand, self).__init__(resource, name, description,
- *args, **kwargs)
-
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ super(ResourceTableCommand, self).__init__(
+ resource, name, description, *args, **kwargs
+ )
+
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
include_attributes = self._get_include_attributes(args=args)
if include_attributes:
- include_attributes = ','.join(include_attributes)
- kwargs['params'] = {'include_attributes': include_attributes}
+ include_attributes = ",".join(include_attributes)
+ kwargs["params"] = {"include_attributes": include_attributes}
return self.manager.get_all(**kwargs)
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
class ResourceListCommand(ResourceTableCommand):
def __init__(self, resource, *args, **kwargs):
super(ResourceListCommand, self).__init__(
- resource, 'list', 'Get the list of %s.' % resource.get_plural_display_name().lower(),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of %s." % resource.get_plural_display_name().lower(),
+ *args,
+ **kwargs,
+ )
class ContentPackResourceListCommand(ResourceListCommand):
"""
Base command class for use with resources which belong to a content pack.
"""
+
def __init__(self, resource, *args, **kwargs):
- super(ContentPackResourceListCommand, self).__init__(resource,
- *args, **kwargs)
+ super(ContentPackResourceListCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-p', '--pack', type=str,
- help=('Only return resources belonging to the'
- ' provided pack'))
+ self.parser.add_argument(
+ "-p",
+ "--pack",
+ type=str,
+ help=("Only return resources belonging to the" " provided pack"),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
- filters = {'pack': args.pack}
+ filters = {"pack": args.pack}
filters.update(**kwargs)
include_attributes = self._get_include_attributes(args=args)
if include_attributes:
- include_attributes = ','.join(include_attributes)
- filters['params'] = {'include_attributes': include_attributes}
+ include_attributes = ",".join(include_attributes)
+ filters["params"] = {"include_attributes": include_attributes}
return self.manager.get_all(**filters)
class ResourceGetCommand(ResourceViewCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'name', 'description']
+ display_attributes = ["all"]
+ attribute_display_order = ["id", "name", "description"]
- pk_argument_name = 'name_or_id' # name of the attribute which stores resource PK
+ pk_argument_name = "name_or_id" # name of the attribute which stores resource PK
help_string = None
def __init__(self, resource, *args, **kwargs):
super(ResourceGetCommand, self).__init__(
- resource, 'get',
- self.help_string or 'Get individual %s.' % resource.get_display_name().lower(),
- *args, **kwargs
+ resource,
+ "get",
+ self.help_string
+ or "Get individual %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
)
argument = self.pk_argument_name
metavar = self._get_metavar_for_argument(argument=self.pk_argument_name)
- help = self._get_help_for_argument(resource=resource,
- argument=self.pk_argument_name)
-
- self.parser.add_argument(argument,
- metavar=metavar,
- help=help)
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" or unspecified will '
- 'return all attributes.'))
+ help = self._get_help_for_argument(
+ resource=resource, argument=self.pk_argument_name
+ )
+
+ self.parser.add_argument(argument, metavar=metavar, help=help)
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" or unspecified will '
+ "return all attributes."
+ ),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -383,13 +442,18 @@ def run(self, args, **kwargs):
def run_and_print(self, args, **kwargs):
try:
instance = self.run(args, **kwargs)
- self.print_output(instance, table.PropertyValueTable,
- attributes=args.attr, json=args.json, yaml=args.yaml,
- attribute_display_order=self.attribute_display_order)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=args.attr,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_display_order=self.attribute_display_order,
+ )
except ResourceNotFoundError:
resource_id = getattr(args, self.pk_argument_name, None)
self.print_not_found(resource_id)
- raise OperationFailureException('Resource %s not found.' % resource_id)
+ raise OperationFailureException("Resource %s not found." % resource_id)
class ContentPackResourceGetCommand(ResourceGetCommand):
@@ -400,24 +464,31 @@ class ContentPackResourceGetCommand(ResourceGetCommand):
retrieved by a reference or by an id.
"""
- attribute_display_order = ['id', 'pack', 'name', 'description']
+ attribute_display_order = ["id", "pack", "name", "description"]
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
def get_resource(self, ref_or_id, **kwargs):
return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs)
class ResourceCreateCommand(ResourceCommand):
-
def __init__(self, resource, *args, **kwargs):
- super(ResourceCreateCommand, self).__init__(resource, 'create',
- 'Create a new %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ super(ResourceCreateCommand, self).__init__(
+ resource,
+ "create",
+ "Create a new %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
- self.parser.add_argument('file',
- help=('JSON/YAML file containing the %s to create.'
- % resource.get_display_name().lower()))
+ self.parser.add_argument(
+ "file",
+ help=(
+ "JSON/YAML file containing the %s to create."
+ % resource.get_display_name().lower()
+ ),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -429,34 +500,46 @@ def run_and_print(self, args, **kwargs):
try:
instance = self.run(args, **kwargs)
if not instance:
- raise Exception('Server did not create instance.')
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ raise Exception("Server did not create instance.")
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
except Exception as e:
message = six.text_type(e)
- print('ERROR: %s' % (message))
+ print("ERROR: %s" % (message))
raise OperationFailureException(message)
class ResourceUpdateCommand(ResourceCommand):
- pk_argument_name = 'name_or_id'
+ pk_argument_name = "name_or_id"
def __init__(self, resource, *args, **kwargs):
- super(ResourceUpdateCommand, self).__init__(resource, 'update',
- 'Updating an existing %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ super(ResourceUpdateCommand, self).__init__(
+ resource,
+ "update",
+ "Updating an existing %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
argument = self.pk_argument_name
metavar = self._get_metavar_for_argument(argument=self.pk_argument_name)
- help = self._get_help_for_argument(resource=resource,
- argument=self.pk_argument_name)
+ help = self._get_help_for_argument(
+ resource=resource, argument=self.pk_argument_name
+ )
- self.parser.add_argument(argument,
- metavar=metavar,
- help=help)
- self.parser.add_argument('file',
- help=('JSON/YAML file containing the %s to update.'
- % resource.get_display_name().lower()))
+ self.parser.add_argument(argument, metavar=metavar, help=help)
+ self.parser.add_argument(
+ "file",
+ help=(
+ "JSON/YAML file containing the %s to update."
+ % resource.get_display_name().lower()
+ ),
+ )
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -465,46 +548,55 @@ def run(self, args, **kwargs):
data = load_meta_file(args.file)
modified_instance = self.resource.deserialize(data)
- if not getattr(modified_instance, 'id', None):
+ if not getattr(modified_instance, "id", None):
modified_instance.id = instance.id
else:
if modified_instance.id != instance.id:
- raise Exception('The value for the %s id in the JSON/YAML file '
- 'does not match the ID provided in the '
- 'command line arguments.' %
- self.resource.get_display_name().lower())
+ raise Exception(
+ "The value for the %s id in the JSON/YAML file "
+ "does not match the ID provided in the "
+ "command line arguments." % self.resource.get_display_name().lower()
+ )
return self.manager.update(modified_instance, **kwargs)
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
try:
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
except Exception as e:
- print('ERROR: %s' % (six.text_type(e)))
+ print("ERROR: %s" % (six.text_type(e)))
raise OperationFailureException(six.text_type(e))
class ContentPackResourceUpdateCommand(ResourceUpdateCommand):
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
class ResourceEnableCommand(ResourceCommand):
- pk_argument_name = 'name_or_id'
+ pk_argument_name = "name_or_id"
def __init__(self, resource, *args, **kwargs):
- super(ResourceEnableCommand, self).__init__(resource, 'enable',
- 'Enable an existing %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ super(ResourceEnableCommand, self).__init__(
+ resource,
+ "enable",
+ "Enable an existing %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
argument = self.pk_argument_name
metavar = self._get_metavar_for_argument(argument=self.pk_argument_name)
- help = self._get_help_for_argument(resource=resource,
- argument=self.pk_argument_name)
+ help = self._get_help_for_argument(
+ resource=resource, argument=self.pk_argument_name
+ )
- self.parser.add_argument(argument,
- metavar=metavar,
- help=help)
+ self.parser.add_argument(argument, metavar=metavar, help=help)
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -513,40 +605,48 @@ def run(self, args, **kwargs):
data = instance.serialize()
- if 'ref' in data:
- del data['ref']
+ if "ref" in data:
+ del data["ref"]
- data['enabled'] = True
+ data["enabled"] = True
modified_instance = self.resource.deserialize(data)
return self.manager.update(modified_instance, **kwargs)
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
class ContentPackResourceEnableCommand(ResourceEnableCommand):
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
class ResourceDisableCommand(ResourceCommand):
- pk_argument_name = 'name_or_id'
+ pk_argument_name = "name_or_id"
def __init__(self, resource, *args, **kwargs):
- super(ResourceDisableCommand, self).__init__(resource, 'disable',
- 'Disable an existing %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ super(ResourceDisableCommand, self).__init__(
+ resource,
+ "disable",
+ "Disable an existing %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
argument = self.pk_argument_name
metavar = self._get_metavar_for_argument(argument=self.pk_argument_name)
- help = self._get_help_for_argument(resource=resource,
- argument=self.pk_argument_name)
+ help = self._get_help_for_argument(
+ resource=resource, argument=self.pk_argument_name
+ )
- self.parser.add_argument(argument,
- metavar=metavar,
- help=help)
+ self.parser.add_argument(argument, metavar=metavar, help=help)
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -555,40 +655,48 @@ def run(self, args, **kwargs):
data = instance.serialize()
- if 'ref' in data:
- del data['ref']
+ if "ref" in data:
+ del data["ref"]
- data['enabled'] = False
+ data["enabled"] = False
modified_instance = self.resource.deserialize(data)
return self.manager.update(modified_instance, **kwargs)
def run_and_print(self, args, **kwargs):
instance = self.run(args, **kwargs)
- self.print_output(instance, table.PropertyValueTable,
- attributes=['all'], json=args.json, yaml=args.yaml)
+ self.print_output(
+ instance,
+ table.PropertyValueTable,
+ attributes=["all"],
+ json=args.json,
+ yaml=args.yaml,
+ )
class ContentPackResourceDisableCommand(ResourceDisableCommand):
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
class ResourceDeleteCommand(ResourceCommand):
- pk_argument_name = 'name_or_id'
+ pk_argument_name = "name_or_id"
def __init__(self, resource, *args, **kwargs):
- super(ResourceDeleteCommand, self).__init__(resource, 'delete',
- 'Delete an existing %s.' % resource.get_display_name().lower(),
- *args, **kwargs)
+ super(ResourceDeleteCommand, self).__init__(
+ resource,
+ "delete",
+ "Delete an existing %s." % resource.get_display_name().lower(),
+ *args,
+ **kwargs,
+ )
argument = self.pk_argument_name
metavar = self._get_metavar_for_argument(argument=self.pk_argument_name)
- help = self._get_help_for_argument(resource=resource,
- argument=self.pk_argument_name)
+ help = self._get_help_for_argument(
+ resource=resource, argument=self.pk_argument_name
+ )
- self.parser.add_argument(argument,
- metavar=metavar,
- help=help)
+ self.parser.add_argument(argument, metavar=metavar, help=help)
@add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -601,10 +709,12 @@ def run_and_print(self, args, **kwargs):
try:
self.run(args, **kwargs)
- print('Resource with id "%s" has been successfully deleted.' % (resource_id))
+ print(
+ 'Resource with id "%s" has been successfully deleted.' % (resource_id)
+ )
except ResourceNotFoundError:
self.print_not_found(resource_id)
- raise OperationFailureException('Resource %s not found.' % resource_id)
+ raise OperationFailureException("Resource %s not found." % resource_id)
class ContentPackResourceDeleteCommand(ResourceDeleteCommand):
@@ -612,7 +722,7 @@ class ContentPackResourceDeleteCommand(ResourceDeleteCommand):
Base command class for deleting a resource which belongs to a content pack.
"""
- pk_argument_name = 'ref_or_id'
+ pk_argument_name = "ref_or_id"
def load_meta_file(file_path):
@@ -621,8 +731,10 @@ def load_meta_file(file_path):
file_name, file_ext = os.path.splitext(file_path)
if file_ext not in ALLOWED_EXTS:
- raise Exception('Unsupported meta type %s, file %s. Allowed: %s' %
- (file_ext, file_path, ALLOWED_EXTS))
+ raise Exception(
+ "Unsupported meta type %s, file %s. Allowed: %s"
+ % (file_ext, file_path, ALLOWED_EXTS)
+ )
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
return PARSER_FUNCS[file_ext](f)
diff --git a/st2client/st2client/commands/rule.py b/st2client/st2client/commands/rule.py
index 7f0f5e58db..cbab939e10 100644
--- a/st2client/st2client/commands/rule.py
+++ b/st2client/st2client/commands/rule.py
@@ -21,99 +21,143 @@
class RuleBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(RuleBranch, self).__init__(
- models.Rule, description, app, subparsers,
+ models.Rule,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': RuleListCommand,
- 'get': RuleGetCommand,
- 'update': RuleUpdateCommand,
- 'delete': RuleDeleteCommand
- })
+ "list": RuleListCommand,
+ "get": RuleGetCommand,
+ "update": RuleUpdateCommand,
+ "delete": RuleDeleteCommand,
+ },
+ )
- self.commands['enable'] = RuleEnableCommand(self.resource, self.app, self.subparsers)
- self.commands['disable'] = RuleDisableCommand(self.resource, self.app, self.subparsers)
+ self.commands["enable"] = RuleEnableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["disable"] = RuleDisableCommand(
+ self.resource, self.app, self.subparsers
+ )
class RuleListCommand(resource.ResourceTableCommand):
- display_attributes = ['ref', 'pack', 'description', 'enabled']
- display_attributes_iftt = ['ref', 'trigger.ref', 'action.ref', 'enabled']
+ display_attributes = ["ref", "pack", "description", "enabled"]
+ display_attributes_iftt = ["ref", "trigger.ref", "action.ref", "enabled"]
def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
- super(RuleListCommand, self).__init__(resource, 'list',
- 'Get the list of the %s most recent %s.' %
- (self.default_limit,
- resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ super(RuleListCommand, self).__init__(
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
self.group = self.parser.add_argument_group()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
- self.parser.add_argument('--iftt', action='store_true',
- help='Show trigger and action in display list.')
- self.parser.add_argument('-p', '--pack', type=str,
- help=('Only return resources belonging to the'
- ' provided pack'))
- self.group.add_argument('-c', '--action',
- help='Action reference to filter the list.')
- self.group.add_argument('-g', '--trigger',
- help='Trigger type reference to filter the list.')
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
+ self.parser.add_argument(
+ "--iftt",
+ action="store_true",
+ help="Show trigger and action in display list.",
+ )
+ self.parser.add_argument(
+ "-p",
+ "--pack",
+ type=str,
+ help=("Only return resources belonging to the" " provided pack"),
+ )
+ self.group.add_argument(
+ "-c", "--action", help="Action reference to filter the list."
+ )
+ self.group.add_argument(
+ "-g", "--trigger", help="Trigger type reference to filter the list."
+ )
self.enabled_filter_group = self.parser.add_mutually_exclusive_group()
- self.enabled_filter_group.add_argument('--enabled', action='store_true',
- help='Show rules that are enabled.')
- self.enabled_filter_group.add_argument('--disabled', action='store_true',
- help='Show rules that are disabled.')
+ self.enabled_filter_group.add_argument(
+ "--enabled", action="store_true", help="Show rules that are enabled."
+ )
+ self.enabled_filter_group.add_argument(
+ "--disabled", action="store_true", help="Show rules that are disabled."
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.pack:
- kwargs['pack'] = args.pack
+ kwargs["pack"] = args.pack
if args.action:
- kwargs['action'] = args.action
+ kwargs["action"] = args.action
if args.trigger:
- kwargs['trigger'] = args.trigger
+ kwargs["trigger"] = args.trigger
if args.enabled:
- kwargs['enabled'] = True
+ kwargs["enabled"] = True
if args.disabled:
- kwargs['enabled'] = False
+ kwargs["enabled"] = False
if args.iftt:
# switch attr to display the trigger and action
args.attr = self.display_attributes_iftt
include_attributes = self._get_include_attributes(args=args)
if include_attributes:
- include_attributes = ','.join(include_attributes)
- kwargs['params'] = {'include_attributes': include_attributes}
+ include_attributes = ",".join(include_attributes)
+ kwargs["params"] = {"include_attributes": include_attributes}
return self.manager.query_with_count(limit=args.last, **kwargs)
def run_and_print(self, args, **kwargs):
instances, count = self.run(args, **kwargs)
if args.json or args.yaml:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
else:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class RuleGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description',
- 'enabled']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "uid",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "enabled",
+ ]
class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand):
@@ -121,15 +165,29 @@ class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand):
class RuleEnableCommand(resource.ContentPackResourceEnableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description',
- 'enabled']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "enabled",
+ "description",
+ "enabled",
+ ]
class RuleDisableCommand(resource.ContentPackResourceDisableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description',
- 'enabled']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "enabled",
+ "description",
+ "enabled",
+ ]
class RuleDeleteCommand(resource.ContentPackResourceDeleteCommand):
diff --git a/st2client/st2client/commands/rule_enforcement.py b/st2client/st2client/commands/rule_enforcement.py
index ecebba2b07..dd624d4a72 100644
--- a/st2client/st2client/commands/rule_enforcement.py
+++ b/st2client/st2client/commands/rule_enforcement.py
@@ -22,24 +22,39 @@
class RuleEnforcementBranch(resource.ResourceBranch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(RuleEnforcementBranch, self).__init__(
- models.RuleEnforcement, description, app, subparsers,
+ models.RuleEnforcement,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': RuleEnforcementListCommand,
- 'get': RuleEnforcementGetCommand,
- })
+ "list": RuleEnforcementListCommand,
+ "get": RuleEnforcementGetCommand,
+ },
+ )
class RuleEnforcementGetCommand(resource.ResourceGetCommand):
- display_attributes = ['id', 'rule.ref', 'trigger_instance_id',
- 'execution_id', 'failure_reason', 'enforced_at']
- attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id',
- 'execution_id', 'failure_reason', 'enforced_at']
-
- pk_argument_name = 'id'
+ display_attributes = [
+ "id",
+ "rule.ref",
+ "trigger_instance_id",
+ "execution_id",
+ "failure_reason",
+ "enforced_at",
+ ]
+ attribute_display_order = [
+ "id",
+ "rule.ref",
+ "trigger_instance_id",
+ "execution_id",
+ "failure_reason",
+ "enforced_at",
+ ]
+
+ pk_argument_name = "id"
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -48,84 +63,137 @@ def run(self, args, **kwargs):
class RuleEnforcementListCommand(resource.ResourceCommand):
- display_attributes = ['id', 'rule.ref', 'trigger_instance_id',
- 'execution_id', 'enforced_at']
- attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id',
- 'execution_id', 'enforced_at']
-
- attribute_transform_functions = {
- 'enforced_at': format_isodate_for_user_timezone
- }
+ display_attributes = [
+ "id",
+ "rule.ref",
+ "trigger_instance_id",
+ "execution_id",
+ "enforced_at",
+ ]
+ attribute_display_order = [
+ "id",
+ "rule.ref",
+ "trigger_instance_id",
+ "execution_id",
+ "enforced_at",
+ ]
+
+ attribute_transform_functions = {"enforced_at": format_isodate_for_user_timezone}
def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
super(RuleEnforcementListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s most recent %s.' %
- (self.default_limit, resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
self.group = self.parser.add_argument_group()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
# Filter options
- self.group.add_argument('--trigger-instance',
- help='Trigger instance id to filter the list.')
-
- self.group.add_argument('--execution',
- help='Execution id to filter the list.')
- self.group.add_argument('--rule',
- help='Rule ref to filter the list.')
-
- self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt',
- default=None,
- help=('Only return enforcements with enforced_at '
- 'greater than the one provided. '
- 'Use time in the format 2000-01-01T12:00:00.000Z'))
- self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt',
- default=None,
- help=('Only return enforcements with enforced_at '
- 'lower than the one provided. '
- 'Use time in the format 2000-01-01T12:00:00.000Z'))
+ self.group.add_argument(
+ "--trigger-instance", help="Trigger instance id to filter the list."
+ )
+
+ self.group.add_argument("--execution", help="Execution id to filter the list.")
+ self.group.add_argument("--rule", help="Rule ref to filter the list.")
+
+ self.parser.add_argument(
+ "-tg",
+ "--timestamp-gt",
+ type=str,
+ dest="timestamp_gt",
+ default=None,
+ help=(
+ "Only return enforcements with enforced_at "
+ "greater than the one provided. "
+ "Use time in the format 2000-01-01T12:00:00.000Z"
+ ),
+ )
+ self.parser.add_argument(
+ "-tl",
+ "--timestamp-lt",
+ type=str,
+ dest="timestamp_lt",
+ default=None,
+ help=(
+ "Only return enforcements with enforced_at "
+ "lower than the one provided. "
+ "Use time in the format 2000-01-01T12:00:00.000Z"
+ ),
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.trigger_instance:
- kwargs['trigger_instance'] = args.trigger_instance
+ kwargs["trigger_instance"] = args.trigger_instance
if args.execution:
- kwargs['execution'] = args.execution
+ kwargs["execution"] = args.execution
if args.rule:
- kwargs['rule_ref'] = args.rule
+ kwargs["rule_ref"] = args.rule
if args.timestamp_gt:
- kwargs['enforced_at_gt'] = args.timestamp_gt
+ kwargs["enforced_at_gt"] = args.timestamp_gt
if args.timestamp_lt:
- kwargs['enforced_at_lt'] = args.timestamp_lt
+ kwargs["enforced_at_lt"] = args.timestamp_lt
return self.manager.query_with_count(limit=args.last, **kwargs)
def run_and_print(self, args, **kwargs):
instances, count = self.run(args, **kwargs)
if args.json or args.yaml:
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
else:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
diff --git a/st2client/st2client/commands/sensor.py b/st2client/st2client/commands/sensor.py
index 0d729c8c02..ca4cc33563 100644
--- a/st2client/st2client/commands/sensor.py
+++ b/st2client/st2client/commands/sensor.py
@@ -22,35 +22,67 @@
class SensorBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(SensorBranch, self).__init__(
- Sensor, description, app, subparsers,
+ Sensor,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': SensorListCommand,
- 'get': SensorGetCommand
- })
+ commands={"list": SensorListCommand, "get": SensorGetCommand},
+ )
- self.commands['enable'] = SensorEnableCommand(self.resource, self.app, self.subparsers)
- self.commands['disable'] = SensorDisableCommand(self.resource, self.app, self.subparsers)
+ self.commands["enable"] = SensorEnableCommand(
+ self.resource, self.app, self.subparsers
+ )
+ self.commands["disable"] = SensorDisableCommand(
+ self.resource, self.app, self.subparsers
+ )
class SensorListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['ref', 'pack', 'description', 'enabled']
+ display_attributes = ["ref", "pack", "description", "enabled"]
class SensorGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'enabled', 'entry_point',
- 'artifact_uri', 'trigger_types']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "uid",
+ "ref",
+ "pack",
+ "name",
+ "enabled",
+ "entry_point",
+ "artifact_uri",
+ "trigger_types",
+ ]
class SensorEnableCommand(resource.ContentPackResourceEnableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval',
- 'entry_point', 'artifact_uri', 'trigger_types']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "enabled",
+ "poll_interval",
+ "entry_point",
+ "artifact_uri",
+ "trigger_types",
+ ]
class SensorDisableCommand(resource.ContentPackResourceDisableCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval',
- 'entry_point', 'artifact_uri', 'trigger_types']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "enabled",
+ "poll_interval",
+ "entry_point",
+ "artifact_uri",
+ "trigger_types",
+ ]
diff --git a/st2client/st2client/commands/service_registry.py b/st2client/st2client/commands/service_registry.py
index b609e051a9..6b9bff60b9 100644
--- a/st2client/st2client/commands/service_registry.py
+++ b/st2client/st2client/commands/service_registry.py
@@ -25,76 +25,87 @@
class ServiceRegistryBranch(commands.Branch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(ServiceRegistryBranch, self).__init__(
- 'service-registry', description,
- app, subparsers, parent_parser=parent_parser)
+ "service-registry",
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ )
self.subparsers = self.parser.add_subparsers(
- help=('List of commands for managing service registry.'))
+ help=("List of commands for managing service registry.")
+ )
# Instantiate commands
- args_groups = ['Manage service registry groups', self.app, self.subparsers]
- args_members = ['Manage service registry members', self.app, self.subparsers]
+ args_groups = ["Manage service registry groups", self.app, self.subparsers]
+ args_members = ["Manage service registry members", self.app, self.subparsers]
- self.commands['groups'] = ServiceRegistryGroupsBranch(*args_groups)
- self.commands['members'] = ServiceRegistryMembersBranch(*args_members)
+ self.commands["groups"] = ServiceRegistryGroupsBranch(*args_groups)
+ self.commands["members"] = ServiceRegistryMembersBranch(*args_members)
class ServiceRegistryGroupsBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(ServiceRegistryGroupsBranch, self).__init__(
- ServiceRegistryGroup, description, app, subparsers,
+ ServiceRegistryGroup,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': ServiceRegistryListGroupsCommand,
- 'get': NoopCommand
- })
+ commands={"list": ServiceRegistryListGroupsCommand, "get": NoopCommand},
+ )
- del self.commands['get']
+ del self.commands["get"]
class ServiceRegistryMembersBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(ServiceRegistryMembersBranch, self).__init__(
- ServiceRegistryMember, description, app, subparsers,
+ ServiceRegistryMember,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': ServiceRegistryListMembersCommand,
- 'get': NoopCommand
- })
+ commands={"list": ServiceRegistryListMembersCommand, "get": NoopCommand},
+ )
- del self.commands['get']
+ del self.commands["get"]
class ServiceRegistryListGroupsCommand(resource.ResourceListCommand):
- display_attributes = ['group_id']
- attribute_display_order = ['group_id']
+ display_attributes = ["group_id"]
+ attribute_display_order = ["group_id"]
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
- manager = self.app.client.managers['ServiceRegistryGroups']
+ manager = self.app.client.managers["ServiceRegistryGroups"]
groups = manager.list()
return groups
class ServiceRegistryListMembersCommand(resource.ResourceListCommand):
- display_attributes = ['group_id', 'member_id', 'capabilities']
- attribute_display_order = ['group_id', 'member_id', 'capabilities']
+ display_attributes = ["group_id", "member_id", "capabilities"]
+ attribute_display_order = ["group_id", "member_id", "capabilities"]
def __init__(self, resource, *args, **kwargs):
super(ServiceRegistryListMembersCommand, self).__init__(
resource, *args, **kwargs
)
- self.parser.add_argument('--group-id', dest='group_id', default=None,
- help='If provided only retrieve members for the specified group.')
+ self.parser.add_argument(
+ "--group-id",
+ dest="group_id",
+ default=None,
+ help="If provided only retrieve members for the specified group.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
- groups_manager = self.app.client.managers['ServiceRegistryGroups']
- members_manager = self.app.client.managers['ServiceRegistryMembers']
+ groups_manager = self.app.client.managers["ServiceRegistryGroups"]
+ members_manager = self.app.client.managers["ServiceRegistryMembers"]
# If group ID is provided only retrieve members for that group, otherwise retrieve members
# for all groups
diff --git a/st2client/st2client/commands/timer.py b/st2client/st2client/commands/timer.py
index e3fc9e223f..c183367291 100644
--- a/st2client/st2client/commands/timer.py
+++ b/st2client/st2client/commands/timer.py
@@ -22,30 +22,39 @@
class TimerBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(TimerBranch, self).__init__(
- Timer, description, app, subparsers,
+ Timer,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': TimerListCommand,
- 'get': TimerGetCommand
- })
+ commands={"list": TimerListCommand, "get": TimerGetCommand},
+ )
class TimerListCommand(resource.ResourceListCommand):
- display_attributes = ['id', 'uid', 'pack', 'name', 'type', 'parameters']
+ display_attributes = ["id", "uid", "pack", "name", "type", "parameters"]
def __init__(self, resource, *args, **kwargs):
super(TimerListCommand, self).__init__(resource, *args, **kwargs)
- self.parser.add_argument('-ty', '--timer-type', type=str, dest='timer_type',
- help=("List %s type, example: 'core.st2.IntervalTimer', \
- 'core.st2.DateTimer', 'core.st2.CronTimer'." %
- resource.get_plural_display_name().lower()), required=False)
+ self.parser.add_argument(
+ "-ty",
+ "--timer-type",
+ type=str,
+ dest="timer_type",
+ help=(
+ "List %s type, example: 'core.st2.IntervalTimer', \
+ 'core.st2.DateTimer', 'core.st2.CronTimer'."
+ % resource.get_plural_display_name().lower()
+ ),
+ required=False,
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
if args.timer_type:
- kwargs['timer_type'] = args.timer_type
+ kwargs["timer_type"] = args.timer_type
if kwargs:
return self.manager.query(**kwargs)
@@ -54,5 +63,5 @@ def run(self, args, **kwargs):
class TimerGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['type', 'pack', 'name', 'description', 'parameters']
+ display_attributes = ["all"]
+ attribute_display_order = ["type", "pack", "name", "description", "parameters"]
diff --git a/st2client/st2client/commands/trace.py b/st2client/st2client/commands/trace.py
index b5e59c2cf1..ac8de676c2 100644
--- a/st2client/st2client/commands/trace.py
+++ b/st2client/st2client/commands/trace.py
@@ -23,53 +23,62 @@
from st2client.utils.date import format_isodate_for_user_timezone
-TRACE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'trace_tag', 'action_executions', 'rules',
- 'trigger_instances', 'start_timestamp']
+TRACE_ATTRIBUTE_DISPLAY_ORDER = [
+ "id",
+ "trace_tag",
+ "action_executions",
+ "rules",
+ "trigger_instances",
+ "start_timestamp",
+]
-TRACE_HEADER_DISPLAY_ORDER = ['id', 'trace_tag', 'start_timestamp']
+TRACE_HEADER_DISPLAY_ORDER = ["id", "trace_tag", "start_timestamp"]
-TRACE_COMPONENT_DISPLAY_LABELS = ['id', 'type', 'ref', 'updated_at']
+TRACE_COMPONENT_DISPLAY_LABELS = ["id", "type", "ref", "updated_at"]
-TRACE_DISPLAY_ATTRIBUTES = ['all']
+TRACE_DISPLAY_ATTRIBUTES = ["all"]
TRIGGER_INSTANCE_DISPLAY_OPTIONS = [
- 'all',
- 'trigger-instances',
- 'trigger_instances',
- 'triggerinstances',
- 'triggers'
+ "all",
+ "trigger-instances",
+ "trigger_instances",
+ "triggerinstances",
+ "triggers",
]
ACTION_EXECUTION_DISPLAY_OPTIONS = [
- 'all',
- 'executions',
- 'action-executions',
- 'action_executions',
- 'actionexecutions',
- 'actions'
+ "all",
+ "executions",
+ "action-executions",
+ "action_executions",
+ "actionexecutions",
+ "actions",
]
class TraceBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(TraceBranch, self).__init__(
- Trace, description, app, subparsers,
+ Trace,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': TraceListCommand,
- 'get': TraceGetCommand
- })
+ commands={"list": TraceListCommand, "get": TraceGetCommand},
+ )
class SingleTraceDisplayMixin(object):
-
def print_trace_details(self, trace, args, **kwargs):
- options = {'attributes': TRACE_ATTRIBUTE_DISPLAY_ORDER if args.json else
- TRACE_HEADER_DISPLAY_ORDER}
- options['json'] = args.json
- options['yaml'] = args.yaml
- options['attribute_transform_functions'] = self.attribute_transform_functions
+ options = {
+ "attributes": TRACE_ATTRIBUTE_DISPLAY_ORDER
+ if args.json
+ else TRACE_HEADER_DISPLAY_ORDER
+ }
+ options["json"] = args.json
+ options["yaml"] = args.yaml
+ options["attribute_transform_functions"] = self.attribute_transform_functions
formatter = execution_formatter.ExecutionResult
@@ -81,35 +90,63 @@ def print_trace_details(self, trace, args, **kwargs):
components = []
if any(attr in args.attr for attr in TRIGGER_INSTANCE_DISPLAY_OPTIONS):
- components.extend([Resource(**{'id': trigger_instance['object_id'],
- 'type': TriggerInstance._alias.lower(),
- 'ref': trigger_instance['ref'],
- 'updated_at': trigger_instance['updated_at']})
- for trigger_instance in trace.trigger_instances])
- if any(attr in args.attr for attr in ['all', 'rules']):
- components.extend([Resource(**{'id': rule['object_id'],
- 'type': Rule._alias.lower(),
- 'ref': rule['ref'],
- 'updated_at': rule['updated_at']})
- for rule in trace.rules])
+ components.extend(
+ [
+ Resource(
+ **{
+ "id": trigger_instance["object_id"],
+ "type": TriggerInstance._alias.lower(),
+ "ref": trigger_instance["ref"],
+ "updated_at": trigger_instance["updated_at"],
+ }
+ )
+ for trigger_instance in trace.trigger_instances
+ ]
+ )
+ if any(attr in args.attr for attr in ["all", "rules"]):
+ components.extend(
+ [
+ Resource(
+ **{
+ "id": rule["object_id"],
+ "type": Rule._alias.lower(),
+ "ref": rule["ref"],
+ "updated_at": rule["updated_at"],
+ }
+ )
+ for rule in trace.rules
+ ]
+ )
if any(attr in args.attr for attr in ACTION_EXECUTION_DISPLAY_OPTIONS):
- components.extend([Resource(**{'id': execution['object_id'],
- 'type': Execution._alias.lower(),
- 'ref': execution['ref'],
- 'updated_at': execution['updated_at']})
- for execution in trace.action_executions])
+ components.extend(
+ [
+ Resource(
+ **{
+ "id": execution["object_id"],
+ "type": Execution._alias.lower(),
+ "ref": execution["ref"],
+ "updated_at": execution["updated_at"],
+ }
+ )
+ for execution in trace.action_executions
+ ]
+ )
if components:
components.sort(key=lambda resource: resource.updated_at)
- self.print_output(components, table.MultiColumnTable,
- attributes=TRACE_COMPONENT_DISPLAY_LABELS,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ components,
+ table.MultiColumnTable,
+ attributes=TRACE_COMPONENT_DISPLAY_LABELS,
+ json=args.json,
+ yaml=args.yaml,
+ )
class TraceListCommand(resource.ResourceCommand, SingleTraceDisplayMixin):
- display_attributes = ['id', 'uid', 'trace_tag', 'start_timestamp']
+ display_attributes = ["id", "uid", "trace_tag", "start_timestamp"]
attribute_transform_functions = {
- 'start_timestamp': format_isodate_for_user_timezone
+ "start_timestamp": format_isodate_for_user_timezone
}
attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER
@@ -119,55 +156,90 @@ def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
super(TraceListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s most recent %s.' %
- (self.default_limit, resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
self.group = self.parser.add_mutually_exclusive_group()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
- self.parser.add_argument('-s', '--sort', type=str, dest='sort_order',
- default='descending',
- help=('Sort %s by start timestamp, '
- 'asc|ascending (earliest first) '
- 'or desc|descending (latest first)' % self.resource_name))
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
+ self.parser.add_argument(
+ "-s",
+ "--sort",
+ type=str,
+ dest="sort_order",
+ default="descending",
+ help=(
+ "Sort %s by start timestamp, "
+ "asc|ascending (earliest first) "
+ "or desc|descending (latest first)" % self.resource_name
+ ),
+ )
# Filter options
- self.group.add_argument('-c', '--trace-tag', help='Trace-tag to filter the list.')
- self.group.add_argument('-e', '--execution', help='Execution to filter the list.')
- self.group.add_argument('-r', '--rule', help='Rule to filter the list.')
- self.group.add_argument('-g', '--trigger-instance',
- help='TriggerInstance to filter the list.')
+ self.group.add_argument(
+ "-c", "--trace-tag", help="Trace-tag to filter the list."
+ )
+ self.group.add_argument(
+ "-e", "--execution", help="Execution to filter the list."
+ )
+ self.group.add_argument("-r", "--rule", help="Rule to filter the list.")
+ self.group.add_argument(
+ "-g", "--trigger-instance", help="TriggerInstance to filter the list."
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.trace_tag:
- kwargs['trace_tag'] = args.trace_tag
+ kwargs["trace_tag"] = args.trace_tag
if args.trigger_instance:
- kwargs['trigger_instance'] = args.trigger_instance
+ kwargs["trigger_instance"] = args.trigger_instance
if args.execution:
- kwargs['execution'] = args.execution
+ kwargs["execution"] = args.execution
if args.rule:
- kwargs['rule'] = args.rule
+ kwargs["rule"] = args.rule
if args.sort_order:
- if args.sort_order in ['asc', 'ascending']:
- kwargs['sort_asc'] = True
- elif args.sort_order in ['desc', 'descending']:
- kwargs['sort_desc'] = True
+ if args.sort_order in ["asc", "ascending"]:
+ kwargs["sort_asc"] = True
+ elif args.sort_order in ["desc", "descending"]:
+ kwargs["sort_desc"] = True
return self.manager.query_with_count(limit=args.last, **kwargs)
def run_and_print(self, args, **kwargs):
@@ -177,7 +249,7 @@ def run_and_print(self, args, **kwargs):
# For a single Trace we must include the components unless
# user has overriden the attributes to display
if args.attr == self.display_attributes:
- args.attr = ['all']
+ args.attr = ["all"]
self.print_trace_details(trace=instances[0], args=args)
if not args.json and not args.yaml:
@@ -185,27 +257,36 @@ def run_and_print(self, args, **kwargs):
table.SingleRowTable.note_box(self.resource_name, 1)
else:
if args.json or args.yaml:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
else:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class TraceGetCommand(resource.ResourceGetCommand, SingleTraceDisplayMixin):
- display_attributes = ['all']
+ display_attributes = ["all"]
attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER
attribute_transform_functions = {
- 'start_timestamp': format_isodate_for_user_timezone
+ "start_timestamp": format_isodate_for_user_timezone
}
- pk_argument_name = 'id'
+ pk_argument_name = "id"
def __init__(self, resource, *args, **kwargs):
super(TraceGetCommand, self).__init__(resource, *args, **kwargs)
@@ -213,23 +294,36 @@ def __init__(self, resource, *args, **kwargs):
# Causation chains
self.causation_group = self.parser.add_mutually_exclusive_group()
- self.causation_group.add_argument('-e', '--execution',
- help='Execution to show causation chain.')
- self.causation_group.add_argument('-r', '--rule', help='Rule to show causation chain.')
- self.causation_group.add_argument('-g', '--trigger-instance',
- help='TriggerInstance to show causation chain.')
+ self.causation_group.add_argument(
+ "-e", "--execution", help="Execution to show causation chain."
+ )
+ self.causation_group.add_argument(
+ "-r", "--rule", help="Rule to show causation chain."
+ )
+ self.causation_group.add_argument(
+ "-g", "--trigger-instance", help="TriggerInstance to show causation chain."
+ )
# display filter group
self.display_filter_group = self.parser.add_argument_group()
- self.display_filter_group.add_argument('--show-executions', action='store_true',
- help='Only show executions.')
- self.display_filter_group.add_argument('--show-rules', action='store_true',
- help='Only show rules.')
- self.display_filter_group.add_argument('--show-trigger-instances', action='store_true',
- help='Only show trigger instances.')
- self.display_filter_group.add_argument('-n', '--hide-noop-triggers', action='store_true',
- help='Hide noop trigger instances.')
+ self.display_filter_group.add_argument(
+ "--show-executions", action="store_true", help="Only show executions."
+ )
+ self.display_filter_group.add_argument(
+ "--show-rules", action="store_true", help="Only show rules."
+ )
+ self.display_filter_group.add_argument(
+ "--show-trigger-instances",
+ action="store_true",
+ help="Only show trigger instances.",
+ )
+ self.display_filter_group.add_argument(
+ "-n",
+ "--hide-noop-triggers",
+ action="store_true",
+ help="Hide noop trigger instances.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
@@ -243,7 +337,7 @@ def run_and_print(self, args, **kwargs):
trace = self.run(args, **kwargs)
except resource.ResourceNotFoundError:
self.print_not_found(args.id)
- raise OperationFailureException('Trace %s not found.' % (args.id))
+ raise OperationFailureException("Trace %s not found." % (args.id))
# First filter for causation chains
trace = self._filter_trace_components(trace=trace, args=args)
# next filter for display purposes
@@ -266,13 +360,13 @@ def _filter_trace_components(trace, args):
# pick the right component type
if args.execution:
component_id = args.execution
- component_type = 'action_execution'
+ component_type = "action_execution"
elif args.rule:
component_id = args.rule
- component_type = 'rule'
+ component_type = "rule"
elif args.trigger_instance:
component_id = args.trigger_instance
- component_type = 'trigger_instance'
+ component_type = "trigger_instance"
# Initialize collection to use
action_executions = []
@@ -284,13 +378,13 @@ def _filter_trace_components(trace, args):
while search_target_found:
components_list = []
- if component_type == 'action_execution':
+ if component_type == "action_execution":
components_list = trace.action_executions
to_update_list = action_executions
- elif component_type == 'rule':
+ elif component_type == "rule":
components_list = trace.rules
to_update_list = rules
- elif component_type == 'trigger_instance':
+ elif component_type == "trigger_instance":
components_list = trace.trigger_instances
to_update_list = trigger_instances
# Look for search_target in the right collection and
@@ -300,22 +394,25 @@ def _filter_trace_components(trace, args):
# init to default value
component_caused_by_id = None
for component in components_list:
- test_id = component['object_id']
+ test_id = component["object_id"]
if test_id == component_id:
- caused_by = component.get('caused_by', {})
- component_id = caused_by.get('id', None)
- component_type = caused_by.get('type', None)
+ caused_by = component.get("caused_by", {})
+ component_id = caused_by.get("id", None)
+ component_type = caused_by.get("type", None)
# If provided the component_caused_by_id must match as well. This is mostly
# applicable for rules since the same rule may appear multiple times and can
# only be distinguished by causing TriggerInstance.
- if component_caused_by_id and component_caused_by_id != component_id:
+ if (
+ component_caused_by_id
+ and component_caused_by_id != component_id
+ ):
continue
component_caused_by_id = None
to_update_list.append(component)
# In some cases the component_id and the causing component are combined to
# provide the complete causation chain. Think rule + triggerinstance
- if component_id and ':' in component_id:
- component_id_split = component_id.split(':')
+ if component_id and ":" in component_id:
+ component_id_split = component_id.split(":")
component_id = component_id_split[0]
component_caused_by_id = component_id_split[1]
search_target_found = True
@@ -333,19 +430,21 @@ def _apply_display_filters(trace, args):
should be displayed.
"""
# If all the filters are false nothing is to be filtered.
- all_component_types = not(args.show_executions or
- args.show_rules or
- args.show_trigger_instances)
+ all_component_types = not (
+ args.show_executions or args.show_rules or args.show_trigger_instances
+ )
# check if noop_triggers are to be hidden. This check applies whenever TriggerInstances
# are to be shown.
- if (all_component_types or args.show_trigger_instances) and args.hide_noop_triggers:
+ if (
+ all_component_types or args.show_trigger_instances
+ ) and args.hide_noop_triggers:
filtered_trigger_instances = []
for trigger_instance in trace.trigger_instances:
is_noop_trigger_instance = True
for rule in trace.rules:
- caused_by_id = rule.get('caused_by', {}).get('id', None)
- if caused_by_id == trigger_instance['object_id']:
+ caused_by_id = rule.get("caused_by", {}).get("id", None)
+ if caused_by_id == trigger_instance["object_id"]:
is_noop_trigger_instance = False
if not is_noop_trigger_instance:
filtered_trigger_instances.append(trigger_instance)
diff --git a/st2client/st2client/commands/trigger.py b/st2client/st2client/commands/trigger.py
index 2fd966261c..3a960fddc8 100644
--- a/st2client/st2client/commands/trigger.py
+++ b/st2client/st2client/commands/trigger.py
@@ -23,29 +23,40 @@
class TriggerTypeBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(TriggerTypeBranch, self).__init__(
- TriggerType, description, app, subparsers,
+ TriggerType,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
commands={
- 'list': TriggerTypeListCommand,
- 'get': TriggerTypeGetCommand,
- 'update': TriggerTypeUpdateCommand,
- 'delete': TriggerTypeDeleteCommand
- })
+ "list": TriggerTypeListCommand,
+ "get": TriggerTypeGetCommand,
+ "update": TriggerTypeUpdateCommand,
+ "delete": TriggerTypeDeleteCommand,
+ },
+ )
# Registers extended commands
- self.commands['getspecs'] = TriggerTypeSubTriggerCommand(
- self.resource, self.app, self.subparsers,
- add_help=False)
+ self.commands["getspecs"] = TriggerTypeSubTriggerCommand(
+ self.resource, self.app, self.subparsers, add_help=False
+ )
class TriggerTypeListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['ref', 'pack', 'description']
+ display_attributes = ["ref", "pack", "description"]
class TriggerTypeGetCommand(resource.ContentPackResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'ref', 'pack', 'name', 'description',
- 'parameters_schema', 'payload_schema']
+ display_attributes = ["all"]
+ attribute_display_order = [
+ "id",
+ "ref",
+ "pack",
+ "name",
+ "description",
+ "parameters_schema",
+ "payload_schema",
+ ]
class TriggerTypeUpdateCommand(resource.ContentPackResourceUpdateCommand):
@@ -57,29 +68,45 @@ class TriggerTypeDeleteCommand(resource.ContentPackResourceDeleteCommand):
class TriggerTypeSubTriggerCommand(resource.ResourceCommand):
- attribute_display_order = ['id', 'ref', 'context', 'parameters', 'status',
- 'start_timestamp', 'result']
+ attribute_display_order = [
+ "id",
+ "ref",
+ "context",
+ "parameters",
+ "status",
+ "start_timestamp",
+ "result",
+ ]
def __init__(self, resource, *args, **kwargs):
super(TriggerTypeSubTriggerCommand, self).__init__(
- resource, kwargs.pop('name', 'getspecs'),
- 'Return Trigger Specifications of a Trigger.',
- *args, **kwargs)
-
- self.parser.add_argument('ref', nargs='?',
- metavar='ref',
- help='Fully qualified name (pack.trigger_name) ' +
- 'of the trigger.')
-
- self.parser.add_argument('-h', '--help',
- action='store_true', dest='help',
- help='Print usage for the given action.')
+ resource,
+ kwargs.pop("name", "getspecs"),
+ "Return Trigger Specifications of a Trigger.",
+ *args,
+ **kwargs,
+ )
+
+ self.parser.add_argument(
+ "ref",
+ nargs="?",
+ metavar="ref",
+ help="Fully qualified name (pack.trigger_name) " + "of the trigger.",
+ )
+
+ self.parser.add_argument(
+ "-h",
+ "--help",
+ action="store_true",
+ dest="help",
+ help="Print usage for the given action.",
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
- trigger_mgr = self.app.client.managers['Trigger']
- return trigger_mgr.query(**{'type': args.ref})
+ trigger_mgr = self.app.client.managers["Trigger"]
+ return trigger_mgr.query(**{"type": args.ref})
@resource.add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
@@ -87,5 +114,6 @@ def run_and_print(self, args, **kwargs):
self.parser.print_help()
return
instances = self.run(args, **kwargs)
- self.print_output(instances, table.MultiColumnTable,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances, table.MultiColumnTable, json=args.json, yaml=args.yaml
+ )
diff --git a/st2client/st2client/commands/triggerinstance.py b/st2client/st2client/commands/triggerinstance.py
index 2ac4da73da..12966fea92 100644
--- a/st2client/st2client/commands/triggerinstance.py
+++ b/st2client/st2client/commands/triggerinstance.py
@@ -25,17 +25,23 @@ class TriggerInstanceResendCommand(resource.ResourceCommand):
def __init__(self, resource, *args, **kwargs):
super(TriggerInstanceResendCommand, self).__init__(
- resource, kwargs.pop('name', 're-emit'),
- 'Re-emit a particular trigger instance.',
- *args, **kwargs)
+ resource,
+ kwargs.pop("name", "re-emit"),
+ "Re-emit a particular trigger instance.",
+ *args,
+ **kwargs,
+ )
- self.parser.add_argument('id', nargs='?',
- metavar='id',
- help='ID of trigger instance to re-emit.')
self.parser.add_argument(
- '-h', '--help',
- action='store_true', dest='help',
- help='Print usage for the given command.')
+ "id", nargs="?", metavar="id", help="ID of trigger instance to re-emit."
+ )
+ self.parser.add_argument(
+ "-h",
+ "--help",
+ action="store_true",
+ dest="help",
+ help="Print usage for the given command.",
+ )
def run(self, args, **kwargs):
return self.manager.re_emit(args.id)
@@ -43,29 +49,35 @@ def run(self, args, **kwargs):
@resource.add_auth_token_to_kwargs_from_cli
def run_and_print(self, args, **kwargs):
ret = self.run(args, **kwargs)
- if 'message' in ret:
- print(ret['message'])
+ if "message" in ret:
+ print(ret["message"])
class TriggerInstanceBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(TriggerInstanceBranch, self).__init__(
- TriggerInstance, description, app, subparsers,
- parent_parser=parent_parser, read_only=True,
+ TriggerInstance,
+ description,
+ app,
+ subparsers,
+ parent_parser=parent_parser,
+ read_only=True,
commands={
- 'list': TriggerInstanceListCommand,
- 'get': TriggerInstanceGetCommand
- })
+ "list": TriggerInstanceListCommand,
+ "get": TriggerInstanceGetCommand,
+ },
+ )
- self.commands['re-emit'] = TriggerInstanceResendCommand(self.resource, self.app,
- self.subparsers, add_help=False)
+ self.commands["re-emit"] = TriggerInstanceResendCommand(
+ self.resource, self.app, self.subparsers, add_help=False
+ )
class TriggerInstanceListCommand(resource.ResourceViewCommand):
- display_attributes = ['id', 'trigger', 'occurrence_time', 'status']
+ display_attributes = ["id", "trigger", "occurrence_time", "status"]
attribute_transform_functions = {
- 'occurrence_time': format_isodate_for_user_timezone
+ "occurrence_time": format_isodate_for_user_timezone
}
def __init__(self, resource, *args, **kwargs):
@@ -73,83 +85,133 @@ def __init__(self, resource, *args, **kwargs):
self.default_limit = 50
super(TriggerInstanceListCommand, self).__init__(
- resource, 'list', 'Get the list of the %s most recent %s.' %
- (self.default_limit, resource.get_plural_display_name().lower()),
- *args, **kwargs)
+ resource,
+ "list",
+ "Get the list of the %s most recent %s."
+ % (self.default_limit, resource.get_plural_display_name().lower()),
+ *args,
+ **kwargs,
+ )
self.resource_name = resource.get_plural_display_name().lower()
self.group = self.parser.add_argument_group()
- self.parser.add_argument('-n', '--last', type=int, dest='last',
- default=self.default_limit,
- help=('List N most recent %s. Use -n -1 to fetch the full result \
- set.' % self.resource_name))
+ self.parser.add_argument(
+ "-n",
+ "--last",
+ type=int,
+ dest="last",
+ default=self.default_limit,
+ help=(
+ "List N most recent %s. Use -n -1 to fetch the full result \
+ set."
+ % self.resource_name
+ ),
+ )
# Filter options
- self.group.add_argument('--trigger', help='Trigger reference to filter the list.')
-
- self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt',
- default=None,
- help=('Only return trigger instances with occurrence_time '
- 'greater than the one provided. '
- 'Use time in the format 2000-01-01T12:00:00.000Z'))
- self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt',
- default=None,
- help=('Only return trigger instances with timestamp '
- 'lower than the one provided. '
- 'Use time in the format 2000-01-01T12:00:00.000Z'))
-
- self.group.add_argument('--status',
- help='Can be pending, processing, processed or processing_failed.')
+ self.group.add_argument(
+ "--trigger", help="Trigger reference to filter the list."
+ )
+
+ self.parser.add_argument(
+ "-tg",
+ "--timestamp-gt",
+ type=str,
+ dest="timestamp_gt",
+ default=None,
+ help=(
+ "Only return trigger instances with occurrence_time "
+ "greater than the one provided. "
+ "Use time in the format 2000-01-01T12:00:00.000Z"
+ ),
+ )
+ self.parser.add_argument(
+ "-tl",
+ "--timestamp-lt",
+ type=str,
+ dest="timestamp_lt",
+ default=None,
+ help=(
+ "Only return trigger instances with timestamp "
+ "lower than the one provided. "
+ "Use time in the format 2000-01-01T12:00:00.000Z"
+ ),
+ )
+
+ self.group.add_argument(
+ "--status",
+ help="Can be pending, processing, processed or processing_failed.",
+ )
# Display options
- self.parser.add_argument('-a', '--attr', nargs='+',
- default=self.display_attributes,
- help=('List of attributes to include in the '
- 'output. "all" will return all '
- 'attributes.'))
- self.parser.add_argument('-w', '--width', nargs='+', type=int,
- default=None,
- help=('Set the width of columns in output.'))
+ self.parser.add_argument(
+ "-a",
+ "--attr",
+ nargs="+",
+ default=self.display_attributes,
+ help=(
+ "List of attributes to include in the "
+ 'output. "all" will return all '
+ "attributes."
+ ),
+ )
+ self.parser.add_argument(
+ "-w",
+ "--width",
+ nargs="+",
+ type=int,
+ default=None,
+ help=("Set the width of columns in output."),
+ )
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
# Filtering options
if args.trigger:
- kwargs['trigger'] = args.trigger
+ kwargs["trigger"] = args.trigger
if args.timestamp_gt:
- kwargs['timestamp_gt'] = args.timestamp_gt
+ kwargs["timestamp_gt"] = args.timestamp_gt
if args.timestamp_lt:
- kwargs['timestamp_lt'] = args.timestamp_lt
+ kwargs["timestamp_lt"] = args.timestamp_lt
if args.status:
- kwargs['status'] = args.status
+ kwargs["status"] = args.status
include_attributes = self._get_include_attributes(args=args)
if include_attributes:
- include_attributes = ','.join(include_attributes)
- kwargs['params'] = {'include_attributes': include_attributes}
+ include_attributes = ",".join(include_attributes)
+ kwargs["params"] = {"include_attributes": include_attributes}
return self.manager.query_with_count(limit=args.last, **kwargs)
def run_and_print(self, args, **kwargs):
instances, count = self.run(args, **kwargs)
if args.json or args.yaml:
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
else:
- self.print_output(reversed(instances), table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- attribute_transform_functions=self.attribute_transform_functions)
+ self.print_output(
+ reversed(instances),
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ attribute_transform_functions=self.attribute_transform_functions,
+ )
if args.last and count and count > args.last:
table.SingleRowTable.note_box(self.resource_name, args.last)
class TriggerInstanceGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['id', 'trigger', 'occurrence_time', 'payload']
+ display_attributes = ["all"]
+ attribute_display_order = ["id", "trigger", "occurrence_time", "payload"]
- pk_argument_name = 'id'
+ pk_argument_name = "id"
@resource.add_auth_token_to_kwargs_from_cli
def run(self, args, **kwargs):
diff --git a/st2client/st2client/commands/webhook.py b/st2client/st2client/commands/webhook.py
index 3a48344500..4b555ac59f 100644
--- a/st2client/st2client/commands/webhook.py
+++ b/st2client/st2client/commands/webhook.py
@@ -23,37 +23,47 @@
class WebhookBranch(resource.ResourceBranch):
def __init__(self, description, app, subparsers, parent_parser=None):
super(WebhookBranch, self).__init__(
- Webhook, description, app, subparsers,
+ Webhook,
+ description,
+ app,
+ subparsers,
parent_parser=parent_parser,
read_only=True,
- commands={
- 'list': WebhookListCommand,
- 'get': WebhookGetCommand
- })
+ commands={"list": WebhookListCommand, "get": WebhookGetCommand},
+ )
class WebhookListCommand(resource.ContentPackResourceListCommand):
- display_attributes = ['url', 'type', 'description']
+ display_attributes = ["url", "type", "description"]
def run_and_print(self, args, **kwargs):
instances = self.run(args, **kwargs)
for instance in instances:
- instance.url = instance.parameters['url']
+ instance.url = instance.parameters["url"]
instances = sorted(instances, key=lambda k: k.url)
if args.json or args.yaml:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width,
- json=args.json, yaml=args.yaml)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ json=args.json,
+ yaml=args.yaml,
+ )
else:
- self.print_output(instances, table.MultiColumnTable,
- attributes=args.attr, widths=args.width)
+ self.print_output(
+ instances,
+ table.MultiColumnTable,
+ attributes=args.attr,
+ widths=args.width,
+ )
class WebhookGetCommand(resource.ResourceGetCommand):
- display_attributes = ['all']
- attribute_display_order = ['type', 'description']
+ display_attributes = ["all"]
+ attribute_display_order = ["type", "description"]
- pk_argument_name = 'url'
+ pk_argument_name = "url"
diff --git a/st2client/st2client/commands/workflow.py b/st2client/st2client/commands/workflow.py
index 5348f76706..57f9f52a5f 100644
--- a/st2client/st2client/commands/workflow.py
+++ b/st2client/st2client/commands/workflow.py
@@ -1,4 +1,3 @@
-
# Copyright 2020 The StackStorm Authors.
# Copyright 2019 Extreme Networks, Inc.
#
@@ -28,26 +27,25 @@
class WorkflowBranch(commands.Branch):
-
def __init__(self, description, app, subparsers, parent_parser=None):
super(WorkflowBranch, self).__init__(
- 'workflow', description, app, subparsers,
- parent_parser=parent_parser
+ "workflow", description, app, subparsers, parent_parser=parent_parser
)
# Add subparser to register subcommands for managing workflows.
- help_message = 'List of commands for managing workflows.'
+ help_message = "List of commands for managing workflows."
self.subparsers = self.parser.add_subparsers(help=help_message)
# Register workflow commands.
- self.commands['inspect'] = WorkflowInspectionCommand(self.app, self.subparsers)
+ self.commands["inspect"] = WorkflowInspectionCommand(self.app, self.subparsers)
class WorkflowInspectionCommand(commands.Command):
-
def __init__(self, *args, **kwargs):
- name = 'inspect'
- description = 'Inspect workflow definition and return the list of errors if any.'
+ name = "inspect"
+ description = (
+ "Inspect workflow definition and return the list of errors if any."
+ )
args = tuple([name, description] + list(args))
super(WorkflowInspectionCommand, self).__init__(*args, **kwargs)
@@ -55,27 +53,25 @@ def __init__(self, *args, **kwargs):
arg_group = self.parser.add_mutually_exclusive_group()
arg_group.add_argument(
- '--file',
- dest='file',
- help='Local file path to the workflow definition.'
+ "--file", dest="file", help="Local file path to the workflow definition."
)
arg_group.add_argument(
- '--action',
- dest='action',
- help='Reference name for the registered action. This option works only if the file '
- 'referenced by the entry point is installed locally under /opt/stackstorm/packs.'
+ "--action",
+ dest="action",
+ help="Reference name for the registered action. This option works only if the file "
+ "referenced by the entry point is installed locally under /opt/stackstorm/packs.",
)
@property
def manager(self):
- return self.app.client.managers['Workflow']
+ return self.app.client.managers["Workflow"]
def get_file_content(self, file_path):
if not os.path.isfile(file_path):
raise Exception('File "%s" does not exist on local system.' % file_path)
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
content = f.read()
return content
@@ -88,13 +84,18 @@ def run(self, args, **kwargs):
# is executed locally where the content is stored.
if not wf_def_file:
action_ref = args.action
- action_manager = self.app.client.managers['Action']
+ action_manager = self.app.client.managers["Action"]
action = action_manager.get_by_ref_or_id(ref_or_id=action_ref)
if not action:
raise Exception('Unable to identify action "%s".' % action_ref)
- wf_def_file = '/opt/stackstorm/packs/' + action.pack + '/actions/' + action.entry_point
+ wf_def_file = (
+ "/opt/stackstorm/packs/"
+ + action.pack
+ + "/actions/"
+ + action.entry_point
+ )
wf_def = self.get_file_content(wf_def_file)
@@ -105,10 +106,10 @@ def run_and_print(self, args, **kwargs):
errors = self.run(args, **kwargs)
if not isinstance(errors, list):
- raise TypeError('The inspection result is not type of list: %s' % errors)
+ raise TypeError("The inspection result is not type of list: %s" % errors)
if not errors:
- print('No errors found in workflow definition.')
+ print("No errors found in workflow definition.")
return
print(yaml.safe_dump(errors, default_flow_style=False, allow_unicode=True))
diff --git a/st2client/st2client/config.py b/st2client/st2client/config.py
index 5de500aec2..c002c7f414 100644
--- a/st2client/st2client/config.py
+++ b/st2client/st2client/config.py
@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'get_config',
- 'set_config'
-]
+__all__ = ["get_config", "set_config"]
# Stores parsed config dictionary
CONFIG = {}
diff --git a/st2client/st2client/config_parser.py b/st2client/st2client/config_parser.py
index e5095df3e2..ca88209f87 100644
--- a/st2client/st2client/config_parser.py
+++ b/st2client/st2client/config_parser.py
@@ -31,88 +31,38 @@
__all__ = [
- 'CLIConfigParser',
-
- 'ST2_CONFIG_DIRECTORY',
- 'ST2_CONFIG_PATH',
-
- 'CONFIG_DEFAULT_VALUES'
+ "CLIConfigParser",
+ "ST2_CONFIG_DIRECTORY",
+ "ST2_CONFIG_PATH",
+ "CONFIG_DEFAULT_VALUES",
]
-ST2_CONFIG_DIRECTORY = '~/.st2'
+ST2_CONFIG_DIRECTORY = "~/.st2"
ST2_CONFIG_DIRECTORY = os.path.abspath(os.path.expanduser(ST2_CONFIG_DIRECTORY))
-ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, 'config'))
+ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, "config"))
CONFIG_FILE_OPTIONS = {
- 'general': {
- 'base_url': {
- 'type': 'string',
- 'default': None
- },
- 'api_version': {
- 'type': 'string',
- 'default': None
- },
- 'cacert': {
- 'type': 'string',
- 'default': None
- },
- 'silence_ssl_warnings': {
- 'type': 'bool',
- 'default': False
- },
- 'silence_schema_output': {
- 'type': 'bool',
- 'default': True
- }
- },
- 'cli': {
- 'debug': {
- 'type': 'bool',
- 'default': False
- },
- 'cache_token': {
- 'type': 'boolean',
- 'default': True
- },
- 'timezone': {
- 'type': 'string',
- 'default': 'UTC'
- }
- },
- 'credentials': {
- 'username': {
- 'type': 'string',
- 'default': None
- },
- 'password': {
- 'type': 'string',
- 'default': None
- },
- 'api_key': {
- 'type': 'string',
- 'default': None
- }
+ "general": {
+ "base_url": {"type": "string", "default": None},
+ "api_version": {"type": "string", "default": None},
+ "cacert": {"type": "string", "default": None},
+ "silence_ssl_warnings": {"type": "bool", "default": False},
+ "silence_schema_output": {"type": "bool", "default": True},
},
- 'api': {
- 'url': {
- 'type': 'string',
- 'default': None
- }
+ "cli": {
+ "debug": {"type": "bool", "default": False},
+ "cache_token": {"type": "boolean", "default": True},
+ "timezone": {"type": "string", "default": "UTC"},
},
- 'auth': {
- 'url': {
- 'type': 'string',
- 'default': None
- }
+ "credentials": {
+ "username": {"type": "string", "default": None},
+ "password": {"type": "string", "default": None},
+ "api_key": {"type": "string", "default": None},
},
- 'stream': {
- 'url': {
- 'type': 'string',
- 'default': None
- }
- }
+ "api": {"url": {"type": "string", "default": None}},
+ "auth": {"url": {"type": "string", "default": None}},
+ "stream": {"url": {"type": "string", "default": None}},
}
CONFIG_DEFAULT_VALUES = {}
@@ -121,13 +71,18 @@
CONFIG_DEFAULT_VALUES[section] = {}
for key, options in six.iteritems(keys):
- default_value = options['default']
+ default_value = options["default"]
CONFIG_DEFAULT_VALUES[section][key] = default_value
class CLIConfigParser(object):
- def __init__(self, config_file_path, validate_config_exists=True,
- validate_config_permissions=True, log=None):
+ def __init__(
+ self,
+ config_file_path,
+ validate_config_exists=True,
+ validate_config_permissions=True,
+ log=None,
+ ):
if validate_config_exists and not os.path.isfile(config_file_path):
raise ValueError('Config file "%s" doesn\'t exist')
@@ -158,37 +113,40 @@ def parse(self):
if bool(os.stat(config_dir_path).st_mode & 0o7):
self.LOG.warn(
"The StackStorm configuration directory permissions are "
- "insecure (too permissive): others have access.")
+ "insecure (too permissive): others have access."
+ )
# Make sure the setgid bit is set on the directory
if not bool(os.stat(config_dir_path).st_mode & 0o2000):
self.LOG.info(
"The SGID bit is not set on the StackStorm configuration "
- "directory.")
+ "directory."
+ )
# Make sure the file permissions == 0o660
if bool(os.stat(self.config_file_path).st_mode & 0o7):
self.LOG.warn(
"The StackStorm configuration file permissions are "
- "insecure: others have access.")
+ "insecure: others have access."
+ )
config = ConfigParser()
- with io.open(self.config_file_path, 'r', encoding='utf8') as fp:
+ with io.open(self.config_file_path, "r", encoding="utf8") as fp:
config.readfp(fp)
for section, keys in six.iteritems(CONFIG_FILE_OPTIONS):
for key, options in six.iteritems(keys):
- key_type = options['type']
- key_default_value = options['default']
+ key_type = options["type"]
+ key_default_value = options["default"]
if config.has_option(section, key):
- if key_type in ['str', 'string']:
+ if key_type in ["str", "string"]:
get_func = config.get
- elif key_type in ['int', 'integer']:
+ elif key_type in ["int", "integer"]:
get_func = config.getint
- elif key_type in ['float']:
+ elif key_type in ["float"]:
get_func = config.getfloat
- elif key_type in ['bool', 'boolean']:
+ elif key_type in ["bool", "boolean"]:
get_func = config.getboolean
else:
msg = 'Invalid type "%s" for option "%s"' % (key_type, key)
diff --git a/st2client/st2client/exceptions/base.py b/st2client/st2client/exceptions/base.py
index f9cd343665..97c9bb8a09 100644
--- a/st2client/st2client/exceptions/base.py
+++ b/st2client/st2client/exceptions/base.py
@@ -16,7 +16,8 @@
class StackStormCLIBaseException(Exception):
"""
- The root of the exception class hierarchy for all
- StackStorm CLI exceptions.
+ The root of the exception class hierarchy for all
+ StackStorm CLI exceptions.
"""
+
pass
diff --git a/st2client/st2client/formatters/__init__.py b/st2client/st2client/formatters/__init__.py
index dcaaee3ee1..e0d8e5f718 100644
--- a/st2client/st2client/formatters/__init__.py
+++ b/st2client/st2client/formatters/__init__.py
@@ -25,10 +25,8 @@
class Formatter(six.with_metaclass(abc.ABCMeta, object)):
-
@classmethod
@abc.abstractmethod
def format(cls, subject, *args, **kwargs):
- """Override this method to customize output format for the subject.
- """
+ """Override this method to customize output format for the subject."""
raise NotImplementedError
diff --git a/st2client/st2client/formatters/doc.py b/st2client/st2client/formatters/doc.py
index ea2218dec4..5f6ca96dce 100644
--- a/st2client/st2client/formatters/doc.py
+++ b/st2client/st2client/formatters/doc.py
@@ -23,10 +23,7 @@
from st2client import formatters
from st2client.utils import jsutil
-__all__ = [
- 'JsonFormatter',
- 'YAMLFormatter'
-]
+__all__ = ["JsonFormatter", "YAMLFormatter"]
LOG = logging.getLogger(__name__)
@@ -34,25 +31,34 @@
class BaseFormatter(formatters.Formatter):
@classmethod
def format(self, subject, *args, **kwargs):
- attributes = kwargs.get('attributes', None)
+ attributes = kwargs.get("attributes", None)
if type(subject) is str:
subject = json.loads(subject)
- elif not isinstance(subject, (list, tuple)) and not hasattr(subject, '__iter__'):
+ elif not isinstance(subject, (list, tuple)) and not hasattr(
+ subject, "__iter__"
+ ):
doc = subject if isinstance(subject, dict) else subject.__dict__
- keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes
+ keys = (
+ list(doc.keys())
+ if not attributes or "all" in attributes
+ else attributes
+ )
docs = jsutil.get_kvps(doc, keys)
else:
docs = []
for item in subject:
doc = item if isinstance(item, dict) else item.__dict__
- keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes
+ keys = (
+ list(doc.keys())
+ if not attributes or "all" in attributes
+ else attributes
+ )
docs.append(jsutil.get_kvps(doc, keys))
return docs
class JsonFormatter(BaseFormatter):
-
@classmethod
def format(self, subject, *args, **kwargs):
docs = BaseFormatter.format(subject, *args, **kwargs)
@@ -60,7 +66,6 @@ def format(self, subject, *args, **kwargs):
class YAMLFormatter(BaseFormatter):
-
@classmethod
def format(self, subject, *args, **kwargs):
docs = BaseFormatter.format(subject, *args, **kwargs)
diff --git a/st2client/st2client/formatters/execution.py b/st2client/st2client/formatters/execution.py
index 69da8cdb41..b52527d4de 100644
--- a/st2client/st2client/formatters/execution.py
+++ b/st2client/st2client/formatters/execution.py
@@ -32,32 +32,31 @@
LOG = logging.getLogger(__name__)
-PLATFORM_MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1
+PLATFORM_MAXINT = 2 ** (struct.Struct("i").size * 8 - 1) - 1
def _print_bordered(text):
- lines = text.split('\n')
+ lines = text.split("\n")
width = max(len(s) for s in lines) + 2
- res = ['\n+' + '-' * width + '+']
+ res = ["\n+" + "-" * width + "+"]
for s in lines:
- res.append('| ' + (s + ' ' * width)[:width - 2] + ' |')
- res.append('+' + '-' * width + '+')
- return '\n'.join(res)
+ res.append("| " + (s + " " * width)[: width - 2] + " |")
+ res.append("+" + "-" * width + "+")
+ return "\n".join(res)
class ExecutionResult(formatters.Formatter):
-
@classmethod
def format(cls, entry, *args, **kwargs):
- attrs = kwargs.get('attributes', [])
- attribute_transform_functions = kwargs.get('attribute_transform_functions', {})
- key = kwargs.get('key', None)
+ attrs = kwargs.get("attributes", [])
+ attribute_transform_functions = kwargs.get("attribute_transform_functions", {})
+ key = kwargs.get("key", None)
if key:
output = jsutil.get_value(entry.result, key)
else:
# drop entry to the dict so that jsutil can operate
entry = vars(entry)
- output = ''
+ output = ""
for attr in attrs:
value = jsutil.get_value(entry, attr)
value = strutil.strip_carriage_returns(strutil.unescape(value))
@@ -65,8 +64,12 @@ def format(cls, entry, *args, **kwargs):
# if the leading character is objectish start and last character is objectish
# end but the string isn't supposed to be a object. Try/Except will catch
# this for now, but this should be improved.
- if (isinstance(value, six.string_types) and len(value) > 0 and
- value[0] in ['{', '['] and value[len(value) - 1] in ['}', ']']):
+ if (
+ isinstance(value, six.string_types)
+ and len(value) > 0
+ and value[0] in ["{", "["]
+ and value[len(value) - 1] in ["}", "]"]
+ ):
try:
new_value = ast.literal_eval(value)
except:
@@ -79,31 +82,40 @@ def format(cls, entry, *args, **kwargs):
# 2. Drop the trailing newline
# 3. Set width to maxint so pyyaml does not split text. Anything longer
# and likely we will see other issues like storage :P.
- formatted_value = yaml.safe_dump({attr: value},
- default_flow_style=False,
- width=PLATFORM_MAXINT,
- indent=2)[len(attr) + 2:-1]
- value = ('\n' if isinstance(value, dict) else '') + formatted_value
+ formatted_value = yaml.safe_dump(
+ {attr: value},
+ default_flow_style=False,
+ width=PLATFORM_MAXINT,
+ indent=2,
+ )[len(attr) + 2 : -1]
+ value = ("\n" if isinstance(value, dict) else "") + formatted_value
value = strutil.dedupe_newlines(value)
# transform the value of our attribute so things like 'status'
# and 'timestamp' are formatted nicely
- transform_function = attribute_transform_functions.get(attr,
- lambda value: value)
+ transform_function = attribute_transform_functions.get(
+ attr, lambda value: value
+ )
value = transform_function(value=value)
- output += ('\n' if output else '') + '%s: %s' % \
- (DisplayColors.colorize(attr, DisplayColors.BLUE), value)
+ output += ("\n" if output else "") + "%s: %s" % (
+ DisplayColors.colorize(attr, DisplayColors.BLUE),
+ value,
+ )
- output_schema = entry.get('action', {}).get('output_schema')
- schema_check = get_config()['general']['silence_schema_output']
- if not output_schema and kwargs.get('with_schema'):
+ output_schema = entry.get("action", {}).get("output_schema")
+ schema_check = get_config()["general"]["silence_schema_output"]
+ if not output_schema and kwargs.get("with_schema"):
rendered_schema = {
- 'output_schema': schema.render_output_schema_from_output(entry['result'])
+ "output_schema": schema.render_output_schema_from_output(
+ entry["result"]
+ )
}
- rendered_schema = yaml.safe_dump(rendered_schema, default_flow_style=False)
- output += '\n'
+ rendered_schema = yaml.safe_dump(
+ rendered_schema, default_flow_style=False
+ )
+ output += "\n"
output += _print_bordered(
"Based on the action output the following inferred schema was built:"
"\n\n"
@@ -120,7 +132,11 @@ def format(cls, entry, *args, **kwargs):
else:
# Assume Python 2
try:
- result = strutil.unescape(str(output)).decode('unicode_escape').encode('utf-8')
+ result = (
+ strutil.unescape(str(output))
+ .decode("unicode_escape")
+ .encode("utf-8")
+ )
except UnicodeDecodeError:
# String contains a value which is not an unicode escape sequence, ignore the error
result = strutil.unescape(str(output))
diff --git a/st2client/st2client/formatters/table.py b/st2client/st2client/formatters/table.py
index 404469ce0e..91cc59e009 100644
--- a/st2client/st2client/formatters/table.py
+++ b/st2client/st2client/formatters/table.py
@@ -40,40 +40,38 @@
MIN_COL_WIDTH = 5
# Default attribute display order to use if one is not provided
-DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'pack', 'description']
+DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "pack", "description"]
# Attributes which contain bash escape sequences - we can't split those across multiple lines
# since things would break
COLORIZED_ATTRIBUTES = {
- 'status': {
- 'col_width': 24 # Note: len('succeed' + ' (XXXX elapsed)') <= 24
- }
+ "status": {"col_width": 24} # Note: len('succeed' + ' (XXXX elapsed)') <= 24
}
class MultiColumnTable(formatters.Formatter):
-
def __init__(self):
self._table_width = 0
@classmethod
def format(cls, entries, *args, **kwargs):
- attributes = kwargs.get('attributes', [])
- attribute_transform_functions = kwargs.get('attribute_transform_functions', {})
- widths = kwargs.get('widths', [])
+ attributes = kwargs.get("attributes", [])
+ attribute_transform_functions = kwargs.get("attribute_transform_functions", {})
+ widths = kwargs.get("widths", [])
widths = widths or []
if not widths and attributes:
# Dynamically calculate column size based on the terminal size
cols = get_terminal_size_columns()
- if attributes[0] == 'id':
+ if attributes[0] == "id":
# consume iterator and save as entries so collection is accessible later.
entries = [e for e in entries]
# first column contains id, make sure it's not broken up
- first_col_width = cls._get_required_column_width(values=[e.id for e in entries],
- minimum_width=MIN_ID_COL_WIDTH)
- cols = (cols - first_col_width)
+ first_col_width = cls._get_required_column_width(
+ values=[e.id for e in entries], minimum_width=MIN_ID_COL_WIDTH
+ )
+ cols = cols - first_col_width
col_width = int(math.floor((cols / len(attributes))))
else:
col_width = int(math.floor((cols / len(attributes))))
@@ -88,14 +86,16 @@ def format(cls, entries, *args, **kwargs):
continue
if attribute_name in COLORIZED_ATTRIBUTES:
- current_col_width = COLORIZED_ATTRIBUTES[attribute_name]['col_width']
- subtract += (current_col_width - col_width)
+ current_col_width = COLORIZED_ATTRIBUTES[attribute_name][
+ "col_width"
+ ]
+ subtract += current_col_width - col_width
else:
# Make sure we subtract the added width from the last column so we account
# for the fixed width columns and make sure table is not wider than the
# terminal width.
if index == (len(attributes) - 1) and subtract:
- current_col_width = (col_width - subtract)
+ current_col_width = col_width - subtract
if current_col_width <= MIN_COL_WIDTH:
# Make sure column width is always grater than MIN_COL_WIDTH
@@ -105,12 +105,14 @@ def format(cls, entries, *args, **kwargs):
widths.append(current_col_width)
- if not attributes or 'all' in attributes:
+ if not attributes or "all" in attributes:
entries = list(entries) if entries else []
if len(entries) >= 1:
attributes = list(entries[0].__dict__.keys())
- attributes = sorted([attr for attr in attributes if not attr.startswith('_')])
+ attributes = sorted(
+ [attr for attr in attributes if not attr.startswith("_")]
+ )
else:
# There are no entries so we can't infer available attributes
attributes = []
@@ -123,8 +125,7 @@ def format(cls, entries, *args, **kwargs):
# If only 1 width value is provided then
# apply it to all columns else fix at 28.
width = widths[0] if len(widths) == 1 else 28
- columns = zip(attributes,
- [width for i in range(0, len(attributes))])
+ columns = zip(attributes, [width for i in range(0, len(attributes))])
# Format result to table.
table = PrettyTable()
@@ -132,14 +133,14 @@ def format(cls, entries, *args, **kwargs):
table.field_names.append(column[0])
table.max_width[column[0]] = column[1]
table.padding_width = 1
- table.align = 'l'
- table.valign = 't'
+ table.align = "l"
+ table.valign = "t"
for entry in entries:
# TODO: Improve getting values of nested dict.
values = []
for field_name in table.field_names:
- if '.' in field_name:
- field_names = field_name.split('.')
+ if "." in field_name:
+ field_names = field_name.split(".")
value = getattr(entry, field_names.pop(0), {})
for name in field_names:
value = cls._get_field_value(value, name)
@@ -149,8 +150,9 @@ def format(cls, entries, *args, **kwargs):
values.append(value)
else:
value = cls._get_simple_field_value(entry, field_name)
- transform_function = attribute_transform_functions.get(field_name,
- lambda value: value)
+ transform_function = attribute_transform_functions.get(
+ field_name, lambda value: value
+ )
value = transform_function(value=value)
value = strutil.strip_carriage_returns(strutil.unescape(value))
values.append(value)
@@ -177,14 +179,14 @@ def _get_simple_field_value(entry, field_name):
"""
Format a value for a simple field.
"""
- value = getattr(entry, field_name, '')
+ value = getattr(entry, field_name, "")
if isinstance(value, (list, tuple)):
if len(value) == 0:
- value = ''
+ value = ""
elif isinstance(value[0], (str, six.text_type)):
# List contains simple string values, format it as comma
# separated string
- value = ', '.join(value)
+ value = ", ".join(value)
return value
@@ -192,10 +194,10 @@ def _get_simple_field_value(entry, field_name):
def _get_field_value(value, field_name):
r_val = value.get(field_name, None)
if r_val is None:
- return ''
+ return ""
if isinstance(r_val, list) or isinstance(r_val, dict):
- return r_val if len(r_val) > 0 else ''
+ return r_val if len(r_val) > 0 else ""
return r_val
@staticmethod
@@ -203,7 +205,7 @@ def _get_friendly_column_name(name):
if not name:
return None
- friendly_name = name.replace('_', ' ').replace('.', ' ').capitalize()
+ friendly_name = name.replace("_", " ").replace(".", " ").capitalize()
return friendly_name
@staticmethod
@@ -213,33 +215,34 @@ def _get_required_column_width(values, minimum_width=0):
class PropertyValueTable(formatters.Formatter):
-
@classmethod
def format(cls, subject, *args, **kwargs):
- attributes = kwargs.get('attributes', None)
- attribute_display_order = kwargs.get('attribute_display_order',
- DEFAULT_ATTRIBUTE_DISPLAY_ORDER)
- attribute_transform_functions = kwargs.get('attribute_transform_functions', {})
+ attributes = kwargs.get("attributes", None)
+ attribute_display_order = kwargs.get(
+ "attribute_display_order", DEFAULT_ATTRIBUTE_DISPLAY_ORDER
+ )
+ attribute_transform_functions = kwargs.get("attribute_transform_functions", {})
- if not attributes or 'all' in attributes:
- attributes = sorted([attr for attr in subject.__dict__
- if not attr.startswith('_')])
+ if not attributes or "all" in attributes:
+ attributes = sorted(
+ [attr for attr in subject.__dict__ if not attr.startswith("_")]
+ )
for attr in attribute_display_order[::-1]:
if attr in attributes:
attributes.remove(attr)
attributes = [attr] + attributes
table = PrettyTable()
- table.field_names = ['Property', 'Value']
- table.max_width['Property'] = 20
- table.max_width['Value'] = 60
+ table.field_names = ["Property", "Value"]
+ table.max_width["Property"] = 20
+ table.max_width["Value"] = 60
table.padding_width = 1
- table.align = 'l'
- table.valign = 't'
+ table.align = "l"
+ table.valign = "t"
for attribute in attributes:
- if '.' in attribute:
- field_names = attribute.split('.')
+ if "." in attribute:
+ field_names = attribute.split(".")
value = cls._get_attribute_value(subject, field_names.pop(0))
for name in field_names:
value = cls._get_attribute_value(value, name)
@@ -248,8 +251,9 @@ def format(cls, subject, *args, **kwargs):
else:
value = cls._get_attribute_value(subject, attribute)
- transform_function = attribute_transform_functions.get(attribute,
- lambda value: value)
+ transform_function = attribute_transform_functions.get(
+ attribute, lambda value: value
+ )
value = transform_function(value=value)
if type(value) is dict or type(value) is list:
@@ -266,9 +270,9 @@ def _get_attribute_value(subject, attribute):
else:
r_val = getattr(subject, attribute, None)
if r_val is None:
- return ''
+ return ""
if isinstance(r_val, list) or isinstance(r_val, dict):
- return r_val if len(r_val) > 0 else ''
+ return r_val if len(r_val) > 0 else ""
return r_val
@@ -284,19 +288,25 @@ def note_box(entity, limit):
else:
entity = entity[:-1]
- message = "Note: Only one %s is displayed. Use -n/--last flag for more results." \
+ message = (
+ "Note: Only one %s is displayed. Use -n/--last flag for more results."
% entity
+ )
else:
- message = "Note: Only first %s %s are displayed. Use -n/--last flag for more results."\
+ message = (
+ "Note: Only first %s %s are displayed. Use -n/--last flag for more results."
% (limit, entity)
+ )
# adding default padding
message_length = len(message) + 3
m = MultiColumnTable()
if m.table_width > message_length:
- note = PrettyTable([""], right_padding_width=(m.table_width - message_length))
+ note = PrettyTable(
+ [""], right_padding_width=(m.table_width - message_length)
+ )
else:
note = PrettyTable([""])
note.header = False
note.add_row([message])
- sys.stderr.write((str(note) + '\n'))
+ sys.stderr.write((str(note) + "\n"))
return
diff --git a/st2client/st2client/models/__init__.py b/st2client/st2client/models/__init__.py
index 2862f59d28..8e27a77050 100644
--- a/st2client/st2client/models/__init__.py
+++ b/st2client/st2client/models/__init__.py
@@ -15,19 +15,19 @@
from __future__ import absolute_import
-from st2client.models.core import * # noqa
-from st2client.models.auth import * # noqa
-from st2client.models.action import * # noqa
+from st2client.models.core import * # noqa
+from st2client.models.auth import * # noqa
+from st2client.models.action import * # noqa
from st2client.models.action_alias import * # noqa
from st2client.models.aliasexecution import * # noqa
from st2client.models.config import * # noqa
from st2client.models.inquiry import * # noqa
-from st2client.models.keyvalue import * # noqa
-from st2client.models.pack import * # noqa
-from st2client.models.policy import * # noqa
-from st2client.models.reactor import * # noqa
-from st2client.models.trace import * # noqa
-from st2client.models.webhook import * # noqa
-from st2client.models.timer import * # noqa
-from st2client.models.service_registry import * # noqa
-from st2client.models.rbac import * # noqa
+from st2client.models.keyvalue import * # noqa
+from st2client.models.pack import * # noqa
+from st2client.models.policy import * # noqa
+from st2client.models.reactor import * # noqa
+from st2client.models.trace import * # noqa
+from st2client.models.webhook import * # noqa
+from st2client.models.timer import * # noqa
+from st2client.models.service_registry import * # noqa
+from st2client.models.rbac import * # noqa
diff --git a/st2client/st2client/models/action.py b/st2client/st2client/models/action.py
index 10692d3dc4..d31b694f80 100644
--- a/st2client/st2client/models/action.py
+++ b/st2client/st2client/models/action.py
@@ -24,27 +24,33 @@
class RunnerType(core.Resource):
- _alias = 'Runner'
- _display_name = 'Runner'
- _plural = 'RunnerTypes'
- _plural_display_name = 'Runners'
- _repr_attributes = ['name', 'enabled', 'description']
+ _alias = "Runner"
+ _display_name = "Runner"
+ _plural = "RunnerTypes"
+ _plural_display_name = "Runners"
+ _repr_attributes = ["name", "enabled", "description"]
class Action(core.Resource):
- _plural = 'Actions'
- _repr_attributes = ['name', 'pack', 'enabled', 'runner_type']
- _url_path = 'actions'
+ _plural = "Actions"
+ _repr_attributes = ["name", "pack", "enabled", "runner_type"]
+ _url_path = "actions"
class Execution(core.Resource):
- _alias = 'Execution'
- _display_name = 'Action Execution'
- _url_path = 'executions'
- _plural = 'ActionExecutions'
- _plural_display_name = 'Action executions'
- _repr_attributes = ['status', 'action', 'start_timestamp', 'end_timestamp', 'parameters',
- 'delay']
+ _alias = "Execution"
+ _display_name = "Action Execution"
+ _url_path = "executions"
+ _plural = "ActionExecutions"
+ _plural_display_name = "Action executions"
+ _repr_attributes = [
+ "status",
+ "action",
+ "start_timestamp",
+ "end_timestamp",
+ "parameters",
+ "delay",
+ ]
# NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for
diff --git a/st2client/st2client/models/action_alias.py b/st2client/st2client/models/action_alias.py
index 42162eae3b..1c1a696cff 100644
--- a/st2client/st2client/models/action_alias.py
+++ b/st2client/st2client/models/action_alias.py
@@ -17,25 +17,22 @@
from st2client.models import core
-__all__ = [
- 'ActionAlias',
- 'ActionAliasMatch'
-]
+__all__ = ["ActionAlias", "ActionAliasMatch"]
class ActionAlias(core.Resource):
- _alias = 'Action-Alias'
- _display_name = 'Action Alias'
- _plural = 'ActionAliases'
- _plural_display_name = 'Action Aliases'
- _url_path = 'actionalias'
- _repr_attributes = ['name', 'pack', 'action_ref']
+ _alias = "Action-Alias"
+ _display_name = "Action Alias"
+ _plural = "ActionAliases"
+ _plural_display_name = "Action Aliases"
+ _url_path = "actionalias"
+ _repr_attributes = ["name", "pack", "action_ref"]
class ActionAliasMatch(core.Resource):
- _alias = 'Action-Alias-Match'
- _display_name = 'ActionAlias Match'
- _plural = 'ActionAliasMatches'
- _plural_display_name = 'Action Alias Matches'
- _url_path = 'actionalias'
- _repr_attributes = ['command']
+ _alias = "Action-Alias-Match"
+ _display_name = "ActionAlias Match"
+ _plural = "ActionAliasMatches"
+ _plural_display_name = "Action Alias Matches"
+ _url_path = "actionalias"
+ _repr_attributes = ["command"]
diff --git a/st2client/st2client/models/aliasexecution.py b/st2client/st2client/models/aliasexecution.py
index 12cfc67cf5..a2d7e62a57 100644
--- a/st2client/st2client/models/aliasexecution.py
+++ b/st2client/st2client/models/aliasexecution.py
@@ -17,16 +17,21 @@
from st2client.models import core
-__all__ = [
- 'ActionAliasExecution'
-]
+__all__ = ["ActionAliasExecution"]
class ActionAliasExecution(core.Resource):
- _alias = 'Action-Alias-Execution'
- _display_name = 'ActionAlias Execution'
- _plural = 'ActionAliasExecutions'
- _plural_display_name = 'Runners'
- _url_path = 'aliasexecution'
- _repr_attributes = ['name', 'format', 'command', 'user', 'source_channel',
- 'notification_channel', 'notification_route']
+ _alias = "Action-Alias-Execution"
+ _display_name = "ActionAlias Execution"
+ _plural = "ActionAliasExecutions"
+ _plural_display_name = "Runners"
+ _url_path = "aliasexecution"
+ _repr_attributes = [
+ "name",
+ "format",
+ "command",
+ "user",
+ "source_channel",
+ "notification_channel",
+ "notification_route",
+ ]
diff --git a/st2client/st2client/models/auth.py b/st2client/st2client/models/auth.py
index 9fa626a19a..7c909ea172 100644
--- a/st2client/st2client/models/auth.py
+++ b/st2client/st2client/models/auth.py
@@ -24,14 +24,14 @@
class Token(core.Resource):
- _display_name = 'Access Token'
- _plural = 'Tokens'
- _plural_display_name = 'Access Tokens'
- _repr_attributes = ['user', 'expiry', 'metadata']
+ _display_name = "Access Token"
+ _plural = "Tokens"
+ _plural_display_name = "Access Tokens"
+ _repr_attributes = ["user", "expiry", "metadata"]
class ApiKey(core.Resource):
- _display_name = 'API Key'
- _plural = 'ApiKeys'
- _plural_display_name = 'API Keys'
- _repr_attributes = ['id', 'user', 'metadata']
+ _display_name = "API Key"
+ _plural = "ApiKeys"
+ _plural_display_name = "API Keys"
+ _repr_attributes = ["id", "user", "metadata"]
diff --git a/st2client/st2client/models/config.py b/st2client/st2client/models/config.py
index 247b4fcaf9..f9054751ed 100644
--- a/st2client/st2client/models/config.py
+++ b/st2client/st2client/models/config.py
@@ -19,14 +19,14 @@
class Config(core.Resource):
- _display_name = 'Config'
- _plural = 'Configs'
- _plural_display_name = 'Configs'
+ _display_name = "Config"
+ _plural = "Configs"
+ _plural_display_name = "Configs"
class ConfigSchema(core.Resource):
- _display_name = 'Config Schema'
- _plural = 'ConfigSchema'
- _plural_display_name = 'Config Schemas'
- _url_path = 'config_schemas'
- _repr_attributes = ['id', 'pack', 'attributes']
+ _display_name = "Config Schema"
+ _plural = "ConfigSchema"
+ _plural_display_name = "Config Schemas"
+ _url_path = "config_schemas"
+ _repr_attributes = ["id", "pack", "attributes"]
diff --git a/st2client/st2client/models/core.py b/st2client/st2client/models/core.py
index 255c91534f..d2a9b694f1 100644
--- a/st2client/st2client/models/core.py
+++ b/st2client/st2client/models/core.py
@@ -34,12 +34,13 @@
def add_auth_token_to_kwargs_from_env(func):
@wraps(func)
def decorate(*args, **kwargs):
- if not kwargs.get('token') and os.environ.get('ST2_AUTH_TOKEN', None):
- kwargs['token'] = os.environ.get('ST2_AUTH_TOKEN')
- if not kwargs.get('api_key') and os.environ.get('ST2_API_KEY', None):
- kwargs['api_key'] = os.environ.get('ST2_API_KEY')
+ if not kwargs.get("token") and os.environ.get("ST2_AUTH_TOKEN", None):
+ kwargs["token"] = os.environ.get("ST2_AUTH_TOKEN")
+ if not kwargs.get("api_key") and os.environ.get("ST2_API_KEY", None):
+ kwargs["api_key"] = os.environ.get("ST2_API_KEY")
return func(*args, **kwargs)
+
return decorate
@@ -81,8 +82,11 @@ def to_dict(self, exclude_attributes=None):
exclude_attributes = exclude_attributes or []
attributes = list(self.__dict__.keys())
- attributes = [attr for attr in attributes if not attr.startswith('__') and
- attr not in exclude_attributes]
+ attributes = [
+ attr
+ for attr in attributes
+ if not attr.startswith("__") and attr not in exclude_attributes
+ ]
result = {}
for attribute in attributes:
@@ -102,15 +106,15 @@ def get_display_name(cls):
@classmethod
def get_plural_name(cls):
if not cls._plural:
- raise Exception('The %s class is missing class attributes '
- 'in its definition.' % cls.__name__)
+ raise Exception(
+ "The %s class is missing class attributes "
+ "in its definition." % cls.__name__
+ )
return cls._plural
@classmethod
def get_plural_display_name(cls):
- return (cls._plural_display_name
- if cls._plural_display_name
- else cls._plural)
+ return cls._plural_display_name if cls._plural_display_name else cls._plural
@classmethod
def get_url_path_name(cls):
@@ -120,9 +124,9 @@ def get_url_path_name(cls):
return cls.get_plural_name().lower()
def serialize(self):
- return dict((k, v)
- for k, v in six.iteritems(self.__dict__)
- if not k.startswith('_'))
+ return dict(
+ (k, v) for k, v in six.iteritems(self.__dict__) if not k.startswith("_")
+ )
@classmethod
def deserialize(cls, doc):
@@ -140,16 +144,15 @@ def __repr__(self):
attributes = []
for attribute in self._repr_attributes:
value = getattr(self, attribute, None)
- attributes.append('%s=%s' % (attribute, value))
+ attributes.append("%s=%s" % (attribute, value))
- attributes = ','.join(attributes)
+ attributes = ",".join(attributes)
class_name = self.__class__.__name__
- result = '<%s %s>' % (class_name, attributes)
+ result = "<%s %s>" % (class_name, attributes)
return result
class ResourceManager(object):
-
def __init__(self, resource, endpoint, cacert=None, debug=False):
self.resource = resource
self.debug = debug
@@ -159,46 +162,47 @@ def __init__(self, resource, endpoint, cacert=None, debug=False):
def handle_error(response):
try:
content = response.json()
- fault = content.get('faultstring', '') if content else ''
+ fault = content.get("faultstring", "") if content else ""
if fault:
- response.reason += '\nMESSAGE: %s' % fault
+ response.reason += "\nMESSAGE: %s" % fault
except Exception as e:
- response.reason += ('\nUnable to retrieve detailed message '
- 'from the HTTP response. %s\n' % six.text_type(e))
+ response.reason += (
+ "\nUnable to retrieve detailed message "
+ "from the HTTP response. %s\n" % six.text_type(e)
+ )
response.raise_for_status()
@add_auth_token_to_kwargs_from_env
def get_all(self, **kwargs):
# TODO: This is ugly, stop abusing kwargs
- url = '/%s' % self.resource.get_url_path_name()
- limit = kwargs.pop('limit', None)
- pack = kwargs.pop('pack', None)
- prefix = kwargs.pop('prefix', None)
- user = kwargs.pop('user', None)
+ url = "/%s" % self.resource.get_url_path_name()
+ limit = kwargs.pop("limit", None)
+ pack = kwargs.pop("pack", None)
+ prefix = kwargs.pop("prefix", None)
+ user = kwargs.pop("user", None)
- params = kwargs.pop('params', {})
+ params = kwargs.pop("params", {})
if limit:
- params['limit'] = limit
+ params["limit"] = limit
if pack:
- params['pack'] = pack
+ params["pack"] = pack
if prefix:
- params['prefix'] = prefix
+ params["prefix"] = prefix
if user:
- params['user'] = user
+ params["user"] = user
response = self.client.get(url=url, params=params, **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
- return [self.resource.deserialize(item)
- for item in response.json()]
+ return [self.resource.deserialize(item) for item in response.json()]
@add_auth_token_to_kwargs_from_env
def get_by_id(self, id, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), id)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), id)
response = self.client.get(url, **kwargs)
if response.status_code == http_client.NOT_FOUND:
return None
@@ -214,14 +218,18 @@ def get_property(self, id_, property_name, self_deserialize=True, **kwargs):
property_name: Name of the property
self_deserialize: #Implies use the deserialize method implemented by this resource.
"""
- token = kwargs.pop('token', None)
- api_key = kwargs.pop('api_key', None)
+ token = kwargs.pop("token", None)
+ api_key = kwargs.pop("api_key", None)
if kwargs:
- url = '/%s/%s/%s/?%s' % (self.resource.get_url_path_name(), id_, property_name,
- urllib.parse.urlencode(kwargs))
+ url = "/%s/%s/%s/?%s" % (
+ self.resource.get_url_path_name(),
+ id_,
+ property_name,
+ urllib.parse.urlencode(kwargs),
+ )
else:
- url = '/%s/%s/%s/' % (self.resource.get_url_path_name(), id_, property_name)
+ url = "/%s/%s/%s/" % (self.resource.get_url_path_name(), id_, property_name)
if token:
response = self.client.get(url, token=token)
@@ -246,19 +254,21 @@ def get_by_ref_or_id(self, ref_or_id, **kwargs):
def _query_details(self, **kwargs):
if not kwargs:
- raise Exception('Query parameter is not provided.')
+ raise Exception("Query parameter is not provided.")
- token = kwargs.get('token', None)
- api_key = kwargs.get('api_key', None)
- params = kwargs.get('params', {})
+ token = kwargs.get("token", None)
+ api_key = kwargs.get("api_key", None)
+ params = kwargs.get("params", {})
for k, v in six.iteritems(kwargs):
# Note: That's a special case to support api_key and token kwargs
- if k not in ['token', 'api_key', 'params']:
+ if k not in ["token", "api_key", "params"]:
params[k] = v
- url = '/%s/?%s' % (self.resource.get_url_path_name(),
- urllib.parse.urlencode(params))
+ url = "/%s/?%s" % (
+ self.resource.get_url_path_name(),
+ urllib.parse.urlencode(params),
+ )
if token:
response = self.client.get(url, token=token)
@@ -284,8 +294,8 @@ def query(self, **kwargs):
@add_auth_token_to_kwargs_from_env
def query_with_count(self, **kwargs):
instances, response = self._query_details(**kwargs)
- if response and 'X-Total-Count' in response.headers:
- return (instances, int(response.headers['X-Total-Count']))
+ if response and "X-Total-Count" in response.headers:
+ return (instances, int(response.headers["X-Total-Count"]))
else:
return (instances, None)
@@ -296,13 +306,15 @@ def get_by_name(self, name, **kwargs):
return None
else:
if len(instances) > 1:
- raise Exception('More than one %s named "%s" are found.' %
- (self.resource.__name__.lower(), name))
+ raise Exception(
+ 'More than one %s named "%s" are found.'
+ % (self.resource.__name__.lower(), name)
+ )
return instances[0]
@add_auth_token_to_kwargs_from_env
def create(self, instance, **kwargs):
- url = '/%s' % self.resource.get_url_path_name()
+ url = "/%s" % self.resource.get_url_path_name()
response = self.client.post(url, instance.serialize(), **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
@@ -311,7 +323,7 @@ def create(self, instance, **kwargs):
@add_auth_token_to_kwargs_from_env
def update(self, instance, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id)
response = self.client.put(url, instance.serialize(), **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
@@ -320,12 +332,14 @@ def update(self, instance, **kwargs):
@add_auth_token_to_kwargs_from_env
def delete(self, instance, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id)
response = self.client.delete(url, **kwargs)
- if response.status_code not in [http_client.OK,
- http_client.NO_CONTENT,
- http_client.NOT_FOUND]:
+ if response.status_code not in [
+ http_client.OK,
+ http_client.NO_CONTENT,
+ http_client.NOT_FOUND,
+ ]:
self.handle_error(response)
return False
@@ -333,11 +347,13 @@ def delete(self, instance, **kwargs):
@add_auth_token_to_kwargs_from_env
def delete_by_id(self, instance_id, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), instance_id)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), instance_id)
response = self.client.delete(url, **kwargs)
- if response.status_code not in [http_client.OK,
- http_client.NO_CONTENT,
- http_client.NOT_FOUND]:
+ if response.status_code not in [
+ http_client.OK,
+ http_client.NO_CONTENT,
+ http_client.NOT_FOUND,
+ ]:
self.handle_error(response)
return False
try:
@@ -357,18 +373,21 @@ def __init__(self, resource, endpoint, cacert=None, debug=False):
@add_auth_token_to_kwargs_from_env
def match(self, instance, **kwargs):
- url = '/%s/match' % self.resource.get_url_path_name()
+ url = "/%s/match" % self.resource.get_url_path_name()
response = self.client.post(url, instance.serialize(), **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
match = response.json()
- return (self.resource.deserialize(match['actionalias']), match['representation'])
+ return (
+ self.resource.deserialize(match["actionalias"]),
+ match["representation"],
+ )
class ActionAliasExecutionManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def match_and_execute(self, instance, **kwargs):
- url = '/%s/match_and_execute' % self.resource.get_url_path_name()
+ url = "/%s/match_and_execute" % self.resource.get_url_path_name()
response = self.client.post(url, instance.serialize(), **kwargs)
if response.status_code != http_client.OK:
@@ -380,7 +399,10 @@ def match_and_execute(self, instance, **kwargs):
class ActionResourceManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def get_entrypoint(self, ref_or_id, **kwargs):
- url = '/%s/views/entry_point/%s' % (self.resource.get_url_path_name(), ref_or_id)
+ url = "/%s/views/entry_point/%s" % (
+ self.resource.get_url_path_name(),
+ ref_or_id,
+ )
response = self.client.get(url, **kwargs)
if response.status_code != http_client.OK:
@@ -391,20 +413,30 @@ def get_entrypoint(self, ref_or_id, **kwargs):
class ExecutionResourceManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
- def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay=0, **kwargs):
- url = '/%s/%s/re_run' % (self.resource.get_url_path_name(), execution_id)
+ def re_run(
+ self,
+ execution_id,
+ parameters=None,
+ tasks=None,
+ no_reset=None,
+ delay=0,
+ **kwargs,
+ ):
+ url = "/%s/%s/re_run" % (self.resource.get_url_path_name(), execution_id)
tasks = tasks or []
no_reset = no_reset or []
if list(set(no_reset) - set(tasks)):
- raise ValueError('List of tasks to reset does not match the tasks to rerun.')
+ raise ValueError(
+ "List of tasks to reset does not match the tasks to rerun."
+ )
data = {
- 'parameters': parameters or {},
- 'tasks': tasks,
- 'reset': list(set(tasks) - set(no_reset)),
- 'delay': delay
+ "parameters": parameters or {},
+ "tasks": tasks,
+ "reset": list(set(tasks) - set(no_reset)),
+ "delay": delay,
}
response = self.client.post(url, data, **kwargs)
@@ -416,10 +448,10 @@ def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay
@add_auth_token_to_kwargs_from_env
def get_output(self, execution_id, output_type=None, **kwargs):
- url = '/%s/%s/output' % (self.resource.get_url_path_name(), execution_id)
+ url = "/%s/%s/output" % (self.resource.get_url_path_name(), execution_id)
if output_type:
- url += '?' + urllib.parse.urlencode({'output_type': output_type})
+ url += "?" + urllib.parse.urlencode({"output_type": output_type})
response = self.client.get(url, **kwargs)
if response.status_code != http_client.OK:
@@ -429,8 +461,8 @@ def get_output(self, execution_id, output_type=None, **kwargs):
@add_auth_token_to_kwargs_from_env
def pause(self, execution_id, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id)
- data = {'status': 'pausing'}
+ url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id)
+ data = {"status": "pausing"}
response = self.client.put(url, data, **kwargs)
@@ -441,8 +473,8 @@ def pause(self, execution_id, **kwargs):
@add_auth_token_to_kwargs_from_env
def resume(self, execution_id, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id)
- data = {'status': 'resuming'}
+ url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id)
+ data = {"status": "resuming"}
response = self.client.put(url, data, **kwargs)
@@ -453,14 +485,14 @@ def resume(self, execution_id, **kwargs):
@add_auth_token_to_kwargs_from_env
def get_children(self, execution_id, **kwargs):
- url = '/%s/%s/children' % (self.resource.get_url_path_name(), execution_id)
+ url = "/%s/%s/children" % (self.resource.get_url_path_name(), execution_id)
- depth = kwargs.pop('depth', -1)
+ depth = kwargs.pop("depth", -1)
- params = kwargs.pop('params', {})
+ params = kwargs.pop("params", {})
if depth:
- params['depth'] = depth
+ params["depth"] = depth
response = self.client.get(url=url, params=params, **kwargs)
if response.status_code != http_client.OK:
@@ -469,19 +501,15 @@ def get_children(self, execution_id, **kwargs):
class InquiryResourceManager(ResourceManager):
-
@add_auth_token_to_kwargs_from_env
def respond(self, inquiry_id, inquiry_response, **kwargs):
"""
Update st2.inquiry.respond action
Update st2client respond command to use this?
"""
- url = '/%s/%s' % (self.resource.get_url_path_name(), inquiry_id)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), inquiry_id)
- payload = {
- "id": inquiry_id,
- "response": inquiry_response
- }
+ payload = {"id": inquiry_id, "response": inquiry_response}
response = self.client.put(url, payload, **kwargs)
@@ -494,7 +522,10 @@ def respond(self, inquiry_id, inquiry_response, **kwargs):
class TriggerInstanceResourceManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def re_emit(self, trigger_instance_id, **kwargs):
- url = '/%s/%s/re_emit' % (self.resource.get_url_path_name(), trigger_instance_id)
+ url = "/%s/%s/re_emit" % (
+ self.resource.get_url_path_name(),
+ trigger_instance_id,
+ )
response = self.client.post(url, None, **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
@@ -508,11 +539,11 @@ class AsyncRequest(Resource):
class PackResourceManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def install(self, packs, force=False, skip_dependencies=False, **kwargs):
- url = '/%s/install' % (self.resource.get_url_path_name())
+ url = "/%s/install" % (self.resource.get_url_path_name())
payload = {
- 'packs': packs,
- 'force': force,
- 'skip_dependencies': skip_dependencies
+ "packs": packs,
+ "force": force,
+ "skip_dependencies": skip_dependencies,
}
response = self.client.post(url, payload, **kwargs)
if response.status_code != http_client.OK:
@@ -522,8 +553,8 @@ def install(self, packs, force=False, skip_dependencies=False, **kwargs):
@add_auth_token_to_kwargs_from_env
def remove(self, packs, **kwargs):
- url = '/%s/uninstall' % (self.resource.get_url_path_name())
- response = self.client.post(url, {'packs': packs}, **kwargs)
+ url = "/%s/uninstall" % (self.resource.get_url_path_name())
+ response = self.client.post(url, {"packs": packs}, **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
instance = AsyncRequest.deserialize(response.json())
@@ -531,11 +562,11 @@ def remove(self, packs, **kwargs):
@add_auth_token_to_kwargs_from_env
def search(self, args, ignore_errors=False, **kwargs):
- url = '/%s/index/search' % (self.resource.get_url_path_name())
- if 'query' in vars(args):
- payload = {'query': args.query}
+ url = "/%s/index/search" % (self.resource.get_url_path_name())
+ if "query" in vars(args):
+ payload = {"query": args.query}
else:
- payload = {'pack': args.pack}
+ payload = {"pack": args.pack}
response = self.client.post(url, payload, **kwargs)
@@ -552,12 +583,12 @@ def search(self, args, ignore_errors=False, **kwargs):
@add_auth_token_to_kwargs_from_env
def register(self, packs=None, types=None, **kwargs):
- url = '/%s/register' % (self.resource.get_url_path_name())
+ url = "/%s/register" % (self.resource.get_url_path_name())
payload = {}
if types:
- payload['types'] = types
+ payload["types"] = types
if packs:
- payload['packs'] = packs
+ payload["packs"] = packs
response = self.client.post(url, payload, **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
@@ -568,7 +599,7 @@ def register(self, packs=None, types=None, **kwargs):
class ConfigManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def update(self, instance, **kwargs):
- url = '/%s/%s' % (self.resource.get_url_path_name(), instance.pack)
+ url = "/%s/%s" % (self.resource.get_url_path_name(), instance.pack)
response = self.client.put(url, instance.values, **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
@@ -584,16 +615,13 @@ def __init__(self, resource, endpoint, cacert=None, debug=False):
@add_auth_token_to_kwargs_from_env
def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs):
- url = '/webhooks/st2'
+ url = "/webhooks/st2"
headers = {}
- data = {
- 'trigger': trigger,
- 'payload': payload or {}
- }
+ data = {"trigger": trigger, "payload": payload or {}}
if trace_tag:
- headers['St2-Trace-Tag'] = trace_tag
+ headers["St2-Trace-Tag"] = trace_tag
response = self.client.post(url, data=data, headers=headers, **kwargs)
@@ -604,17 +632,20 @@ def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs):
@add_auth_token_to_kwargs_from_env
def match(self, instance, **kwargs):
- url = '/%s/match' % self.resource.get_url_path_name()
+ url = "/%s/match" % self.resource.get_url_path_name()
response = self.client.post(url, instance.serialize(), **kwargs)
if response.status_code != http_client.OK:
self.handle_error(response)
match = response.json()
- return (self.resource.deserialize(match['actionalias']), match['representation'])
+ return (
+ self.resource.deserialize(match["actionalias"]),
+ match["representation"],
+ )
class StreamManager(object):
def __init__(self, endpoint, cacert=None, debug=False):
- self._url = httpclient.get_url_without_trailing_slash(endpoint) + '/stream'
+ self._url = httpclient.get_url_without_trailing_slash(endpoint) + "/stream"
self.debug = debug
self.cacert = cacert
@@ -631,25 +662,25 @@ def listen(self, events=None, **kwargs):
if events and isinstance(events, six.string_types):
events = [events]
- if 'token' in kwargs:
- query_params['x-auth-token'] = kwargs.get('token')
+ if "token" in kwargs:
+ query_params["x-auth-token"] = kwargs.get("token")
- if 'api_key' in kwargs:
- query_params['st2-api-key'] = kwargs.get('api_key')
+ if "api_key" in kwargs:
+ query_params["st2-api-key"] = kwargs.get("api_key")
- if 'end_event' in kwargs:
- query_params['end_event'] = kwargs.get('end_event')
+ if "end_event" in kwargs:
+ query_params["end_event"] = kwargs.get("end_event")
- if 'end_execution_id' in kwargs:
- query_params['end_execution_id'] = kwargs.get('end_execution_id')
+ if "end_execution_id" in kwargs:
+ query_params["end_execution_id"] = kwargs.get("end_execution_id")
if events:
- query_params['events'] = ','.join(events)
+ query_params["events"] = ",".join(events)
if self.cacert is not None:
- request_params['verify'] = self.cacert
+ request_params["verify"] = self.cacert
- query_string = '?' + urllib.parse.urlencode(query_params)
+ query_string = "?" + urllib.parse.urlencode(query_params)
url = url + query_string
response = requests.get(url, stream=True, **request_params)
@@ -667,36 +698,38 @@ class WorkflowManager(object):
def __init__(self, endpoint, cacert, debug):
self.debug = debug
self.cacert = cacert
- self.endpoint = endpoint + '/workflows'
- self.client = httpclient.HTTPClient(root=self.endpoint, cacert=cacert, debug=debug)
+ self.endpoint = endpoint + "/workflows"
+ self.client = httpclient.HTTPClient(
+ root=self.endpoint, cacert=cacert, debug=debug
+ )
@staticmethod
def handle_error(response):
try:
content = response.json()
- fault = content.get('faultstring', '') if content else ''
+ fault = content.get("faultstring", "") if content else ""
if fault:
- response.reason += '\nMESSAGE: %s' % fault
+ response.reason += "\nMESSAGE: %s" % fault
except Exception as e:
response.reason += (
- '\nUnable to retrieve detailed message '
- 'from the HTTP response. %s\n' % six.text_type(e)
+ "\nUnable to retrieve detailed message "
+ "from the HTTP response. %s\n" % six.text_type(e)
)
response.raise_for_status()
@add_auth_token_to_kwargs_from_env
def inspect(self, definition, **kwargs):
- url = '/inspect'
+ url = "/inspect"
if not isinstance(definition, six.string_types):
- raise TypeError('Workflow definition is not type of string.')
+ raise TypeError("Workflow definition is not type of string.")
- if 'headers' not in kwargs:
- kwargs['headers'] = {}
+ if "headers" not in kwargs:
+ kwargs["headers"] = {}
- kwargs['headers']['content-type'] = 'text/plain'
+ kwargs["headers"]["content-type"] = "text/plain"
response = self.client.post_raw(url, definition, **kwargs)
@@ -709,7 +742,7 @@ def inspect(self, definition, **kwargs):
class ServiceRegistryGroupsManager(ResourceManager):
@add_auth_token_to_kwargs_from_env
def list(self, **kwargs):
- url = '/service_registry/groups'
+ url = "/service_registry/groups"
headers = {}
response = self.client.get(url, headers=headers, **kwargs)
@@ -717,21 +750,20 @@ def list(self, **kwargs):
if response.status_code != http_client.OK:
self.handle_error(response)
- groups = response.json()['groups']
+ groups = response.json()["groups"]
result = []
for group in groups:
- item = self.resource.deserialize({'group_id': group})
+ item = self.resource.deserialize({"group_id": group})
result.append(item)
return result
class ServiceRegistryMembersManager(ResourceManager):
-
@add_auth_token_to_kwargs_from_env
def list(self, group_id, **kwargs):
- url = '/service_registry/groups/%s/members' % (group_id)
+ url = "/service_registry/groups/%s/members" % (group_id)
headers = {}
response = self.client.get(url, headers=headers, **kwargs)
@@ -739,14 +771,14 @@ def list(self, group_id, **kwargs):
if response.status_code != http_client.OK:
self.handle_error(response)
- members = response.json()['members']
+ members = response.json()["members"]
result = []
for member in members:
data = {
- 'group_id': group_id,
- 'member_id': member['member_id'],
- 'capabilities': member['capabilities']
+ "group_id": group_id,
+ "member_id": member["member_id"],
+ "capabilities": member["capabilities"],
}
item = self.resource.deserialize(data)
result.append(item)
diff --git a/st2client/st2client/models/inquiry.py b/st2client/st2client/models/inquiry.py
index 5d1a1076f5..93161ee68f 100644
--- a/st2client/st2client/models/inquiry.py
+++ b/st2client/st2client/models/inquiry.py
@@ -24,15 +24,8 @@
class Inquiry(core.Resource):
- _display_name = 'Inquiry'
- _plural = 'Inquiries'
- _plural_display_name = 'Inquiries'
- _url_path = 'inquiries'
- _repr_attributes = [
- 'id',
- 'schema',
- 'roles',
- 'users',
- 'route',
- 'ttl'
- ]
+ _display_name = "Inquiry"
+ _plural = "Inquiries"
+ _plural_display_name = "Inquiries"
+ _url_path = "inquiries"
+ _repr_attributes = ["id", "schema", "roles", "users", "route", "ttl"]
diff --git a/st2client/st2client/models/keyvalue.py b/st2client/st2client/models/keyvalue.py
index f7095a4b8f..5bcd1de8de 100644
--- a/st2client/st2client/models/keyvalue.py
+++ b/st2client/st2client/models/keyvalue.py
@@ -24,11 +24,11 @@
class KeyValuePair(core.Resource):
- _alias = 'Key'
- _display_name = 'Key Value Pair'
- _plural = 'Keys'
- _plural_display_name = 'Key Value Pairs'
- _repr_attributes = ['name', 'value']
+ _alias = "Key"
+ _display_name = "Key Value Pair"
+ _plural = "Keys"
+ _plural_display_name = "Key Value Pairs"
+ _repr_attributes = ["name", "value"]
# Note: This is a temporary hack until we refactor client and make it support non id PKs
def get_id(self):
diff --git a/st2client/st2client/models/pack.py b/st2client/st2client/models/pack.py
index 5d681266ad..7333c1a28e 100644
--- a/st2client/st2client/models/pack.py
+++ b/st2client/st2client/models/pack.py
@@ -19,8 +19,8 @@
class Pack(core.Resource):
- _display_name = 'Pack'
- _plural = 'Packs'
- _plural_display_name = 'Packs'
- _url_path = 'packs'
- _repr_attributes = ['name', 'description', 'version', 'author']
+ _display_name = "Pack"
+ _plural = "Packs"
+ _plural_display_name = "Packs"
+ _url_path = "packs"
+ _repr_attributes = ["name", "description", "version", "author"]
diff --git a/st2client/st2client/models/policy.py b/st2client/st2client/models/policy.py
index 851779d7fd..4b8bb0c813 100644
--- a/st2client/st2client/models/policy.py
+++ b/st2client/st2client/models/policy.py
@@ -24,13 +24,13 @@
class PolicyType(core.Resource):
- _alias = 'Policy-Type'
- _display_name = 'Policy type'
- _plural = 'PolicyTypes'
- _plural_display_name = 'Policy types'
- _repr_attributes = ['ref', 'enabled', 'description']
+ _alias = "Policy-Type"
+ _display_name = "Policy type"
+ _plural = "PolicyTypes"
+ _plural_display_name = "Policy types"
+ _repr_attributes = ["ref", "enabled", "description"]
class Policy(core.Resource):
- _plural = 'Policies'
- _repr_attributes = ['name', 'pack', 'enabled', 'policy_type', 'resource_ref']
+ _plural = "Policies"
+ _repr_attributes = ["name", "pack", "enabled", "policy_type", "resource_ref"]
diff --git a/st2client/st2client/models/rbac.py b/st2client/st2client/models/rbac.py
index 6df4aa4f94..94c765ddf3 100644
--- a/st2client/st2client/models/rbac.py
+++ b/st2client/st2client/models/rbac.py
@@ -17,25 +17,22 @@
from st2client.models import core
-__all__ = [
- 'Role',
- 'UserRoleAssignment'
-]
+__all__ = ["Role", "UserRoleAssignment"]
class Role(core.Resource):
- _alias = 'role'
- _display_name = 'Role'
- _plural = 'Roles'
- _plural_display_name = 'Roles'
- _repr_attributes = ['id', 'name', 'system']
- _url_path = 'rbac/roles'
+ _alias = "role"
+ _display_name = "Role"
+ _plural = "Roles"
+ _plural_display_name = "Roles"
+ _repr_attributes = ["id", "name", "system"]
+ _url_path = "rbac/roles"
class UserRoleAssignment(core.Resource):
- _alias = 'role-assignment'
- _display_name = 'Role Assignment'
- _plural = 'RoleAssignments'
- _plural_display_name = 'Role Assignments'
- _repr_attributes = ['id', 'role', 'user', 'is_remote']
- _url_path = 'rbac/role_assignments'
+ _alias = "role-assignment"
+ _display_name = "Role Assignment"
+ _plural = "RoleAssignments"
+ _plural_display_name = "Role Assignments"
+ _repr_attributes = ["id", "role", "user", "is_remote"]
+ _url_path = "rbac/role_assignments"
diff --git a/st2client/st2client/models/reactor.py b/st2client/st2client/models/reactor.py
index 140d1aaf50..ef4c054f69 100644
--- a/st2client/st2client/models/reactor.py
+++ b/st2client/st2client/models/reactor.py
@@ -24,43 +24,49 @@
class Sensor(core.Resource):
- _plural = 'Sensortypes'
- _repr_attributes = ['name', 'pack']
+ _plural = "Sensortypes"
+ _repr_attributes = ["name", "pack"]
class TriggerType(core.Resource):
- _alias = 'Trigger'
- _display_name = 'Trigger'
- _plural = 'Triggertypes'
- _plural_display_name = 'Triggers'
- _repr_attributes = ['name', 'pack']
+ _alias = "Trigger"
+ _display_name = "Trigger"
+ _plural = "Triggertypes"
+ _plural_display_name = "Triggers"
+ _repr_attributes = ["name", "pack"]
class TriggerInstance(core.Resource):
- _alias = 'Trigger-Instance'
- _display_name = 'TriggerInstance'
- _plural = 'Triggerinstances'
- _plural_display_name = 'TriggerInstances'
- _repr_attributes = ['id', 'trigger', 'occurrence_time', 'payload', 'status']
+ _alias = "Trigger-Instance"
+ _display_name = "TriggerInstance"
+ _plural = "Triggerinstances"
+ _plural_display_name = "TriggerInstances"
+ _repr_attributes = ["id", "trigger", "occurrence_time", "payload", "status"]
class Trigger(core.Resource):
- _alias = 'TriggerSpecification'
- _display_name = 'Trigger Specification'
- _plural = 'Triggers'
- _plural_display_name = 'Trigger Specifications'
- _repr_attributes = ['name', 'pack']
+ _alias = "TriggerSpecification"
+ _display_name = "Trigger Specification"
+ _plural = "Triggers"
+ _plural_display_name = "Trigger Specifications"
+ _repr_attributes = ["name", "pack"]
class Rule(core.Resource):
- _alias = 'Rule'
- _plural = 'Rules'
- _repr_attributes = ['name', 'pack', 'trigger', 'criteria', 'enabled']
+ _alias = "Rule"
+ _plural = "Rules"
+ _repr_attributes = ["name", "pack", "trigger", "criteria", "enabled"]
class RuleEnforcement(core.Resource):
- _alias = 'Rule-Enforcement'
- _plural = 'RuleEnforcements'
- _display_name = 'Rule Enforcement'
- _plural_display_name = 'Rule Enforcements'
- _repr_attributes = ['id', 'trigger_instance_id', 'execution_id', 'rule.ref', 'enforced_at']
+ _alias = "Rule-Enforcement"
+ _plural = "RuleEnforcements"
+ _display_name = "Rule Enforcement"
+ _plural_display_name = "Rule Enforcements"
+ _repr_attributes = [
+ "id",
+ "trigger_instance_id",
+ "execution_id",
+ "rule.ref",
+ "enforced_at",
+ ]
diff --git a/st2client/st2client/models/service_registry.py b/st2client/st2client/models/service_registry.py
index 3b3057a3c3..ca95cd73cb 100644
--- a/st2client/st2client/models/service_registry.py
+++ b/st2client/st2client/models/service_registry.py
@@ -17,32 +17,27 @@
from st2client.models import core
-__all__ = [
- 'ServiceRegistry',
-
- 'ServiceRegistryGroup',
- 'ServiceRegistryMember'
-]
+__all__ = ["ServiceRegistry", "ServiceRegistryGroup", "ServiceRegistryMember"]
class ServiceRegistry(core.Resource):
- _alias = 'service-registry'
- _display_name = 'Service Registry'
- _plural = 'Service Registry'
- _plural_display_name = 'Service Registry'
+ _alias = "service-registry"
+ _display_name = "Service Registry"
+ _plural = "Service Registry"
+ _plural_display_name = "Service Registry"
class ServiceRegistryGroup(core.Resource):
- _alias = 'group'
- _display_name = 'Group'
- _plural = 'Groups'
- _plural_display_name = 'Groups'
- _repr_attributes = ['group_id']
+ _alias = "group"
+ _display_name = "Group"
+ _plural = "Groups"
+ _plural_display_name = "Groups"
+ _repr_attributes = ["group_id"]
class ServiceRegistryMember(core.Resource):
- _alias = 'member'
- _display_name = 'Group Member'
- _plural = 'Group Members'
- _plural_display_name = 'Group Members'
- _repr_attributes = ['group_id', 'member_id']
+ _alias = "member"
+ _display_name = "Group Member"
+ _plural = "Group Members"
+ _plural_display_name = "Group Members"
+ _repr_attributes = ["group_id", "member_id"]
diff --git a/st2client/st2client/models/timer.py b/st2client/st2client/models/timer.py
index 4ba58547f3..fbfbd6cfcd 100644
--- a/st2client/st2client/models/timer.py
+++ b/st2client/st2client/models/timer.py
@@ -24,7 +24,7 @@
class Timer(core.Resource):
- _alias = 'Timer'
- _display_name = 'Timer'
- _plural = 'Timers'
- _plural_display_name = 'Timers'
+ _alias = "Timer"
+ _display_name = "Timer"
+ _plural = "Timers"
+ _plural_display_name = "Timers"
diff --git a/st2client/st2client/models/trace.py b/st2client/st2client/models/trace.py
index a03b4a8812..3b7bfe4449 100644
--- a/st2client/st2client/models/trace.py
+++ b/st2client/st2client/models/trace.py
@@ -19,8 +19,8 @@
class Trace(core.Resource):
- _alias = 'Trace'
- _display_name = 'Trace'
- _plural = 'Traces'
- _plural_display_name = 'Traces'
- _repr_attributes = ['id', 'trace_tag']
+ _alias = "Trace"
+ _display_name = "Trace"
+ _plural = "Traces"
+ _plural_display_name = "Traces"
+ _repr_attributes = ["id", "trace_tag"]
diff --git a/st2client/st2client/models/webhook.py b/st2client/st2client/models/webhook.py
index 83d939f061..161d1bdb4c 100644
--- a/st2client/st2client/models/webhook.py
+++ b/st2client/st2client/models/webhook.py
@@ -24,8 +24,8 @@
class Webhook(core.Resource):
- _alias = 'Webhook'
- _display_name = 'Webhook'
- _plural = 'Webhooks'
- _plural_display_name = 'Webhooks'
- _repr_attributes = ['parameters', 'type', 'pack', 'name']
+ _alias = "Webhook"
+ _display_name = "Webhook"
+ _plural = "Webhooks"
+ _plural_display_name = "Webhooks"
+ _repr_attributes = ["parameters", "type", "pack", "name"]
diff --git a/st2client/st2client/shell.py b/st2client/st2client/shell.py
index ac6108d796..7d3359c532 100755
--- a/st2client/st2client/shell.py
+++ b/st2client/st2client/shell.py
@@ -25,6 +25,7 @@
# Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7
import warnings
from cryptography.utils import CryptographyDeprecationWarning
+
warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning)
import os
@@ -66,13 +67,13 @@
from st2client.commands.auth import LoginCommand
-__all__ = [
- 'Shell'
-]
+__all__ = ["Shell"]
LOGGER = logging.getLogger(__name__)
-CLI_DESCRIPTION = 'CLI for StackStorm event-driven automation platform. https://stackstorm.com'
+CLI_DESCRIPTION = (
+ "CLI for StackStorm event-driven automation platform. https://stackstorm.com"
+)
USAGE_STRING = """
Usage: %(prog)s [options] [options]
@@ -83,15 +84,19 @@
%(prog)s --debug run core.local cmd=date
""".strip()
-NON_UTF8_LOCALE = """
+NON_UTF8_LOCALE = (
+ """
Locale %s with encoding %s which is not UTF-8 is used. This means that some functionality which
relies on outputting unicode characters won't work.
You are encouraged to use UTF-8 locale by setting LC_ALL environment variable to en_US.UTF-8 or
similar.
-""".strip().replace('\n', ' ').replace(' ', ' ')
+""".strip()
+ .replace("\n", " ")
+ .replace(" ", " ")
+)
-PACKAGE_METADATA_FILE_PATH = '/opt/stackstorm/st2/package.meta'
+PACKAGE_METADATA_FILE_PATH = "/opt/stackstorm/st2/package.meta"
def get_stackstorm_version():
@@ -101,7 +106,7 @@ def get_stackstorm_version():
:rtype: ``str``
"""
- if 'dev' in __version__:
+ if "dev" in __version__:
version = __version__
if not os.path.isfile(PACKAGE_METADATA_FILE_PATH):
@@ -115,11 +120,11 @@ def get_stackstorm_version():
return version
try:
- git_revision = config.get('server', 'git_sha')
+ git_revision = config.get("server", "git_sha")
except Exception:
return version
- version = '%s (%s)' % (version, git_revision)
+ version = "%s (%s)" % (version, git_revision)
else:
version = __version__
@@ -143,214 +148,237 @@ def __init__(self):
# Set up general program options.
self.parser.add_argument(
- '--version',
- action='version',
- version='%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}'
- .format(version=get_stackstorm_version(),
- python_major=sys.version_info.major,
- python_minor=sys.version_info.minor,
- python_patch=sys.version_info.micro))
+ "--version",
+ action="version",
+ version="%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}".format(
+ version=get_stackstorm_version(),
+ python_major=sys.version_info.major,
+ python_minor=sys.version_info.minor,
+ python_patch=sys.version_info.micro,
+ ),
+ )
self.parser.add_argument(
- '--url',
- action='store',
- dest='base_url',
+ "--url",
+ action="store",
+ dest="base_url",
default=None,
- help='Base URL for the API servers. Assumes all servers use the '
- 'same base URL and default ports are used. Get ST2_BASE_URL '
- 'from the environment variables by default.'
+ help="Base URL for the API servers. Assumes all servers use the "
+ "same base URL and default ports are used. Get ST2_BASE_URL "
+ "from the environment variables by default.",
)
self.parser.add_argument(
- '--auth-url',
- action='store',
- dest='auth_url',
+ "--auth-url",
+ action="store",
+ dest="auth_url",
default=None,
- help='URL for the authentication service. Get ST2_AUTH_URL '
- 'from the environment variables by default.'
+ help="URL for the authentication service. Get ST2_AUTH_URL "
+ "from the environment variables by default.",
)
self.parser.add_argument(
- '--api-url',
- action='store',
- dest='api_url',
+ "--api-url",
+ action="store",
+ dest="api_url",
default=None,
- help='URL for the API server. Get ST2_API_URL '
- 'from the environment variables by default.'
+ help="URL for the API server. Get ST2_API_URL "
+ "from the environment variables by default.",
)
self.parser.add_argument(
- '--stream-url',
- action='store',
- dest='stream_url',
+ "--stream-url",
+ action="store",
+ dest="stream_url",
default=None,
- help='URL for the stream endpoint. Get ST2_STREAM_URL'
- 'from the environment variables by default.'
+ help="URL for the stream endpoint. Get ST2_STREAM_URL"
+ "from the environment variables by default.",
)
self.parser.add_argument(
- '--api-version',
- action='store',
- dest='api_version',
+ "--api-version",
+ action="store",
+ dest="api_version",
default=None,
- help='API version to use. Get ST2_API_VERSION '
- 'from the environment variables by default.'
+ help="API version to use. Get ST2_API_VERSION "
+ "from the environment variables by default.",
)
self.parser.add_argument(
- '--cacert',
- action='store',
- dest='cacert',
+ "--cacert",
+ action="store",
+ dest="cacert",
default=None,
- help='Path to the CA cert bundle for the SSL endpoints. '
- 'Get ST2_CACERT from the environment variables by default. '
- 'If this is not provided, then SSL cert will not be verified.'
+ help="Path to the CA cert bundle for the SSL endpoints. "
+ "Get ST2_CACERT from the environment variables by default. "
+ "If this is not provided, then SSL cert will not be verified.",
)
self.parser.add_argument(
- '--config-file',
- action='store',
- dest='config_file',
+ "--config-file",
+ action="store",
+ dest="config_file",
default=None,
- help='Path to the CLI config file'
+ help="Path to the CLI config file",
)
self.parser.add_argument(
- '--print-config',
- action='store_true',
- dest='print_config',
+ "--print-config",
+ action="store_true",
+ dest="print_config",
default=False,
- help='Parse the config file and print the values'
+ help="Parse the config file and print the values",
)
self.parser.add_argument(
- '--skip-config',
- action='store_true',
- dest='skip_config',
+ "--skip-config",
+ action="store_true",
+ dest="skip_config",
default=False,
- help='Don\'t parse and use the CLI config file'
+ help="Don't parse and use the CLI config file",
)
self.parser.add_argument(
- '--debug',
- action='store_true',
- dest='debug',
+ "--debug",
+ action="store_true",
+ dest="debug",
default=False,
- help='Enable debug mode'
+ help="Enable debug mode",
)
# Set up list of commands and subcommands.
- self.subparsers = self.parser.add_subparsers(dest='parser')
+ self.subparsers = self.parser.add_subparsers(dest="parser")
self.subparsers.required = True
self.commands = {}
- self.commands['run'] = action.ActionRunCommand(
- models.Action, self, self.subparsers, name='run', add_help=False)
+ self.commands["run"] = action.ActionRunCommand(
+ models.Action, self, self.subparsers, name="run", add_help=False
+ )
- self.commands['action'] = action.ActionBranch(
- 'An activity that happens as a response to the external event.',
- self, self.subparsers)
+ self.commands["action"] = action.ActionBranch(
+ "An activity that happens as a response to the external event.",
+ self,
+ self.subparsers,
+ )
- self.commands['action-alias'] = action_alias.ActionAliasBranch(
- 'Action aliases.',
- self, self.subparsers)
+ self.commands["action-alias"] = action_alias.ActionAliasBranch(
+ "Action aliases.", self, self.subparsers
+ )
- self.commands['auth'] = auth.TokenCreateCommand(
- models.Token, self, self.subparsers, name='auth')
+ self.commands["auth"] = auth.TokenCreateCommand(
+ models.Token, self, self.subparsers, name="auth"
+ )
- self.commands['login'] = auth.LoginCommand(
- models.Token, self, self.subparsers, name='login')
+ self.commands["login"] = auth.LoginCommand(
+ models.Token, self, self.subparsers, name="login"
+ )
- self.commands['whoami'] = auth.WhoamiCommand(
- models.Token, self, self.subparsers, name='whoami')
+ self.commands["whoami"] = auth.WhoamiCommand(
+ models.Token, self, self.subparsers, name="whoami"
+ )
- self.commands['api-key'] = auth.ApiKeyBranch(
- 'API Keys.',
- self, self.subparsers)
+ self.commands["api-key"] = auth.ApiKeyBranch("API Keys.", self, self.subparsers)
- self.commands['execution'] = action.ActionExecutionBranch(
- 'An invocation of an action.',
- self, self.subparsers)
+ self.commands["execution"] = action.ActionExecutionBranch(
+ "An invocation of an action.", self, self.subparsers
+ )
- self.commands['inquiry'] = inquiry.InquiryBranch(
- 'Inquiries provide an opportunity to ask a question '
- 'and wait for a response in a workflow.',
- self, self.subparsers)
+ self.commands["inquiry"] = inquiry.InquiryBranch(
+ "Inquiries provide an opportunity to ask a question "
+ "and wait for a response in a workflow.",
+ self,
+ self.subparsers,
+ )
- self.commands['key'] = keyvalue.KeyValuePairBranch(
- 'Key value pair is used to store commonly used configuration '
- 'for reuse in sensors, actions, and rules.',
- self, self.subparsers)
+ self.commands["key"] = keyvalue.KeyValuePairBranch(
+ "Key value pair is used to store commonly used configuration "
+ "for reuse in sensors, actions, and rules.",
+ self,
+ self.subparsers,
+ )
- self.commands['pack'] = pack.PackBranch(
- 'A group of related integration resources: '
- 'actions, rules, and sensors.',
- self, self.subparsers)
+ self.commands["pack"] = pack.PackBranch(
+ "A group of related integration resources: " "actions, rules, and sensors.",
+ self,
+ self.subparsers,
+ )
- self.commands['policy'] = policy.PolicyBranch(
- 'Policy that is enforced on a resource.',
- self, self.subparsers)
+ self.commands["policy"] = policy.PolicyBranch(
+ "Policy that is enforced on a resource.", self, self.subparsers
+ )
- self.commands['policy-type'] = policy.PolicyTypeBranch(
- 'Type of policy that can be applied to resources.',
- self, self.subparsers)
+ self.commands["policy-type"] = policy.PolicyTypeBranch(
+ "Type of policy that can be applied to resources.", self, self.subparsers
+ )
- self.commands['rule'] = rule.RuleBranch(
+ self.commands["rule"] = rule.RuleBranch(
'A specification to invoke an "action" on a "trigger" selectively '
- 'based on some criteria.',
- self, self.subparsers)
+ "based on some criteria.",
+ self,
+ self.subparsers,
+ )
- self.commands['webhook'] = webhook.WebhookBranch(
- 'Webhooks.',
- self, self.subparsers)
+ self.commands["webhook"] = webhook.WebhookBranch(
+ "Webhooks.", self, self.subparsers
+ )
- self.commands['timer'] = timer.TimerBranch(
- 'Timers.',
- self, self.subparsers)
+ self.commands["timer"] = timer.TimerBranch("Timers.", self, self.subparsers)
- self.commands['runner'] = resource.ResourceBranch(
+ self.commands["runner"] = resource.ResourceBranch(
models.RunnerType,
- 'Runner is a type of handler for a specific class of actions.',
- self, self.subparsers, read_only=True, has_disable=True)
+ "Runner is a type of handler for a specific class of actions.",
+ self,
+ self.subparsers,
+ read_only=True,
+ has_disable=True,
+ )
- self.commands['sensor'] = sensor.SensorBranch(
- 'An adapter which allows you to integrate StackStorm with external system.',
- self, self.subparsers)
+ self.commands["sensor"] = sensor.SensorBranch(
+ "An adapter which allows you to integrate StackStorm with external system.",
+ self,
+ self.subparsers,
+ )
- self.commands['trace'] = trace.TraceBranch(
- 'A group of executions, rules and triggerinstances that are related.',
- self, self.subparsers)
+ self.commands["trace"] = trace.TraceBranch(
+ "A group of executions, rules and triggerinstances that are related.",
+ self,
+ self.subparsers,
+ )
- self.commands['trigger'] = trigger.TriggerTypeBranch(
- 'An external event that is mapped to a st2 input. It is the '
- 'st2 invocation point.',
- self, self.subparsers)
+ self.commands["trigger"] = trigger.TriggerTypeBranch(
+ "An external event that is mapped to a st2 input. It is the "
+ "st2 invocation point.",
+ self,
+ self.subparsers,
+ )
- self.commands['trigger-instance'] = triggerinstance.TriggerInstanceBranch(
- 'Actual instances of triggers received by st2.',
- self, self.subparsers)
+ self.commands["trigger-instance"] = triggerinstance.TriggerInstanceBranch(
+ "Actual instances of triggers received by st2.", self, self.subparsers
+ )
- self.commands['rule-enforcement'] = rule_enforcement.RuleEnforcementBranch(
- 'Models that represent enforcement of rules.',
- self, self.subparsers)
+ self.commands["rule-enforcement"] = rule_enforcement.RuleEnforcementBranch(
+ "Models that represent enforcement of rules.", self, self.subparsers
+ )
- self.commands['workflow'] = workflow.WorkflowBranch(
- 'Commands for workflow authoring related operations. '
- 'Only orquesta workflows are supported.',
- self, self.subparsers)
+ self.commands["workflow"] = workflow.WorkflowBranch(
+ "Commands for workflow authoring related operations. "
+ "Only orquesta workflows are supported.",
+ self,
+ self.subparsers,
+ )
# Service Registry
- self.commands['service-registry'] = service_registry.ServiceRegistryBranch(
- 'Service registry group and membership related commands.',
- self, self.subparsers)
+ self.commands["service-registry"] = service_registry.ServiceRegistryBranch(
+ "Service registry group and membership related commands.",
+ self,
+ self.subparsers,
+ )
# RBAC
- self.commands['role'] = rbac.RoleBranch(
- 'RBAC roles.',
- self, self.subparsers)
- self.commands['role-assignment'] = rbac.RoleAssignmentBranch(
- 'RBAC role assignments.',
- self, self.subparsers)
+ self.commands["role"] = rbac.RoleBranch("RBAC roles.", self, self.subparsers)
+ self.commands["role-assignment"] = rbac.RoleAssignmentBranch(
+ "RBAC role assignments.", self, self.subparsers
+ )
def run(self, argv):
debug = False
@@ -369,9 +397,9 @@ def run(self, argv):
# Provide autocomplete for shell
argcomplete.autocomplete(self.parser)
- if '--print-config' in argv:
+ if "--print-config" in argv:
# Hack because --print-config requires no command to be specified
- argv = argv + ['action', 'list']
+ argv = argv + ["action", "list"]
# Parse command line arguments.
args = self.parser.parse_args(args=argv)
@@ -389,7 +417,7 @@ def run(self, argv):
# Setup client and run the command
try:
- debug = getattr(args, 'debug', False)
+ debug = getattr(args, "debug", False)
if debug:
set_log_level_for_all_loggers(level=logging.DEBUG)
@@ -399,7 +427,7 @@ def run(self, argv):
# TODO: This is not so nice work-around for Python 3 because of a breaking change in
# Python 3 - https://bugs.python.org/issue16308
try:
- func = getattr(args, 'func')
+ func = getattr(args, "func")
except AttributeError:
parser.print_help()
sys.exit(2)
@@ -414,9 +442,9 @@ def run(self, argv):
return 2
except Exception as e:
# We allow exception to define custom exit codes
- exit_code = getattr(e, 'exit_code', 1)
+ exit_code = getattr(e, "exit_code", 1)
- print('ERROR: %s\n' % e)
+ print("ERROR: %s\n" % e)
if debug:
self._print_debug_info(args=args)
@@ -426,10 +454,10 @@ def _print_config(self, args):
config = self._parse_config_file(args=args)
for section, options in six.iteritems(config):
- print('[%s]' % (section))
+ print("[%s]" % (section))
for name, value in six.iteritems(options):
- print('%s = %s' % (name, value))
+ print("%s = %s" % (name, value))
def _check_locale_and_print_warning(self):
"""
@@ -440,23 +468,23 @@ def _check_locale_and_print_warning(self):
preferred_encoding = locale.getpreferredencoding()
except ValueError:
# Ignore unknown locale errors for now
- default_locale = 'unknown'
- preferred_encoding = 'unknown'
+ default_locale = "unknown"
+ preferred_encoding = "unknown"
- if preferred_encoding and preferred_encoding.lower() != 'utf-8':
- msg = NON_UTF8_LOCALE % (default_locale or 'unknown', preferred_encoding)
+ if preferred_encoding and preferred_encoding.lower() != "utf-8":
+ msg = NON_UTF8_LOCALE % (default_locale or "unknown", preferred_encoding)
LOGGER.warn(msg)
def setup_logging(argv):
- debug = '--debug' in argv
+ debug = "--debug" in argv
root = LOGGER
root.setLevel(logging.WARNING)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARNING)
- formatter = logging.Formatter('%(asctime)s %(levelname)s - %(message)s')
+ formatter = logging.Formatter("%(asctime)s %(levelname)s - %(message)s")
handler.setFormatter(formatter)
if not debug:
@@ -470,5 +498,5 @@ def main(argv=sys.argv[1:]):
return Shell().run(argv)
-if __name__ == '__main__':
+if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
diff --git a/st2client/st2client/utils/color.py b/st2client/st2client/utils/color.py
index 8b18402136..f1106851e2 100644
--- a/st2client/st2client/utils/color.py
+++ b/st2client/st2client/utils/color.py
@@ -16,40 +16,36 @@
from __future__ import absolute_import
import os
-__all__ = [
- 'DisplayColors',
-
- 'format_status'
-]
+__all__ = ["DisplayColors", "format_status"]
TERMINAL_SUPPORTS_ANSI_CODES = [
- 'xterm',
- 'xterm-color',
- 'screen',
- 'vt100',
- 'vt100-color',
- 'xterm-256color'
+ "xterm",
+ "xterm-color",
+ "screen",
+ "vt100",
+ "vt100-color",
+ "xterm-256color",
]
-DISABLED = os.environ.get('ST2_COLORIZE', '')
+DISABLED = os.environ.get("ST2_COLORIZE", "")
class DisplayColors(object):
- RED = '\033[91m'
- PURPLE = '\033[35m'
- GREEN = '\033[92m'
- YELLOW = '\033[93m'
- BLUE = '\033[94m'
- BROWN = '\033[33m'
- ENDC = '\033[0m'
- BOLD = '\033[1m'
- UNDERLINE = '\033[4m'
+ RED = "\033[91m"
+ PURPLE = "\033[35m"
+ GREEN = "\033[92m"
+ YELLOW = "\033[93m"
+ BLUE = "\033[94m"
+ BROWN = "\033[33m"
+ ENDC = "\033[0m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
@staticmethod
- def colorize(value, color=''):
+ def colorize(value, color=""):
# TODO: use list of supported terminals
- term = os.environ.get('TERM', None)
+ term = os.environ.get("TERM", None)
if term is None or term.lower() not in TERMINAL_SUPPORTS_ANSI_CODES:
# Terminal doesn't support colors
@@ -58,33 +54,33 @@ def colorize(value, color=''):
if DISABLED or not color:
return value
- return '%s%s%s' % (color, value, DisplayColors.ENDC)
+ return "%s%s%s" % (color, value, DisplayColors.ENDC)
# Lookup table
STATUS_LOOKUP = {
- 'succeeded': DisplayColors.GREEN,
- 'delayed': DisplayColors.BLUE,
- 'failed': DisplayColors.RED,
- 'timeout': DisplayColors.BROWN,
- 'running': DisplayColors.YELLOW
+ "succeeded": DisplayColors.GREEN,
+ "delayed": DisplayColors.BLUE,
+ "failed": DisplayColors.RED,
+ "timeout": DisplayColors.BROWN,
+ "running": DisplayColors.YELLOW,
}
def format_status(value):
# Support status values with elapsed info
- split = value.split('(', 1)
+ split = value.split("(", 1)
if len(split) == 2:
status = split[0].strip()
- remainder = '(' + split[1]
+ remainder = "(" + split[1]
else:
status = value
- remainder = ''
+ remainder = ""
color = STATUS_LOOKUP.get(status, DisplayColors.YELLOW)
result = DisplayColors.colorize(status, color)
if remainder:
- result = result + ' ' + remainder
+ result = result + " " + remainder
return result
diff --git a/st2client/st2client/utils/date.py b/st2client/st2client/utils/date.py
index b19e27f3ec..3a76a44c81 100644
--- a/st2client/st2client/utils/date.py
+++ b/st2client/st2client/utils/date.py
@@ -20,10 +20,7 @@
from st2client.config import get_config
-__all__ = [
- 'parse',
- 'format_isodate'
-]
+__all__ = ["parse", "format_isodate"]
def add_utc_tz(dt):
@@ -39,7 +36,7 @@ def format_dt(dt):
"""
Format datetime object for human friendly representation.
"""
- value = dt.strftime('%a, %d %b %Y %H:%M:%S %Z')
+ value = dt.strftime("%a, %d %b %Y %H:%M:%S %Z")
return value
@@ -52,7 +49,7 @@ def format_isodate(value, timezone=None):
:rtype: ``str``
"""
if not value:
- return ''
+ return ""
# For some reason pylint thinks it returns a tuple but it returns a datetime object
dt = dateutil.parser.parse(str(value))
@@ -70,6 +67,6 @@ def format_isodate_for_user_timezone(value):
specific in the config.
"""
config = get_config()
- timezone = config.get('cli', {}).get('timezone', 'UTC')
+ timezone = config.get("cli", {}).get("timezone", "UTC")
result = format_isodate(value=value, timezone=timezone)
return result
diff --git a/st2client/st2client/utils/httpclient.py b/st2client/st2client/utils/httpclient.py
index 089f6b88d6..6af6595ec5 100644
--- a/st2client/st2client/utils/httpclient.py
+++ b/st2client/st2client/utils/httpclient.py
@@ -27,38 +27,41 @@
def add_ssl_verify_to_kwargs(func):
def decorate(*args, **kwargs):
- if isinstance(args[0], HTTPClient) and 'https' in getattr(args[0], 'root', ''):
- cacert = getattr(args[0], 'cacert', None)
- kwargs['verify'] = cacert if cacert is not None else False
+ if isinstance(args[0], HTTPClient) and "https" in getattr(args[0], "root", ""):
+ cacert = getattr(args[0], "cacert", None)
+ kwargs["verify"] = cacert if cacert is not None else False
return func(*args, **kwargs)
+
return decorate
def add_auth_token_to_headers(func):
def decorate(*args, **kwargs):
- headers = kwargs.get('headers', dict())
+ headers = kwargs.get("headers", dict())
- token = kwargs.pop('token', None)
+ token = kwargs.pop("token", None)
if token:
- headers['X-Auth-Token'] = str(token)
- kwargs['headers'] = headers
+ headers["X-Auth-Token"] = str(token)
+ kwargs["headers"] = headers
- api_key = kwargs.pop('api_key', None)
+ api_key = kwargs.pop("api_key", None)
if api_key:
- headers['St2-Api-Key'] = str(api_key)
- kwargs['headers'] = headers
+ headers["St2-Api-Key"] = str(api_key)
+ kwargs["headers"] = headers
return func(*args, **kwargs)
+
return decorate
def add_json_content_type_to_headers(func):
def decorate(*args, **kwargs):
- headers = kwargs.get('headers', dict())
- content_type = headers.get('content-type', 'application/json')
- headers['content-type'] = content_type
- kwargs['headers'] = headers
+ headers = kwargs.get("headers", dict())
+ content_type = headers.get("content-type", "application/json")
+ headers["content-type"] = content_type
+ kwargs["headers"] = headers
return func(*args, **kwargs)
+
return decorate
@@ -71,12 +74,11 @@ def get_url_without_trailing_slash(value):
:rtype: ``str``
"""
- result = value[:-1] if value.endswith('/') else value
+ result = value[:-1] if value.endswith("/") else value
return result
class HTTPClient(object):
-
def __init__(self, root, cacert=None, debug=False):
self.root = get_url_without_trailing_slash(root)
self.cacert = cacert
@@ -136,30 +138,30 @@ def _response_hook(self, response):
print("# -------- begin %d response ----------" % (id(self)))
print(response.text)
print("# -------- end %d response ------------" % (id(self)))
- print('')
+ print("")
return response
def _get_curl_line_for_request(self, request):
- parts = ['curl']
+ parts = ["curl"]
# method
method = request.method.upper()
- if method in ['HEAD']:
- parts.extend(['--head'])
+ if method in ["HEAD"]:
+ parts.extend(["--head"])
else:
- parts.extend(['-X', pquote(method)])
+ parts.extend(["-X", pquote(method)])
# headers
for key, value in request.headers.items():
- parts.extend(['-H ', pquote('%s: %s' % (key, value))])
+ parts.extend(["-H ", pquote("%s: %s" % (key, value))])
# body
if request.body:
- parts.extend(['--data-binary', pquote(request.body)])
+ parts.extend(["--data-binary", pquote(request.body)])
# URL
parts.extend([pquote(request.url)])
- curl_line = ' '.join(parts)
+ curl_line = " ".join(parts)
return curl_line
diff --git a/st2client/st2client/utils/interactive.py b/st2client/st2client/utils/interactive.py
index 35065e5d94..7e6f81b29b 100644
--- a/st2client/st2client/utils/interactive.py
+++ b/st2client/st2client/utils/interactive.py
@@ -28,8 +28,8 @@
from six.moves import range
-POSITIVE_BOOLEAN = {'1', 'y', 'yes', 'true'}
-NEGATIVE_BOOLEAN = {'0', 'n', 'no', 'nope', 'nah', 'false'}
+POSITIVE_BOOLEAN = {"1", "y", "yes", "true"}
+NEGATIVE_BOOLEAN = {"0", "n", "no", "nope", "nah", "false"}
class ReaderNotImplemented(OperationFailureException):
@@ -58,10 +58,8 @@ class StringReader(object):
def __init__(self, name, spec, prefix=None, secret=False, **kw):
self.name = name
self.spec = spec
- self.prefix = prefix or ''
- self.options = {
- 'is_password': secret
- }
+ self.prefix = prefix or ""
+ self.options = {"is_password": secret}
self._construct_description()
self._construct_template()
@@ -84,7 +82,7 @@ def read(self):
message = self.template.format(self.prefix + self.name, **self.spec)
response = prompt(message, **self.options)
- result = self.spec.get('default', None)
+ result = self.spec.get("default", None)
if response:
result = self._transform_response(response)
@@ -92,20 +90,21 @@ def read(self):
return result
def _construct_description(self):
- if 'description' in self.spec:
+ if "description" in self.spec:
+
def get_bottom_toolbar_tokens(cli):
- return [(token.Token.Toolbar, self.spec['description'])]
+ return [(token.Token.Toolbar, self.spec["description"])]
- self.options['get_bottom_toolbar_tokens'] = get_bottom_toolbar_tokens
+ self.options["get_bottom_toolbar_tokens"] = get_bottom_toolbar_tokens
def _construct_template(self):
- self.template = u'{0}: '
+ self.template = "{0}: "
- if 'default' in self.spec:
- self.template = u'{0} [{default}]: '
+ if "default" in self.spec:
+ self.template = "{0} [{default}]: "
def _construct_validators(self):
- self.options['validator'] = MuxValidator([self.validate], self.spec)
+ self.options["validator"] = MuxValidator([self.validate], self.spec)
def _transform_response(self, response):
return response
@@ -114,25 +113,27 @@ def _transform_response(self, response):
class BooleanReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'boolean'
+ return spec.get("type", None) == "boolean"
@staticmethod
def validate(input, spec):
- if not input and (not spec.get('required', None) or spec.get('default', None)):
+ if not input and (not spec.get("required", None) or spec.get("default", None)):
return
if input.lower() not in POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN:
- raise validation.ValidationError(len(input),
- 'Does not look like boolean. Pick from [%s]'
- % ', '.join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN))
+ raise validation.ValidationError(
+ len(input),
+ "Does not look like boolean. Pick from [%s]"
+ % ", ".join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN),
+ )
def _construct_template(self):
- self.template = u'{0} (boolean)'
+ self.template = "{0} (boolean)"
- if 'default' in self.spec:
- self.template += u' [{}]: '.format(self.spec.get('default') and 'y' or 'n')
+ if "default" in self.spec:
+ self.template += " [{}]: ".format(self.spec.get("default") and "y" or "n")
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
if response.lower() in POSITIVE_BOOLEAN:
@@ -141,14 +142,16 @@ def _transform_response(self, response):
return False
# Hopefully, it will never happen
- raise OperationFailureException('Response neither positive no negative. '
- 'Value have not been properly validated.')
+ raise OperationFailureException(
+ "Response neither positive no negative. "
+ "Value have not been properly validated."
+ )
class NumberReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'number'
+ return spec.get("type", None) == "number"
@staticmethod
def validate(input, spec):
@@ -161,12 +164,12 @@ def validate(input, spec):
super(NumberReader, NumberReader).validate(input, spec)
def _construct_template(self):
- self.template = u'{0} (float)'
+ self.template = "{0} (float)"
- if 'default' in self.spec:
- self.template += u' [{default}]: '.format(default=self.spec.get('default'))
+ if "default" in self.spec:
+ self.template += " [{default}]: ".format(default=self.spec.get("default"))
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
return float(response)
@@ -175,7 +178,7 @@ def _transform_response(self, response):
class IntegerReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'integer'
+ return spec.get("type", None) == "integer"
@staticmethod
def validate(input, spec):
@@ -188,12 +191,12 @@ def validate(input, spec):
super(IntegerReader, IntegerReader).validate(input, spec)
def _construct_template(self):
- self.template = u'{0} (integer)'
+ self.template = "{0} (integer)"
- if 'default' in self.spec:
- self.template += u' [{default}]: '.format(default=self.spec.get('default'))
+ if "default" in self.spec:
+ self.template += " [{default}]: ".format(default=self.spec.get("default"))
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
return int(response)
@@ -205,71 +208,71 @@ def __init__(self, *args, **kwargs):
@staticmethod
def condition(spec):
- return spec.get('secret', None)
+ return spec.get("secret", None)
def _construct_template(self):
- self.template = u'{0} (secret)'
+ self.template = "{0} (secret)"
- if 'default' in self.spec:
- self.template += u' [{default}]: '.format(default=self.spec.get('default'))
+ if "default" in self.spec:
+ self.template += " [{default}]: ".format(default=self.spec.get("default"))
else:
- self.template += u': '
+ self.template += ": "
class EnumReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('enum', None)
+ return spec.get("enum", None)
@staticmethod
def validate(input, spec):
- if not input and (not spec.get('required', None) or spec.get('default', None)):
+ if not input and (not spec.get("required", None) or spec.get("default", None)):
return
if not input.isdigit():
- raise validation.ValidationError(len(input), 'Not a number')
+ raise validation.ValidationError(len(input), "Not a number")
- enum = spec.get('enum')
+ enum = spec.get("enum")
try:
enum[int(input)]
except IndexError:
- raise validation.ValidationError(len(input), 'Out of bounds')
+ raise validation.ValidationError(len(input), "Out of bounds")
def _construct_template(self):
- self.template = u'{0}: '
+ self.template = "{0}: "
- enum = self.spec.get('enum')
+ enum = self.spec.get("enum")
for index, value in enumerate(enum):
- self.template += u'\n {} - {}'.format(index, value)
+ self.template += "\n {} - {}".format(index, value)
num_options = len(enum)
- more = ''
+ more = ""
if num_options > 3:
num_options = 3
- more = '...'
+ more = "..."
options = [str(i) for i in range(0, num_options)]
- self.template += u'\nChoose from {}{}'.format(', '.join(options), more)
+ self.template += "\nChoose from {}{}".format(", ".join(options), more)
- if 'default' in self.spec:
- self.template += u' [{}]: '.format(enum.index(self.spec.get('default')))
+ if "default" in self.spec:
+ self.template += " [{}]: ".format(enum.index(self.spec.get("default")))
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
- return self.spec.get('enum')[int(response)]
+ return self.spec.get("enum")[int(response)]
class ObjectReader(StringReader):
-
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'object'
+ return spec.get("type", None) == "object"
def read(self):
- prefix = u'{}.'.format(self.name)
+ prefix = "{}.".format(self.name)
- result = InteractiveForm(self.spec.get('properties', {}),
- prefix=prefix, reraise=True).initiate_dialog()
+ result = InteractiveForm(
+ self.spec.get("properties", {}), prefix=prefix, reraise=True
+ ).initiate_dialog()
return result
@@ -277,25 +280,27 @@ def read(self):
class ArrayReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'array'
+ return spec.get("type", None) == "array"
@staticmethod
def validate(input, spec):
- if not input and (not spec.get('required', None) or spec.get('default', None)):
+ if not input and (not spec.get("required", None) or spec.get("default", None)):
return
- for m in re.finditer(r'[^, ]+', input):
+ for m in re.finditer(r"[^, ]+", input):
index, item = m.start(), m.group()
try:
- StringReader.validate(item, spec.get('items', {}))
+ StringReader.validate(item, spec.get("items", {}))
except validation.ValidationError as e:
raise validation.ValidationError(index, six.text_type(e))
def read(self):
- item_type = self.spec.get('items', {}).get('type', 'string')
+ item_type = self.spec.get("items", {}).get("type", "string")
- if item_type not in ['string', 'integer', 'number', 'boolean']:
- message = 'Interactive mode does not support arrays of %s type yet' % item_type
+ if item_type not in ["string", "integer", "number", "boolean"]:
+ message = (
+ "Interactive mode does not support arrays of %s type yet" % item_type
+ )
raise ReaderNotImplemented(message)
result = super(ArrayReader, self).read()
@@ -303,37 +308,46 @@ def read(self):
return result
def _construct_template(self):
- self.template = u'{0} (comma-separated list)'
+ self.template = "{0} (comma-separated list)"
- if 'default' in self.spec:
- self.template += u' [{default}]: '.format(default=','.join(self.spec.get('default')))
+ if "default" in self.spec:
+ self.template += " [{default}]: ".format(
+ default=",".join(self.spec.get("default"))
+ )
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
- return [item.strip() for item in response.split(',')]
+ return [item.strip() for item in response.split(",")]
class ArrayObjectReader(StringReader):
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'array' and spec.get('items', {}).get('type') == 'object'
+ return (
+ spec.get("type", None) == "array"
+ and spec.get("items", {}).get("type") == "object"
+ )
def read(self):
results = []
- properties = self.spec.get('items', {}).get('properties', {})
- message = '~~~ Would you like to add another item to "%s" array / list?' % self.name
+ properties = self.spec.get("items", {}).get("properties", {})
+ message = (
+ '~~~ Would you like to add another item to "%s" array / list?' % self.name
+ )
is_continue = True
index = 0
while is_continue:
- prefix = u'{name}[{index}].'.format(name=self.name, index=index)
- results.append(InteractiveForm(properties,
- prefix=prefix,
- reraise=True).initiate_dialog())
+ prefix = "{name}[{index}].".format(name=self.name, index=index)
+ results.append(
+ InteractiveForm(
+ properties, prefix=prefix, reraise=True
+ ).initiate_dialog()
+ )
index += 1
- if Question(message, {'default': 'y'}).read() != 'y':
+ if Question(message, {"default": "y"}).read() != "y":
is_continue = False
return results
@@ -341,53 +355,55 @@ def read(self):
class ArrayEnumReader(EnumReader):
def __init__(self, name, spec, prefix=None):
- self.items = spec.get('items', {})
+ self.items = spec.get("items", {})
super(ArrayEnumReader, self).__init__(name, spec, prefix)
@staticmethod
def condition(spec):
- return spec.get('type', None) == 'array' and 'enum' in spec.get('items', {})
+ return spec.get("type", None) == "array" and "enum" in spec.get("items", {})
@staticmethod
def validate(input, spec):
- if not input and (not spec.get('required', None) or spec.get('default', None)):
+ if not input and (not spec.get("required", None) or spec.get("default", None)):
return
- for m in re.finditer(r'[^, ]+', input):
+ for m in re.finditer(r"[^, ]+", input):
index, item = m.start(), m.group()
try:
- EnumReader.validate(item, spec.get('items', {}))
+ EnumReader.validate(item, spec.get("items", {}))
except validation.ValidationError as e:
raise validation.ValidationError(index, six.text_type(e))
def _construct_template(self):
- self.template = u'{0}: '
+ self.template = "{0}: "
- enum = self.items.get('enum')
+ enum = self.items.get("enum")
for index, value in enumerate(enum):
- self.template += u'\n {} - {}'.format(index, value)
+ self.template += "\n {} - {}".format(index, value)
num_options = len(enum)
- more = ''
+ more = ""
if num_options > 3:
num_options = 3
- more = '...'
+ more = "..."
options = [str(i) for i in range(0, num_options)]
- self.template += u'\nChoose from {}{}'.format(', '.join(options), more)
+ self.template += "\nChoose from {}{}".format(", ".join(options), more)
- if 'default' in self.spec:
- default_choises = [str(enum.index(item)) for item in self.spec.get('default')]
- self.template += u' [{}]: '.format(', '.join(default_choises))
+ if "default" in self.spec:
+ default_choises = [
+ str(enum.index(item)) for item in self.spec.get("default")
+ ]
+ self.template += " [{}]: ".format(", ".join(default_choises))
else:
- self.template += u': '
+ self.template += ": "
def _transform_response(self, response):
result = []
- for i in (item.strip() for item in response.split(',')):
+ for i in (item.strip() for item in response.split(",")):
if i:
- result.append(self.items.get('enum')[int(i)])
+ result.append(self.items.get("enum")[int(i)])
return result
@@ -403,7 +419,7 @@ class InteractiveForm(object):
ArrayObjectReader,
ArrayReader,
SecretStringReader,
- StringReader
+ StringReader,
]
def __init__(self, schema, prefix=None, reraise=False):
@@ -419,11 +435,11 @@ def initiate_dialog(self):
try:
result[field] = self._read_field(field)
except ReaderNotImplemented as e:
- print('%s. Skipping...' % six.text_type(e))
+ print("%s. Skipping..." % six.text_type(e))
except DialogInterrupted:
if self.reraise:
raise
- print('Dialog interrupted.')
+ print("Dialog interrupted.")
return result
@@ -438,7 +454,7 @@ def _read_field(self, field):
break
if not reader:
- raise ReaderNotImplemented('No reader for the field spec')
+ raise ReaderNotImplemented("No reader for the field spec")
try:
return reader.read()
diff --git a/st2client/st2client/utils/jsutil.py b/st2client/st2client/utils/jsutil.py
index 7aaf20dfe0..1d98ab8f46 100644
--- a/st2client/st2client/utils/jsutil.py
+++ b/st2client/st2client/utils/jsutil.py
@@ -48,7 +48,7 @@ def _get_value_simple(doc, key):
Returns the extracted value from the key specified (if found)
Returns None if the key can not be found
"""
- split_key = key.split('.')
+ split_key = key.split(".")
if not split_key:
return None
@@ -82,8 +82,9 @@ def get_value(doc, key):
raise ValueError("key is None or empty: '{}'".format(key))
if not isinstance(doc, dict):
- raise ValueError("doc is not an instance of dict: type={} value='{}'".format(type(doc),
- doc))
+ raise ValueError(
+ "doc is not an instance of dict: type={} value='{}'".format(type(doc), doc)
+ )
# jsonpath_rw can be very slow when processing expressions.
# In the case of a simple expression we've created a "fast path" that avoids
# the complexity introduced by running jsonpath_rw code.
@@ -113,12 +114,12 @@ def get_kvps(doc, keys):
value = get_value(doc, key)
if value is not None:
nested = new_doc
- while '.' in key:
- attr = key[:key.index('.')]
+ while "." in key:
+ attr = key[: key.index(".")]
if attr not in nested:
nested[attr] = {}
nested = nested[attr]
- key = key[key.index('.') + 1:]
+ key = key[key.index(".") + 1 :]
nested[key] = value
return new_doc
diff --git a/st2client/st2client/utils/logging.py b/st2client/st2client/utils/logging.py
index dd8b8b9e44..8328a5c55e 100644
--- a/st2client/st2client/utils/logging.py
+++ b/st2client/st2client/utils/logging.py
@@ -18,9 +18,9 @@
import logging
__all__ = [
- 'LogLevelFilter',
- 'set_log_level_for_all_handlers',
- 'set_log_level_for_all_loggers'
+ "LogLevelFilter",
+ "set_log_level_for_all_handlers",
+ "set_log_level_for_all_loggers",
]
diff --git a/st2client/st2client/utils/misc.py b/st2client/st2client/utils/misc.py
index e8623b3070..62c7b1a61f 100644
--- a/st2client/st2client/utils/misc.py
+++ b/st2client/st2client/utils/misc.py
@@ -18,9 +18,7 @@
import six
-__all__ = [
- 'merge_dicts'
-]
+__all__ = ["merge_dicts"]
def merge_dicts(d1, d2):
diff --git a/st2client/st2client/utils/schema.py b/st2client/st2client/utils/schema.py
index 33142daa71..2cf7d5b231 100644
--- a/st2client/st2client/utils/schema.py
+++ b/st2client/st2client/utils/schema.py
@@ -17,36 +17,30 @@
TYPE_TABLE = {
- dict: 'object',
- list: 'array',
- int: 'integer',
- str: 'string',
- float: 'number',
- bool: 'boolean',
- type(None): 'null',
+ dict: "object",
+ list: "array",
+ int: "integer",
+ str: "string",
+ float: "number",
+ bool: "boolean",
+ type(None): "null",
}
if sys.version_info[0] < 3:
- TYPE_TABLE[unicode] = 'string' # noqa # pylint: disable=E0602
+ TYPE_TABLE[unicode] = "string" # noqa # pylint: disable=E0602
def _dict_to_schema(item):
schema = {}
for key, value in item.iteritems():
if isinstance(value, dict):
- schema[key] = {
- 'type': 'object',
- 'parameters': _dict_to_schema(value)
- }
+ schema[key] = {"type": "object", "parameters": _dict_to_schema(value)}
else:
- schema[key] = {
- 'type': TYPE_TABLE[type(value)]
- }
+ schema[key] = {"type": TYPE_TABLE[type(value)]}
return schema
def render_output_schema_from_output(output):
- """Given an action output produce a reasonable schema to match.
- """
+ """Given an action output produce a reasonable schema to match."""
return _dict_to_schema(output)
diff --git a/st2client/st2client/utils/strutil.py b/st2client/st2client/utils/strutil.py
index d6bc23d9cc..0bb970ff3e 100644
--- a/st2client/st2client/utils/strutil.py
+++ b/st2client/st2client/utils/strutil.py
@@ -24,9 +24,9 @@ def unescape(s):
This function unescapes those chars.
"""
if isinstance(s, six.string_types):
- s = s.replace('\\n', '\n')
- s = s.replace('\\r', '\r')
- s = s.replace('\\"', '\"')
+ s = s.replace("\\n", "\n")
+ s = s.replace("\\r", "\r")
+ s = s.replace('\\"', '"')
return s
@@ -39,14 +39,14 @@ def dedupe_newlines(s):
"""
if isinstance(s, six.string_types):
- s = s.replace('\n\n', '\n')
+ s = s.replace("\n\n", "\n")
return s
def strip_carriage_returns(s):
if isinstance(s, six.string_types):
- s = s.replace('\\r', '')
- s = s.replace('\r', '')
+ s = s.replace("\\r", "")
+ s = s.replace("\r", "")
return s
diff --git a/st2client/st2client/utils/terminal.py b/st2client/st2client/utils/terminal.py
index 555753fc95..6ce28a4d74 100644
--- a/st2client/st2client/utils/terminal.py
+++ b/st2client/st2client/utils/terminal.py
@@ -24,11 +24,7 @@
DEFAULT_TERMINAL_SIZE_COLUMNS = 150
-__all__ = [
- 'DEFAULT_TERMINAL_SIZE_COLUMNS',
-
- 'get_terminal_size_columns'
-]
+__all__ = ["DEFAULT_TERMINAL_SIZE_COLUMNS", "get_terminal_size_columns"]
def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS):
@@ -48,7 +44,7 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS):
# This way it's consistent with upstream implementation. In the past, our implementation
# checked those variables at the end as a fall back.
try:
- columns = os.environ['COLUMNS']
+ columns = os.environ["COLUMNS"]
return int(columns)
except (KeyError, ValueError):
pass
@@ -56,8 +52,9 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS):
def ioctl_GWINSZ(fd):
import fcntl
import termios
+
# Return a tuple (lines, columns)
- return struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
+ return struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234"))
# 2. try stdin, stdout, stderr
for fd in (0, 1, 2):
@@ -78,10 +75,12 @@ def ioctl_GWINSZ(fd):
# 4. try `stty size`
try:
- process = subprocess.Popen(['stty', 'size'],
- shell=False,
- stdout=subprocess.PIPE,
- stderr=open(os.devnull, 'w'))
+ process = subprocess.Popen(
+ ["stty", "size"],
+ shell=False,
+ stdout=subprocess.PIPE,
+ stderr=open(os.devnull, "w"),
+ )
result = process.communicate()
if process.returncode == 0:
return tuple(int(x) for x in result[0].split())[1]
@@ -101,23 +100,23 @@ def __exit__(self, type, value, traceback):
return self.close()
def add_stage(self, status, name):
- self._write('\t[{:^20}] {}'.format(format_status(status), name))
+ self._write("\t[{:^20}] {}".format(format_status(status), name))
def update_stage(self, status, name):
- self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True)
+ self._write("\t[{:^20}] {}".format(format_status(status), name), override=True)
def finish_stage(self, status, name):
- self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True)
+ self._write("\t[{:^20}] {}".format(format_status(status), name), override=True)
def close(self):
if self.dirty:
- self._write('\n')
+ self._write("\n")
def _write(self, string, override=False):
if override:
- sys.stdout.write('\r')
+ sys.stdout.write("\r")
else:
- sys.stdout.write('\n')
+ sys.stdout.write("\n")
sys.stdout.write(string)
sys.stdout.flush()
diff --git a/st2client/st2client/utils/types.py b/st2client/st2client/utils/types.py
index 5c25990a6e..ad70f078b9 100644
--- a/st2client/st2client/utils/types.py
+++ b/st2client/st2client/utils/types.py
@@ -20,17 +20,14 @@
from __future__ import absolute_import
import collections
-__all__ = [
- 'OrderedSet'
-]
+__all__ = ["OrderedSet"]
class OrderedSet(collections.MutableSet):
-
def __init__(self, iterable=None):
self.end = end = []
- end += [None, end, end] # sentinel node for doubly linked list
- self.map = {} # key --> [key, prev, next]
+ end += [None, end, end] # sentinel node for doubly linked list
+ self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
@@ -68,15 +65,15 @@ def __reversed__(self):
def pop(self, last=True):
if not self:
- raise KeyError('set is empty')
+ raise KeyError("set is empty")
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
def __repr__(self):
if not self:
- return '%s()' % (self.__class__.__name__,)
- return '%s(%r)' % (self.__class__.__name__, list(self))
+ return "%s()" % (self.__class__.__name__,)
+ return "%s(%r)" % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
diff --git a/st2client/tests/base.py b/st2client/tests/base.py
index 307f00f74b..80c14efef6 100644
--- a/st2client/tests/base.py
+++ b/st2client/tests/base.py
@@ -27,26 +27,22 @@
LOG = logging.getLogger(__name__)
-FAKE_ENDPOINT = 'http://127.0.0.1:8268'
+FAKE_ENDPOINT = "http://127.0.0.1:8268"
RESOURCES = [
{
"id": "123",
"name": "abc",
},
- {
- "id": "456",
- "name": "def"
- }
+ {"id": "456", "name": "def"},
]
class FakeResource(models.Resource):
- _plural = 'FakeResources'
+ _plural = "FakeResources"
class FakeResponse(object):
-
def __init__(self, text, status_code, reason, *args):
self.text = text
self.status_code = status_code
@@ -64,8 +60,7 @@ def raise_for_status(self):
class FakeClient(object):
def __init__(self):
self.managers = {
- 'FakeResource': models.ResourceManager(FakeResource,
- FAKE_ENDPOINT)
+ "FakeResource": models.ResourceManager(FakeResource, FAKE_ENDPOINT)
}
@@ -75,23 +70,32 @@ def __init__(self):
class BaseCLITestCase(unittest2.TestCase):
- capture_output = True # if True, stdout and stderr are saved to self.stdout and self.stderr
+ capture_output = (
+ True # if True, stdout and stderr are saved to self.stdout and self.stderr
+ )
stdout = six.moves.StringIO()
stderr = six.moves.StringIO()
- DEFAULT_SKIP_CONFIG = '1'
+ DEFAULT_SKIP_CONFIG = "1"
def setUp(self):
super(BaseCLITestCase, self).setUp()
# Setup environment
- for var in ['ST2_BASE_URL', 'ST2_AUTH_URL', 'ST2_API_URL', 'ST2_STREAM_URL',
- 'ST2_AUTH_TOKEN', 'ST2_CONFIG_FILE', 'ST2_API_KEY']:
+ for var in [
+ "ST2_BASE_URL",
+ "ST2_AUTH_URL",
+ "ST2_API_URL",
+ "ST2_STREAM_URL",
+ "ST2_AUTH_TOKEN",
+ "ST2_CONFIG_FILE",
+ "ST2_API_KEY",
+ ]:
if var in os.environ:
del os.environ[var]
- os.environ['ST2_CLI_SKIP_CONFIG'] = self.DEFAULT_SKIP_CONFIG
+ os.environ["ST2_CLI_SKIP_CONFIG"] = self.DEFAULT_SKIP_CONFIG
if self.capture_output:
# Make sure we reset it for each test class instance
@@ -134,5 +138,5 @@ def _reset_output_streams(self):
self.stderr.truncate()
# Verify it has been reset correctly
- self.assertEqual(self.stdout.getvalue(), '')
- self.assertEqual(self.stderr.getvalue(), '')
+ self.assertEqual(self.stdout.getvalue(), "")
+ self.assertEqual(self.stderr.getvalue(), "")
diff --git a/st2client/tests/fixtures/loader.py b/st2client/tests/fixtures/loader.py
index a471d8e710..049a82b7a6 100644
--- a/st2client/tests/fixtures/loader.py
+++ b/st2client/tests/fixtures/loader.py
@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except ImportError:
@@ -24,8 +25,8 @@
import yaml
-ALLOWED_EXTS = ['.json', '.yaml', '.yml', '.txt']
-PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load}
+ALLOWED_EXTS = [".json", ".yaml", ".yml", ".txt"]
+PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load}
def get_fixtures_base_path():
@@ -44,12 +45,14 @@ def load_content(file_path):
file_name, file_ext = os.path.splitext(file_path)
if file_ext not in ALLOWED_EXTS:
- raise Exception('Unsupported meta type %s, file %s. Allowed: %s' %
- (file_ext, file_path, ALLOWED_EXTS))
+ raise Exception(
+ "Unsupported meta type %s, file %s. Allowed: %s"
+ % (file_ext, file_path, ALLOWED_EXTS)
+ )
parser_func = PARSER_FUNCS.get(file_ext, None)
- with open(file_path, 'r') as fd:
+ with open(file_path, "r") as fd:
return parser_func(fd) if parser_func else fd.read()
@@ -75,7 +78,7 @@ def load_fixtures(fixtures_dict=None):
for fixture_type, fixtures in six.iteritems(fixtures_dict):
loaded_fixtures = {}
for fixture in fixtures:
- fixture_path = fixtures_base_path + '/' + fixture
+ fixture_path = fixtures_base_path + "/" + fixture
fixture_dict = load_content(fixture_path)
loaded_fixtures[fixture] = fixture_dict
all_fixtures[fixture_type] = loaded_fixtures
diff --git a/st2client/tests/unit/test_action.py b/st2client/tests/unit/test_action.py
index e02c1ea1ca..1bb8be3810 100644
--- a/st2client/tests/unit/test_action.py
+++ b/st2client/tests/unit/test_action.py
@@ -34,7 +34,7 @@
"float": {"type": "number"},
"json": {"type": "object"},
"list": {"type": "array"},
- "str": {"type": "string"}
+ "str": {"type": "string"},
},
"name": "mock-runner1",
}
@@ -46,7 +46,7 @@
"parameters": {},
"enabled": True,
"entry_point": "",
- "pack": "mockety"
+ "pack": "mockety",
}
RUNNER2 = {
@@ -65,475 +65,583 @@
"float": {"type": "number"},
"json": {"type": "object"},
"list": {"type": "array"},
- "str": {"type": "string"}
+ "str": {"type": "string"},
},
"enabled": True,
"entry_point": "",
- "pack": "mockety"
+ "pack": "mockety",
}
LIVE_ACTION = {
- 'action': 'mockety.mock',
- 'status': 'complete',
- 'result': {'stdout': 'non-empty'}
+ "action": "mockety.mock",
+ "status": "complete",
+ "result": {"stdout": "non-empty"},
}
def get_by_name(name, **kwargs):
- if name == 'mock-runner1':
+ if name == "mock-runner1":
return models.RunnerType(**RUNNER1)
- if name == 'mock-runner2':
+ if name == "mock-runner2":
return models.RunnerType(**RUNNER2)
def get_by_ref(**kwargs):
- ref = kwargs.get('ref_or_id', None)
+ ref = kwargs.get("ref_or_id", None)
if not ref:
raise Exception('Actions must be referred to by "ref".')
- if ref == 'mockety.mock1':
+ if ref == "mockety.mock1":
return models.Action(**ACTION1)
- if ref == 'mockety.mock2':
+ if ref == "mockety.mock2":
return models.Action(**ACTION2)
class ActionCommandTestCase(base.BaseCLITestCase):
-
def __init__(self, *args, **kwargs):
super(ActionCommandTestCase, self).__init__(*args, **kwargs)
self.shell = shell.Shell()
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_bool_conversion(self):
- self.shell.run(['run', 'mockety.mock1', 'bool=false'])
- expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'bool': False}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock1", "bool=false"])
+ expected = {
+ "action": "mockety.mock1",
+ "user": None,
+ "parameters": {"bool": False},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_integer_conversion(self):
- self.shell.run(['run', 'mockety.mock1', 'int=30'])
- expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'int': 30}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock1", "int=30"])
+ expected = {"action": "mockety.mock1", "user": None, "parameters": {"int": 30}}
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_float_conversion(self):
- self.shell.run(['run', 'mockety.mock1', 'float=3.01'])
- expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'float': 3.01}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock1", "float=3.01"])
+ expected = {
+ "action": "mockety.mock1",
+ "user": None,
+ "parameters": {"float": 3.01},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_json_conversion(self):
- self.shell.run(['run', 'mockety.mock1', 'json={"a":1}'])
- expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'json': {'a': 1}}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock1", 'json={"a":1}'])
+ expected = {
+ "action": "mockety.mock1",
+ "user": None,
+ "parameters": {"json": {"a": 1}},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_array_conversion(self):
- self.shell.run(['run', 'mockety.mock1', 'list=one,two,three'])
+ self.shell.run(["run", "mockety.mock1", "list=one,two,three"])
expected = {
- 'action': 'mockety.mock1',
- 'user': None,
- 'parameters': {
- 'list': [
- 'one',
- 'two',
- 'three'
- ]
- }
+ "action": "mockety.mock1",
+ "user": None,
+ "parameters": {"list": ["one", "two", "three"]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_runner_param_array_object_conversion(self):
self.shell.run(
[
- 'run',
- 'mockety.mock1',
- 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]'
+ "run",
+ "mockety.mock1",
+ 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]',
]
)
expected = {
- 'action': 'mockety.mock1',
- 'user': None,
- 'parameters': {
- 'list': [
- {
- 'foo': 1,
- 'ponies': 'rainbows'
- },
- {
- 'pluto': False,
- 'earth': True
- }
+ "action": "mockety.mock1",
+ "user": None,
+ "parameters": {
+ "list": [
+ {"foo": 1, "ponies": "rainbows"},
+ {"pluto": False, "earth": True},
]
- }
+ },
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_bool_conversion(self):
- self.shell.run(['run', 'mockety.mock2', 'bool=false'])
- expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'bool': False}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock2", "bool=false"])
+ expected = {
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"bool": False},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_integer_conversion(self):
- self.shell.run(['run', 'mockety.mock2', 'int=30'])
- expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'int': 30}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock2", "int=30"])
+ expected = {"action": "mockety.mock2", "user": None, "parameters": {"int": 30}}
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_float_conversion(self):
- self.shell.run(['run', 'mockety.mock2', 'float=3.01'])
- expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'float': 3.01}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock2", "float=3.01"])
+ expected = {
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"float": 3.01},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_json_conversion(self):
- self.shell.run(['run', 'mockety.mock2', 'json={"a":1}'])
- expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'json': {'a': 1}}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock2", 'json={"a":1}'])
+ expected = {
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"json": {"a": 1}},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_array_conversion(self):
- self.shell.run(['run', 'mockety.mock2', 'list=one,two,three'])
+ self.shell.run(["run", "mockety.mock2", "list=one,two,three"])
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- 'one',
- 'two',
- 'three'
- ]
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"list": ["one", "two", "three"]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_array_conversion_single_element_str(self):
- self.shell.run(['run', 'mockety.mock2', 'list=one'])
+ self.shell.run(["run", "mockety.mock2", "list=one"])
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- 'one'
- ]
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"list": ["one"]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_array_conversion_single_element_int(self):
- self.shell.run(['run', 'mockety.mock2', 'list=1'])
+ self.shell.run(["run", "mockety.mock2", "list=1"])
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- 1
- ]
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"list": [1]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_array_object_conversion(self):
self.shell.run(
[
- 'run',
- 'mockety.mock2',
- 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]'
+ "run",
+ "mockety.mock2",
+ 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]',
]
)
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- {
- 'foo': 1,
- 'ponies': 'rainbows'
- },
- {
- 'pluto': False,
- 'earth': True
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {
+ "list": [
+ {"foo": 1, "ponies": "rainbows"},
+ {"pluto": False, "earth": True},
]
- }
+ },
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_dict_conversion_flag(self):
- """Ensure that the automatic conversion to dict based on colons only occurs with the flag
- """
+ """Ensure that the automatic conversion to dict based on colons only occurs with the flag"""
self.shell.run(
- [
- 'run',
- 'mockety.mock2',
- 'list=key1:value1,key2:value2',
- '--auto-dict'
- ]
+ ["run", "mockety.mock2", "list=key1:value1,key2:value2", "--auto-dict"]
)
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- {
- 'key1': 'value1',
- 'key2': 'value2'
- }
- ]
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"list": [{"key1": "value1", "key2": "value2"}]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
- self.shell.run(
- [
- 'run',
- 'mockety.mock2',
- 'list=key1:value1,key2:value2'
- ]
- )
+ self.shell.run(["run", "mockety.mock2", "list=key1:value1,key2:value2"])
expected = {
- 'action': 'mockety.mock2',
- 'user': None,
- 'parameters': {
- 'list': [
- 'key1:value1',
- 'key2:value2'
- ]
- }
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"list": ["key1:value1", "key2:value2"]},
}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_param_value_with_equal_sign(self):
- self.shell.run(['run', 'mockety.mock2', 'key=foo=bar&ponies=unicorns'])
- expected = {'action': 'mockety.mock2', 'user': None,
- 'parameters': {'key': 'foo=bar&ponies=unicorns'}}
- httpclient.HTTPClient.post.assert_called_with('/executions', expected)
+ self.shell.run(["run", "mockety.mock2", "key=foo=bar&ponies=unicorns"])
+ expected = {
+ "action": "mockety.mock2",
+ "user": None,
+ "parameters": {"key": "foo=bar&ponies=unicorns"},
+ }
+ httpclient.HTTPClient.post.assert_called_with("/executions", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_cancel_single_execution(self):
- self.shell.run(['execution', 'cancel', '123'])
- httpclient.HTTPClient.delete.assert_called_with('/executions/123')
+ self.shell.run(["execution", "cancel", "123"])
+ httpclient.HTTPClient.delete.assert_called_with("/executions/123")
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_cancel_multiple_executions(self):
- self.shell.run(['execution', 'cancel', '123', '456', '789'])
- calls = [mock.call('/executions/123'),
- mock.call('/executions/456'),
- mock.call('/executions/789')]
+ self.shell.run(["execution", "cancel", "123", "456", "789"])
+ calls = [
+ mock.call("/executions/123"),
+ mock.call("/executions/456"),
+ mock.call("/executions/789"),
+ ]
httpclient.HTTPClient.delete.assert_has_calls(calls)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_pause_single_execution(self):
- self.shell.run(['execution', 'pause', '123'])
- expected = {'status': 'pausing'}
- httpclient.HTTPClient.put.assert_called_with('/executions/123', expected)
+ self.shell.run(["execution", "pause", "123"])
+ expected = {"status": "pausing"}
+ httpclient.HTTPClient.put.assert_called_with("/executions/123", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_pause_multiple_executions(self):
- self.shell.run(['execution', 'pause', '123', '456', '789'])
- expected = {'status': 'pausing'}
- calls = [mock.call('/executions/123', expected),
- mock.call('/executions/456', expected),
- mock.call('/executions/789', expected)]
+ self.shell.run(["execution", "pause", "123", "456", "789"])
+ expected = {"status": "pausing"}
+ calls = [
+ mock.call("/executions/123", expected),
+ mock.call("/executions/456", expected),
+ mock.call("/executions/789", expected),
+ ]
httpclient.HTTPClient.put.assert_has_calls(calls)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_resume_single_execution(self):
- self.shell.run(['execution', 'resume', '123'])
- expected = {'status': 'resuming'}
- httpclient.HTTPClient.put.assert_called_with('/executions/123', expected)
+ self.shell.run(["execution", "resume", "123"])
+ expected = {"status": "resuming"}
+ httpclient.HTTPClient.put.assert_called_with("/executions/123", expected)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(side_effect=get_by_name))
+ models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK")
+ ),
+ )
def test_resume_multiple_executions(self):
- self.shell.run(['execution', 'resume', '123', '456', '789'])
- expected = {'status': 'resuming'}
- calls = [mock.call('/executions/123', expected),
- mock.call('/executions/456', expected),
- mock.call('/executions/789', expected)]
+ self.shell.run(["execution", "resume", "123", "456", "789"])
+ expected = {"status": "resuming"}
+ calls = [
+ mock.call("/executions/123", expected),
+ mock.call("/executions/456", expected),
+ mock.call("/executions/789", expected),
+ ]
httpclient.HTTPClient.put.assert_has_calls(calls)
diff --git a/st2client/tests/unit/test_action_alias.py b/st2client/tests/unit/test_action_alias.py
index a360fd5139..753b4e71a8 100644
--- a/st2client/tests/unit/test_action_alias.py
+++ b/st2client/tests/unit/test_action_alias.py
@@ -29,9 +29,7 @@
"execution": {
"id": "mock-id",
},
- "actionalias": {
- "ref": "mock-ref"
- }
+ "actionalias": {"ref": "mock-ref"},
}
]
}
@@ -43,20 +41,26 @@ def __init__(self, *args, **kwargs):
self.shell = shell.Shell()
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT),
- 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT), 200, "OK"
+ )
+ ),
+ )
def test_match_and_execute(self):
- ret = self.shell.run(['action-alias', 'execute', "run whoami on localhost"])
+ ret = self.shell.run(["action-alias", "execute", "run whoami on localhost"])
self.assertEqual(ret, 0)
expected_args = {
- 'command': 'run whoami on localhost',
- 'user': '',
- 'source_channel': 'cli'
+ "command": "run whoami on localhost",
+ "user": "",
+ "source_channel": "cli",
}
- httpclient.HTTPClient.post.assert_called_with('/aliasexecution/match_and_execute',
- expected_args)
+ httpclient.HTTPClient.post.assert_called_with(
+ "/aliasexecution/match_and_execute", expected_args
+ )
mock_stdout = self.stdout.getvalue()
diff --git a/st2client/tests/unit/test_app.py b/st2client/tests/unit/test_app.py
index eb1a67242e..217d3875ad 100644
--- a/st2client/tests/unit/test_app.py
+++ b/st2client/tests/unit/test_app.py
@@ -26,33 +26,33 @@
class BaseCLIAppTestCase(unittest2.TestCase):
- @mock.patch('os.path.isfile', mock.Mock())
+ @mock.patch("os.path.isfile", mock.Mock())
def test_cli_config_file_path(self):
app = BaseCLIApp()
args = mock.Mock()
# 1. Absolute path
- args.config_file = '/tmp/full/abs/path/config.ini'
+ args.config_file = "/tmp/full/abs/path/config.ini"
result = app._get_config_file_path(args=args)
self.assertEqual(result, args.config_file)
- args.config_file = '/home/user/st2/config.ini'
+ args.config_file = "/home/user/st2/config.ini"
result = app._get_config_file_path(args=args)
self.assertEqual(result, args.config_file)
# 2. Path relative to user home directory, should get expanded
- args.config_file = '~/.st2/config.ini'
+ args.config_file = "~/.st2/config.ini"
result = app._get_config_file_path(args=args)
- expected = os.path.join(os.path.expanduser('~' + USER), '.st2/config.ini')
+ expected = os.path.join(os.path.expanduser("~" + USER), ".st2/config.ini")
self.assertEqual(result, expected)
# 3. Relative path (should get converted to absolute one)
- args.config_file = 'config.ini'
+ args.config_file = "config.ini"
result = app._get_config_file_path(args=args)
- expected = os.path.join(os.getcwd(), 'config.ini')
+ expected = os.path.join(os.getcwd(), "config.ini")
self.assertEqual(result, expected)
- args.config_file = '.st2/config.ini'
+ args.config_file = ".st2/config.ini"
result = app._get_config_file_path(args=args)
- expected = os.path.join(os.getcwd(), '.st2/config.ini')
+ expected = os.path.join(os.getcwd(), ".st2/config.ini")
self.assertEqual(result, expected)
diff --git a/st2client/tests/unit/test_auth.py b/st2client/tests/unit/test_auth.py
index cd838712cb..e59b31dfaf 100644
--- a/st2client/tests/unit/test_auth.py
+++ b/st2client/tests/unit/test_auth.py
@@ -29,24 +29,27 @@
from st2client import shell
from st2client.models.core import add_auth_token_to_kwargs_from_env
from st2client.commands.resource import add_auth_token_to_kwargs_from_cli
-from st2client.utils.httpclient import add_auth_token_to_headers, add_json_content_type_to_headers
+from st2client.utils.httpclient import (
+ add_auth_token_to_headers,
+ add_json_content_type_to_headers,
+)
LOG = logging.getLogger(__name__)
if six.PY3:
RULE = {
- 'name': 'drule',
- 'description': 'i am THE rule.',
- 'pack': 'cli',
- 'id': uuid.uuid4().hex
+ "name": "drule",
+ "description": "i am THE rule.",
+ "pack": "cli",
+ "id": uuid.uuid4().hex,
}
else:
RULE = {
- 'id': uuid.uuid4().hex,
- 'description': 'i am THE rule.',
- 'name': 'drule',
- 'pack': 'cli',
+ "id": uuid.uuid4().hex,
+ "description": "i am THE rule.",
+ "name": "drule",
+ "pack": "cli",
}
@@ -59,9 +62,9 @@ class TestLoginBase(base.BaseCLITestCase):
on duplicate code in each test class
"""
- DOTST2_PATH = os.path.expanduser('~/.st2/')
- CONFIG_FILE_NAME = 'st2.conf'
- PARENT_DIR = 'testconfig'
+ DOTST2_PATH = os.path.expanduser("~/.st2/")
+ CONFIG_FILE_NAME = "st2.conf"
+ PARENT_DIR = "testconfig"
TMP_DIR = tempfile.mkdtemp()
CONFIG_CONTENTS = """
[credentials]
@@ -73,11 +76,11 @@ def __init__(self, *args, **kwargs):
super(TestLoginBase, self).__init__(*args, **kwargs)
# We're overriding the default behavior for CLI test cases here
- self.DEFAULT_SKIP_CONFIG = '0'
+ self.DEFAULT_SKIP_CONFIG = "0"
self.parser = argparse.ArgumentParser()
- self.parser.add_argument('-t', '--token', dest='token')
- self.parser.add_argument('--api-key', dest='api_key')
+ self.parser.add_argument("-t", "--token", dest="token")
+ self.parser.add_argument("--api-key", dest="api_key")
self.shell = shell.Shell()
self.CONFIG_DIR = os.path.join(self.TMP_DIR, self.PARENT_DIR)
@@ -94,9 +97,9 @@ def setUp(self):
if os.path.isfile(self.CONFIG_FILE):
os.remove(self.CONFIG_FILE)
- with open(self.CONFIG_FILE, 'w') as cfg:
- for line in self.CONFIG_CONTENTS.split('\n'):
- cfg.write('%s\n' % line.strip())
+ with open(self.CONFIG_FILE, "w") as cfg:
+ for line in self.CONFIG_CONTENTS.split("\n"):
+ cfg.write("%s\n" % line.strip())
os.chmod(self.CONFIG_FILE, 0o660)
@@ -107,7 +110,7 @@ def tearDown(self):
os.remove(self.CONFIG_FILE)
# Clean up tokens
- for file in [f for f in os.listdir(self.DOTST2_PATH) if 'token-' in f]:
+ for file in [f for f in os.listdir(self.DOTST2_PATH) if "token-" in f]:
os.remove(self.DOTST2_PATH + file)
# Clean up config directory
@@ -116,181 +119,208 @@ def tearDown(self):
class TestLoginPasswordAndConfig(TestLoginBase):
- CONFIG_FILE_NAME = 'logintest.cfg'
+ CONFIG_FILE_NAME = "logintest.cfg"
TOKEN = {
- 'user': 'st2admin',
- 'token': '44583f15945b4095afbf57058535ca64',
- 'expiry': '2017-02-12T00:53:09.632783Z',
- 'id': '589e607532ed3535707f10eb',
- 'metadata': {}
+ "user": "st2admin",
+ "token": "44583f15945b4095afbf57058535ca64",
+ "expiry": "2017-02-12T00:53:09.632783Z",
+ "id": "589e607532ed3535707f10eb",
+ "metadata": {},
}
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")),
+ )
def runTest(self):
- '''Test 'st2 login' functionality by specifying a password and a configuration file
- '''
-
- expected_username = self.TOKEN['user']
- args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password',
- 'Password1!']
+ """Test 'st2 login' functionality by specifying a password and a configuration file"""
+
+ expected_username = self.TOKEN["user"]
+ args = [
+ "--config",
+ self.CONFIG_FILE,
+ "login",
+ expected_username,
+ "--password",
+ "Password1!",
+ ]
self.shell.run(args)
- with open(self.CONFIG_FILE, 'r') as config_file:
+ with open(self.CONFIG_FILE, "r") as config_file:
for line in config_file.readlines():
# Make sure certain values are not present
- self.assertNotIn('password', line)
- self.assertNotIn('olduser', line)
+ self.assertNotIn("password", line)
+ self.assertNotIn("olduser", line)
# Make sure configured username is what we expect
- if 'username' in line:
- self.assertEqual(line.split(' ')[2][:-1], expected_username)
+ if "username" in line:
+ self.assertEqual(line.split(" ")[2][:-1], expected_username)
# validate token was created
- self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username)))
+ self.assertTrue(
+ os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username))
+ )
class TestLoginIntPwdAndConfig(TestLoginBase):
- CONFIG_FILE_NAME = 'logintest.cfg'
+ CONFIG_FILE_NAME = "logintest.cfg"
TOKEN = {
- 'user': 'st2admin',
- 'token': '44583f15945b4095afbf57058535ca64',
- 'expiry': '2017-02-12T00:53:09.632783Z',
- 'id': '589e607532ed3535707f10eb',
- 'metadata': {}
+ "user": "st2admin",
+ "token": "44583f15945b4095afbf57058535ca64",
+ "expiry": "2017-02-12T00:53:09.632783Z",
+ "id": "589e607532ed3535707f10eb",
+ "metadata": {},
}
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")),
+ )
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
- @mock.patch('st2client.commands.auth.getpass')
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
+ @mock.patch("st2client.commands.auth.getpass")
def runTest(self, mock_gp):
- '''Test 'st2 login' functionality with interactive password entry
- '''
+ """Test 'st2 login' functionality with interactive password entry"""
- expected_username = self.TOKEN['user']
- args = ['--config', self.CONFIG_FILE, 'login', expected_username]
+ expected_username = self.TOKEN["user"]
+ args = ["--config", self.CONFIG_FILE, "login", expected_username]
- mock_gp.getpass.return_value = 'Password1!'
+ mock_gp.getpass.return_value = "Password1!"
self.shell.run(args)
expected_kwargs = {
- 'headers': {'content-type': 'application/json'},
- 'auth': ('st2admin', 'Password1!')
+ "headers": {"content-type": "application/json"},
+ "auth": ("st2admin", "Password1!"),
}
- requests.post.assert_called_with('http://127.0.0.1:9100/tokens', '{}', **expected_kwargs)
+ requests.post.assert_called_with(
+ "http://127.0.0.1:9100/tokens", "{}", **expected_kwargs
+ )
# Check file permissions
self.assertEqual(os.stat(self.CONFIG_FILE).st_mode & 0o777, 0o660)
- with open(self.CONFIG_FILE, 'r') as config_file:
+ with open(self.CONFIG_FILE, "r") as config_file:
for line in config_file.readlines():
# Make sure certain values are not present
- self.assertNotIn('password', line)
- self.assertNotIn('olduser', line)
+ self.assertNotIn("password", line)
+ self.assertNotIn("olduser", line)
# Make sure configured username is what we expect
- if 'username' in line:
- self.assertEqual(line.split(' ')[2][:-1], expected_username)
+ if "username" in line:
+ self.assertEqual(line.split(" ")[2][:-1], expected_username)
# validate token was created
- self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username)))
+ self.assertTrue(
+ os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username))
+ )
# Validate token is sent on subsequent requests to st2 API
- args = ['--config', self.CONFIG_FILE, 'pack', 'list']
+ args = ["--config", self.CONFIG_FILE, "pack", "list"]
self.shell.run(args)
expected_kwargs = {
- 'headers': {
- 'X-Auth-Token': self.TOKEN['token']
- },
- 'params': {
- 'include_attributes': 'ref,name,description,version,author'
- }
+ "headers": {"X-Auth-Token": self.TOKEN["token"]},
+ "params": {"include_attributes": "ref,name,description,version,author"},
}
- requests.get.assert_called_with('http://127.0.0.1:9101/v1/packs', **expected_kwargs)
+ requests.get.assert_called_with(
+ "http://127.0.0.1:9101/v1/packs", **expected_kwargs
+ )
class TestLoginWritePwdOkay(TestLoginBase):
- CONFIG_FILE_NAME = 'logintest.cfg'
+ CONFIG_FILE_NAME = "logintest.cfg"
TOKEN = {
- 'user': 'st2admin',
- 'token': '44583f15945b4095afbf57058535ca64',
- 'expiry': '2017-02-12T00:53:09.632783Z',
- 'id': '589e607532ed3535707f10eb',
- 'metadata': {}
+ "user": "st2admin",
+ "token": "44583f15945b4095afbf57058535ca64",
+ "expiry": "2017-02-12T00:53:09.632783Z",
+ "id": "589e607532ed3535707f10eb",
+ "metadata": {},
}
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK')))
- @mock.patch('st2client.commands.auth.getpass')
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")),
+ )
+ @mock.patch("st2client.commands.auth.getpass")
def runTest(self, mock_gp):
- '''Test 'st2 login' functionality with --write-password flag set
- '''
-
- expected_username = self.TOKEN['user']
- args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password',
- 'Password1!', '--write-password']
+ """Test 'st2 login' functionality with --write-password flag set"""
+
+ expected_username = self.TOKEN["user"]
+ args = [
+ "--config",
+ self.CONFIG_FILE,
+ "login",
+ expected_username,
+ "--password",
+ "Password1!",
+ "--write-password",
+ ]
self.shell.run(args)
- with open(self.CONFIG_FILE, 'r') as config_file:
+ with open(self.CONFIG_FILE, "r") as config_file:
for line in config_file.readlines():
# Make sure certain values are not present
- self.assertNotIn('olduser', line)
+ self.assertNotIn("olduser", line)
# Make sure configured username is what we expect
- if 'username' in line:
- self.assertEqual(line.split(' ')[2][:-1], expected_username)
+ if "username" in line:
+ self.assertEqual(line.split(" ")[2][:-1], expected_username)
# validate token was created
- self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username)))
+ self.assertTrue(
+ os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username))
+ )
class TestLoginUncaughtException(TestLoginBase):
- CONFIG_FILE_NAME = 'logintest.cfg'
+ CONFIG_FILE_NAME = "logintest.cfg"
TOKEN = {
- 'user': 'st2admin',
- 'token': '44583f15945b4095afbf57058535ca64',
- 'expiry': '2017-02-12T00:53:09.632783Z',
- 'id': '589e607532ed3535707f10eb',
- 'metadata': {}
+ "user": "st2admin",
+ "token": "44583f15945b4095afbf57058535ca64",
+ "expiry": "2017-02-12T00:53:09.632783Z",
+ "id": "589e607532ed3535707f10eb",
+ "metadata": {},
}
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK')))
- @mock.patch('st2client.commands.auth.getpass')
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")),
+ )
+ @mock.patch("st2client.commands.auth.getpass")
def runTest(self, mock_gp):
- '''Test 'st2 login' ability to detect unhandled exceptions
- '''
+ """Test 'st2 login' ability to detect unhandled exceptions"""
- expected_username = self.TOKEN['user']
- args = ['--config', self.CONFIG_FILE, 'login', expected_username]
+ expected_username = self.TOKEN["user"]
+ args = ["--config", self.CONFIG_FILE, "login", expected_username]
mock_gp.getpass = mock.MagicMock(side_effect=Exception)
self.shell.run(args)
retcode = self.shell.run(args)
- self.assertIn('Failed to log in as %s' % expected_username, self.stdout.getvalue())
- self.assertNotIn('Logged in as', self.stdout.getvalue())
+ self.assertIn(
+ "Failed to log in as %s" % expected_username, self.stdout.getvalue()
+ )
+ self.assertNotIn("Logged in as", self.stdout.getvalue())
self.assertEqual(retcode, 1)
@@ -301,26 +331,26 @@ class TestAuthToken(base.BaseCLITestCase):
def __init__(self, *args, **kwargs):
super(TestAuthToken, self).__init__(*args, **kwargs)
self.parser = argparse.ArgumentParser()
- self.parser.add_argument('-t', '--token', dest='token')
- self.parser.add_argument('--api-key', dest='api_key')
+ self.parser.add_argument("-t", "--token", dest="token")
+ self.parser.add_argument("--api-key", dest="api_key")
self.shell = shell.Shell()
def setUp(self):
super(TestAuthToken, self).setUp()
# Setup environment.
- os.environ['ST2_BASE_URL'] = 'http://127.0.0.1'
+ os.environ["ST2_BASE_URL"] = "http://127.0.0.1"
def tearDown(self):
super(TestAuthToken, self).tearDown()
# Clean up environment.
- if 'ST2_AUTH_TOKEN' in os.environ:
- del os.environ['ST2_AUTH_TOKEN']
- if 'ST2_API_KEY' in os.environ:
- del os.environ['ST2_API_KEY']
- if 'ST2_BASE_URL' in os.environ:
- del os.environ['ST2_BASE_URL']
+ if "ST2_AUTH_TOKEN" in os.environ:
+ del os.environ["ST2_AUTH_TOKEN"]
+ if "ST2_API_KEY" in os.environ:
+ del os.environ["ST2_API_KEY"]
+ if "ST2_BASE_URL" in os.environ:
+ del os.environ["ST2_BASE_URL"]
@add_auth_token_to_kwargs_from_cli
@add_auth_token_to_kwargs_from_env
@@ -329,27 +359,27 @@ def _mock_run(self, args, **kwargs):
def test_decorate_auth_token_by_cli(self):
token = uuid.uuid4().hex
- args = self.parser.parse_args(args=['-t', token])
- self.assertDictEqual(self._mock_run(args), {'token': token})
- args = self.parser.parse_args(args=['--token', token])
- self.assertDictEqual(self._mock_run(args), {'token': token})
+ args = self.parser.parse_args(args=["-t", token])
+ self.assertDictEqual(self._mock_run(args), {"token": token})
+ args = self.parser.parse_args(args=["--token", token])
+ self.assertDictEqual(self._mock_run(args), {"token": token})
def test_decorate_api_key_by_cli(self):
token = uuid.uuid4().hex
- args = self.parser.parse_args(args=['--api-key', token])
- self.assertDictEqual(self._mock_run(args), {'api_key': token})
+ args = self.parser.parse_args(args=["--api-key", token])
+ self.assertDictEqual(self._mock_run(args), {"api_key": token})
def test_decorate_auth_token_by_env(self):
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
+ os.environ["ST2_AUTH_TOKEN"] = token
args = self.parser.parse_args(args=[])
- self.assertDictEqual(self._mock_run(args), {'token': token})
+ self.assertDictEqual(self._mock_run(args), {"token": token})
def test_decorate_api_key_by_env(self):
token = uuid.uuid4().hex
- os.environ['ST2_API_KEY'] = token
+ os.environ["ST2_API_KEY"] = token
args = self.parser.parse_args(args=[])
- self.assertDictEqual(self._mock_run(args), {'api_key': token})
+ self.assertDictEqual(self._mock_run(args), {"api_key": token})
def test_decorate_without_auth_token(self):
args = self.parser.parse_args(args=[])
@@ -362,187 +392,215 @@ def _mock_http(self, url, **kwargs):
def test_decorate_auth_token_to_http_headers(self):
token = uuid.uuid4().hex
- kwargs = self._mock_http('/', token=token)
- expected = {'content-type': 'application/json', 'X-Auth-Token': token}
- self.assertIn('headers', kwargs)
- self.assertDictEqual(kwargs['headers'], expected)
+ kwargs = self._mock_http("/", token=token)
+ expected = {"content-type": "application/json", "X-Auth-Token": token}
+ self.assertIn("headers", kwargs)
+ self.assertDictEqual(kwargs["headers"], expected)
def test_decorate_api_key_to_http_headers(self):
token = uuid.uuid4().hex
- kwargs = self._mock_http('/', api_key=token)
- expected = {'content-type': 'application/json', 'St2-Api-Key': token}
- self.assertIn('headers', kwargs)
- self.assertDictEqual(kwargs['headers'], expected)
+ kwargs = self._mock_http("/", api_key=token)
+ expected = {"content-type": "application/json", "St2-Api-Key": token}
+ self.assertIn("headers", kwargs)
+ self.assertDictEqual(kwargs["headers"], expected)
def test_decorate_without_auth_token_to_http_headers(self):
- kwargs = self._mock_http('/', auth=('stanley', 'stanley'))
- expected = {'content-type': 'application/json'}
- self.assertIn('auth', kwargs)
- self.assertEqual(kwargs['auth'], ('stanley', 'stanley'))
- self.assertIn('headers', kwargs)
- self.assertDictEqual(kwargs['headers'], expected)
+ kwargs = self._mock_http("/", auth=("stanley", "stanley"))
+ expected = {"content-type": "application/json"}
+ self.assertIn("auth", kwargs)
+ self.assertEqual(kwargs["auth"], ("stanley", "stanley"))
+ self.assertIn("headers", kwargs)
+ self.assertDictEqual(kwargs["headers"], expected)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_resource_list(self):
- url = ('http://127.0.0.1:9101/v1/rules/'
- '?include_attributes=ref,pack,description,enabled&limit=50')
- url = url.replace(',', '%2C')
+ url = (
+ "http://127.0.0.1:9101/v1/rules/"
+ "?include_attributes=ref,pack,description,enabled&limit=50"
+ )
+ url = url.replace(",", "%2C")
# Test without token.
- self.shell.run(['rule', 'list'])
+ self.shell.run(["rule", "list"])
kwargs = {}
requests.get.assert_called_with(url, **kwargs)
# Test with token from cli.
token = uuid.uuid4().hex
- self.shell.run(['rule', 'list', '-t', token])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ self.shell.run(["rule", "list", "-t", token])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(url, **kwargs)
# Test with token from env.
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
- self.shell.run(['rule', 'list'])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ os.environ["ST2_AUTH_TOKEN"] = token
+ self.shell.run(["rule", "list"])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(url, **kwargs)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")),
+ )
def test_decorate_resource_get(self):
- rule_ref = '%s.%s' % (RULE['pack'], RULE['name'])
- url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref
+ rule_ref = "%s.%s" % (RULE["pack"], RULE["name"])
+ url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref
# Test without token.
- self.shell.run(['rule', 'get', rule_ref])
+ self.shell.run(["rule", "get", rule_ref])
kwargs = {}
requests.get.assert_called_with(url, **kwargs)
# Test with token from cli.
token = uuid.uuid4().hex
- self.shell.run(['rule', 'get', rule_ref, '-t', token])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ self.shell.run(["rule", "get", rule_ref, "-t", token])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(url, **kwargs)
# Test with token from env.
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
- self.shell.run(['rule', 'get', rule_ref])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ os.environ["ST2_AUTH_TOKEN"] = token
+ self.shell.run(["rule", "get", rule_ref])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(url, **kwargs)
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")),
+ )
def test_decorate_resource_post(self):
- url = 'http://127.0.0.1:9101/v1/rules'
- data = {'name': RULE['name'], 'description': RULE['description']}
+ url = "http://127.0.0.1:9101/v1/rules"
+ data = {"name": RULE["name"], "description": RULE["description"]}
- fd, path = tempfile.mkstemp(suffix='.json')
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(data, indent=4))
# Test without token.
- self.shell.run(['rule', 'create', path])
- kwargs = {'headers': {'content-type': 'application/json'}}
+ self.shell.run(["rule", "create", path])
+ kwargs = {"headers": {"content-type": "application/json"}}
requests.post.assert_called_with(url, json.dumps(data), **kwargs)
# Test with token from cli.
token = uuid.uuid4().hex
- self.shell.run(['rule', 'create', path, '-t', token])
- kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}}
+ self.shell.run(["rule", "create", path, "-t", token])
+ kwargs = {
+ "headers": {"content-type": "application/json", "X-Auth-Token": token}
+ }
requests.post.assert_called_with(url, json.dumps(data), **kwargs)
# Test with token from env.
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
- self.shell.run(['rule', 'create', path])
- kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}}
+ os.environ["ST2_AUTH_TOKEN"] = token
+ self.shell.run(["rule", "create", path])
+ kwargs = {
+ "headers": {"content-type": "application/json", "X-Auth-Token": token}
+ }
requests.post.assert_called_with(url, json.dumps(data), **kwargs)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")),
+ )
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")),
+ )
def test_decorate_resource_put(self):
- rule_ref = '%s.%s' % (RULE['pack'], RULE['name'])
-
- get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref
- put_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id']
- data = {'name': RULE['name'], 'description': RULE['description'], 'pack': RULE['pack']}
+ rule_ref = "%s.%s" % (RULE["pack"], RULE["name"])
+
+ get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref
+ put_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"]
+ data = {
+ "name": RULE["name"],
+ "description": RULE["description"],
+ "pack": RULE["pack"],
+ }
- fd, path = tempfile.mkstemp(suffix='.json')
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(data, indent=4))
# Test without token.
- self.shell.run(['rule', 'update', rule_ref, path])
+ self.shell.run(["rule", "update", rule_ref, path])
kwargs = {}
requests.get.assert_called_with(get_url, **kwargs)
- kwargs = {'headers': {'content-type': 'application/json'}}
+ kwargs = {"headers": {"content-type": "application/json"}}
requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs)
# Test with token from cli.
token = uuid.uuid4().hex
- self.shell.run(['rule', 'update', rule_ref, path, '-t', token])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ self.shell.run(["rule", "update", rule_ref, path, "-t", token])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(get_url, **kwargs)
- kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}}
+ kwargs = {
+ "headers": {"content-type": "application/json", "X-Auth-Token": token}
+ }
requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs)
# Test with token from env.
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
- self.shell.run(['rule', 'update', rule_ref, path])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ os.environ["ST2_AUTH_TOKEN"] = token
+ self.shell.run(["rule", "update", rule_ref, path])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(get_url, **kwargs)
# Note: We parse the payload because data might not be in the same
# order as the fixture
- kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}}
+ kwargs = {
+ "headers": {"content-type": "application/json", "X-Auth-Token": token}
+ }
requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")),
+ )
@mock.patch.object(
- requests, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 204, 'OK')))
+ requests,
+ "delete",
+ mock.MagicMock(return_value=base.FakeResponse("", 204, "OK")),
+ )
def test_decorate_resource_delete(self):
- rule_ref = '%s.%s' % (RULE['pack'], RULE['name'])
- get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref
- del_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id']
+ rule_ref = "%s.%s" % (RULE["pack"], RULE["name"])
+ get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref
+ del_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"]
# Test without token.
- self.shell.run(['rule', 'delete', rule_ref])
+ self.shell.run(["rule", "delete", rule_ref])
kwargs = {}
requests.get.assert_called_with(get_url, **kwargs)
requests.delete.assert_called_with(del_url, **kwargs)
# Test with token from cli.
token = uuid.uuid4().hex
- self.shell.run(['rule', 'delete', rule_ref, '-t', token])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ self.shell.run(["rule", "delete", rule_ref, "-t", token])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(get_url, **kwargs)
requests.delete.assert_called_with(del_url, **kwargs)
# Test with token from env.
token = uuid.uuid4().hex
- os.environ['ST2_AUTH_TOKEN'] = token
- self.shell.run(['rule', 'delete', rule_ref])
- kwargs = {'headers': {'X-Auth-Token': token}}
+ os.environ["ST2_AUTH_TOKEN"] = token
+ self.shell.run(["rule", "delete", rule_ref])
+ kwargs = {"headers": {"X-Auth-Token": token}}
requests.get.assert_called_with(get_url, **kwargs)
requests.delete.assert_called_with(del_url, **kwargs)
diff --git a/st2client/tests/unit/test_client.py b/st2client/tests/unit/test_client.py
index 2e9fd95095..2d0a380ab5 100644
--- a/st2client/tests/unit/test_client.py
+++ b/st2client/tests/unit/test_client.py
@@ -25,25 +25,25 @@
LOG = logging.getLogger(__name__)
-NONRESOURCES = ['workflows']
+NONRESOURCES = ["workflows"]
class TestClientEndpoints(unittest2.TestCase):
-
def tearDown(self):
for var in [
- 'ST2_BASE_URL',
- 'ST2_API_URL',
- 'ST2_STREAM_URL',
- 'ST2_DATASTORE_URL',
- 'ST2_AUTH_TOKEN'
+ "ST2_BASE_URL",
+ "ST2_API_URL",
+ "ST2_STREAM_URL",
+ "ST2_DATASTORE_URL",
+ "ST2_AUTH_TOKEN",
]:
if var in os.environ:
del os.environ[var]
def test_managers(self):
- property_names = [k for k, v in six.iteritems(Client.__dict__)
- if isinstance(v, property)]
+ property_names = [
+ k for k, v in six.iteritems(Client.__dict__) if isinstance(v, property)
+ ]
client = Client()
@@ -55,96 +55,109 @@ def test_managers(self):
self.assertIsInstance(manager, models.ResourceManager)
def test_default(self):
- base_url = 'http://127.0.0.1'
- api_url = 'http://127.0.0.1:9101/v1'
- stream_url = 'http://127.0.0.1:9102/v1'
+ base_url = "http://127.0.0.1"
+ api_url = "http://127.0.0.1:9101/v1"
+ stream_url = "http://127.0.0.1:9102/v1"
client = Client()
endpoints = client.endpoints
- self.assertEqual(endpoints['base'], base_url)
- self.assertEqual(endpoints['api'], api_url)
- self.assertEqual(endpoints['stream'], stream_url)
+ self.assertEqual(endpoints["base"], base_url)
+ self.assertEqual(endpoints["api"], api_url)
+ self.assertEqual(endpoints["stream"], stream_url)
def test_env(self):
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
- os.environ['ST2_BASE_URL'] = base_url
- os.environ['ST2_API_URL'] = api_url
- os.environ['ST2_STREAM_URL'] = stream_url
- self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url)
- self.assertEqual(os.environ.get('ST2_API_URL'), api_url)
- self.assertEqual(os.environ.get('ST2_STREAM_URL'), stream_url)
+ os.environ["ST2_BASE_URL"] = base_url
+ os.environ["ST2_API_URL"] = api_url
+ os.environ["ST2_STREAM_URL"] = stream_url
+ self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url)
+ self.assertEqual(os.environ.get("ST2_API_URL"), api_url)
+ self.assertEqual(os.environ.get("ST2_STREAM_URL"), stream_url)
client = Client()
endpoints = client.endpoints
- self.assertEqual(endpoints['base'], base_url)
- self.assertEqual(endpoints['api'], api_url)
- self.assertEqual(endpoints['stream'], stream_url)
+ self.assertEqual(endpoints["base"], base_url)
+ self.assertEqual(endpoints["api"], api_url)
+ self.assertEqual(endpoints["stream"], stream_url)
def test_env_base_only(self):
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.stackstorm.com:9101/v1'
- stream_url = 'http://www.stackstorm.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.stackstorm.com:9101/v1"
+ stream_url = "http://www.stackstorm.com:9102/v1"
- os.environ['ST2_BASE_URL'] = base_url
- self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url)
- self.assertEqual(os.environ.get('ST2_API_URL'), None)
- self.assertEqual(os.environ.get('ST2_STREAM_URL'), None)
+ os.environ["ST2_BASE_URL"] = base_url
+ self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url)
+ self.assertEqual(os.environ.get("ST2_API_URL"), None)
+ self.assertEqual(os.environ.get("ST2_STREAM_URL"), None)
client = Client()
endpoints = client.endpoints
- self.assertEqual(endpoints['base'], base_url)
- self.assertEqual(endpoints['api'], api_url)
- self.assertEqual(endpoints['stream'], stream_url)
+ self.assertEqual(endpoints["base"], base_url)
+ self.assertEqual(endpoints["api"], api_url)
+ self.assertEqual(endpoints["stream"], stream_url)
def test_args(self):
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url)
endpoints = client.endpoints
- self.assertEqual(endpoints['base'], base_url)
- self.assertEqual(endpoints['api'], api_url)
- self.assertEqual(endpoints['stream'], stream_url)
+ self.assertEqual(endpoints["base"], base_url)
+ self.assertEqual(endpoints["api"], api_url)
+ self.assertEqual(endpoints["stream"], stream_url)
def test_cacert_arg(self):
# Valid value, boolean True
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
- client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True)
+ client = Client(
+ base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True
+ )
self.assertEqual(client.cacert, True)
# Valid value, boolean False
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
- client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False)
+ client = Client(
+ base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False
+ )
self.assertEqual(client.cacert, False)
# Valid value, existing path to a CA bundle
cacert = os.path.abspath(__file__)
- client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert)
+ client = Client(
+ base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert
+ )
self.assertEqual(client.cacert, cacert)
# Invalid value, path to the bundle doesn't exist
cacert = os.path.abspath(__file__)
expected_msg = 'CA cert file "doesntexist" does not exist'
- self.assertRaisesRegexp(ValueError, expected_msg, Client, base_url=base_url,
- api_url=api_url, stream_url=stream_url, cacert='doesntexist')
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ Client,
+ base_url=base_url,
+ api_url=api_url,
+ stream_url=stream_url,
+ cacert="doesntexist",
+ )
def test_args_base_only(self):
- base_url = 'http://www.stackstorm.com'
- api_url = 'http://www.stackstorm.com:9101/v1'
- stream_url = 'http://www.stackstorm.com:9102/v1'
+ base_url = "http://www.stackstorm.com"
+ api_url = "http://www.stackstorm.com:9101/v1"
+ stream_url = "http://www.stackstorm.com:9102/v1"
client = Client(base_url=base_url)
endpoints = client.endpoints
- self.assertEqual(endpoints['base'], base_url)
- self.assertEqual(endpoints['api'], api_url)
- self.assertEqual(endpoints['stream'], stream_url)
+ self.assertEqual(endpoints["base"], base_url)
+ self.assertEqual(endpoints["api"], api_url)
+ self.assertEqual(endpoints["stream"], stream_url)
diff --git a/st2client/tests/unit/test_client_actions.py b/st2client/tests/unit/test_client_actions.py
index 82b12b788d..141e7c8ece 100644
--- a/st2client/tests/unit/test_client_actions.py
+++ b/st2client/tests/unit/test_client_actions.py
@@ -31,22 +31,17 @@
EXECUTION = {
"id": 12345,
- "action": {
- "ref": "mock.foobar"
- },
+ "action": {"ref": "mock.foobar"},
"status": "failed",
- "result": "non-empty"
+ "result": "non-empty",
}
ENTRYPOINT = (
"version: 1.0"
-
"description: A basic workflow that runs an arbitrary linux command."
-
"input:"
" - cmd"
" - timeout"
-
"tasks:"
" task1:"
" action: core.local cmd=<% ctx(cmd) %> timeout=<% ctx(timeout) %>"
@@ -55,51 +50,63 @@
" publish:"
" - stdout: <% result().stdout %>"
" - stderr: <% result().stderr %>"
-
"output:"
" - stdout: <% ctx(stdout) %>"
)
class TestActionResourceManager(unittest2.TestCase):
-
@classmethod
def setUpClass(cls):
super(TestActionResourceManager, cls).setUpClass()
cls.client = client.Client()
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK")
+ ),
+ )
def test_get_action_entry_point_by_ref(self):
- actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['action']['ref'])
+ actual_entrypoint = self.client.actions.get_entrypoint(
+ EXECUTION["action"]["ref"]
+ )
actual_entrypoint = json.loads(actual_entrypoint)
- endpoint = '/actions/views/entry_point/%s' % EXECUTION['action']['ref']
+ endpoint = "/actions/views/entry_point/%s" % EXECUTION["action"]["ref"]
httpclient.HTTPClient.get.assert_called_with(endpoint)
self.assertEqual(ENTRYPOINT, actual_entrypoint)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK")
+ ),
+ )
def test_get_action_entry_point_by_id(self):
- actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['id'])
+ actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION["id"])
actual_entrypoint = json.loads(actual_entrypoint)
- endpoint = '/actions/views/entry_point/%s' % EXECUTION['id']
+ endpoint = "/actions/views/entry_point/%s" % EXECUTION["id"]
httpclient.HTTPClient.get.assert_called_with(endpoint)
self.assertEqual(ENTRYPOINT, actual_entrypoint)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(
- json.dumps({}), 404, '404 Client Error: Not Found'
- )))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps({}), 404, "404 Client Error: Not Found"
+ )
+ ),
+ )
def test_get_non_existent_action_entry_point(self):
- with self.assertRaisesRegexp(Exception, '404 Client Error: Not Found'):
- self.client.actions.get_entrypoint('nonexistentpack.nonexistentaction')
+ with self.assertRaisesRegexp(Exception, "404 Client Error: Not Found"):
+ self.client.actions.get_entrypoint("nonexistentpack.nonexistentaction")
- endpoint = '/actions/views/entry_point/%s' % 'nonexistentpack.nonexistentaction'
+ endpoint = "/actions/views/entry_point/%s" % "nonexistentpack.nonexistentaction"
httpclient.HTTPClient.get.assert_called_with(endpoint)
diff --git a/st2client/tests/unit/test_client_executions.py b/st2client/tests/unit/test_client_executions.py
index 0470347ee2..a9dc19e2c3 100644
--- a/st2client/tests/unit/test_client_executions.py
+++ b/st2client/tests/unit/test_client_executions.py
@@ -34,9 +34,7 @@
RUNNER = {
"enabled": True,
"name": "marathon",
- "runner_parameters": {
- "var1": {"type": "string"}
- }
+ "runner_parameters": {"var1": {"type": "string"}},
}
ACTION = {
@@ -46,185 +44,227 @@
"parameters": {},
"enabled": True,
"entry_point": "",
- "pack": "mocke"
+ "pack": "mocke",
}
EXECUTION = {
"id": 12345,
- "action": {
- "ref": "mock.foobar"
- },
+ "action": {"ref": "mock.foobar"},
"status": "failed",
- "result": "non-empty"
+ "result": "non-empty",
}
class TestExecutionResourceManager(unittest2.TestCase):
-
@classmethod
def setUpClass(cls):
super(TestExecutionResourceManager, cls).setUpClass()
cls.client = client.Client()
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=models.Execution(**EXECUTION)))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=models.Execution(**EXECUTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=models.Action(**ACTION)))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(return_value=models.Action(**ACTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=models.RunnerType(**RUNNER)))
+ models.ResourceManager,
+ "get_by_name",
+ mock.MagicMock(return_value=models.RunnerType(**RUNNER)),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK")
+ ),
+ )
def test_rerun_with_no_params(self):
- self.client.executions.re_run(EXECUTION['id'], tasks=['foobar'])
+ self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"])
- endpoint = '/executions/%s/re_run' % EXECUTION['id']
+ endpoint = "/executions/%s/re_run" % EXECUTION["id"]
- data = {
- 'tasks': ['foobar'],
- 'reset': ['foobar'],
- 'parameters': {},
- 'delay': 0
- }
+ data = {"tasks": ["foobar"], "reset": ["foobar"], "parameters": {}, "delay": 0}
httpclient.HTTPClient.post.assert_called_with(endpoint, data)
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=models.Execution(**EXECUTION)))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=models.Execution(**EXECUTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=models.Action(**ACTION)))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(return_value=models.Action(**ACTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=models.RunnerType(**RUNNER)))
+ models.ResourceManager,
+ "get_by_name",
+ mock.MagicMock(return_value=models.RunnerType(**RUNNER)),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK")
+ ),
+ )
def test_rerun_with_params(self):
- params = {
- 'var1': 'testing...'
- }
+ params = {"var1": "testing..."}
self.client.executions.re_run(
- EXECUTION['id'],
- tasks=['foobar'],
- parameters=params
+ EXECUTION["id"], tasks=["foobar"], parameters=params
)
- endpoint = '/executions/%s/re_run' % EXECUTION['id']
+ endpoint = "/executions/%s/re_run" % EXECUTION["id"]
data = {
- 'tasks': ['foobar'],
- 'reset': ['foobar'],
- 'parameters': params,
- 'delay': 0
+ "tasks": ["foobar"],
+ "reset": ["foobar"],
+ "parameters": params,
+ "delay": 0,
}
httpclient.HTTPClient.post.assert_called_with(endpoint, data)
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=models.Execution(**EXECUTION)))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=models.Execution(**EXECUTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=models.Action(**ACTION)))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(return_value=models.Action(**ACTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=models.RunnerType(**RUNNER)))
+ models.ResourceManager,
+ "get_by_name",
+ mock.MagicMock(return_value=models.RunnerType(**RUNNER)),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK")
+ ),
+ )
def test_rerun_with_delay(self):
- self.client.executions.re_run(EXECUTION['id'], tasks=['foobar'], delay=100)
+ self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"], delay=100)
- endpoint = '/executions/%s/re_run' % EXECUTION['id']
+ endpoint = "/executions/%s/re_run" % EXECUTION["id"]
data = {
- 'tasks': ['foobar'],
- 'reset': ['foobar'],
- 'parameters': {},
- 'delay': 100
+ "tasks": ["foobar"],
+ "reset": ["foobar"],
+ "parameters": {},
+ "delay": 100,
}
httpclient.HTTPClient.post.assert_called_with(endpoint, data)
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=models.Execution(**EXECUTION)))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=models.Execution(**EXECUTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=models.Action(**ACTION)))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(return_value=models.Action(**ACTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=models.RunnerType(**RUNNER)))
+ models.ResourceManager,
+ "get_by_name",
+ mock.MagicMock(return_value=models.RunnerType(**RUNNER)),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK")
+ ),
+ )
def test_pause(self):
- self.client.executions.pause(EXECUTION['id'])
+ self.client.executions.pause(EXECUTION["id"])
- endpoint = '/executions/%s' % EXECUTION['id']
+ endpoint = "/executions/%s" % EXECUTION["id"]
- data = {
- 'status': 'pausing'
- }
+ data = {"status": "pausing"}
httpclient.HTTPClient.put.assert_called_with(endpoint, data)
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=models.Execution(**EXECUTION)))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=models.Execution(**EXECUTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=models.Action(**ACTION)))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(return_value=models.Action(**ACTION)),
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=models.RunnerType(**RUNNER)))
+ models.ResourceManager,
+ "get_by_name",
+ mock.MagicMock(return_value=models.RunnerType(**RUNNER)),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK")
+ ),
+ )
def test_resume(self):
- self.client.executions.resume(EXECUTION['id'])
+ self.client.executions.resume(EXECUTION["id"])
- endpoint = '/executions/%s' % EXECUTION['id']
+ endpoint = "/executions/%s" % EXECUTION["id"]
- data = {
- 'status': 'resuming'
- }
+ data = {"status": "resuming"}
httpclient.HTTPClient.put.assert_called_with(endpoint, data)
@mock.patch.object(
- models.core.Resource, 'get_url_path_name',
- mock.MagicMock(return_value='executions'))
+ models.core.Resource,
+ "get_url_path_name",
+ mock.MagicMock(return_value="executions"),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK")
+ ),
+ )
def test_get_children(self):
- self.client.executions.get_children(EXECUTION['id'])
+ self.client.executions.get_children(EXECUTION["id"])
- endpoint = '/executions/%s/children' % EXECUTION['id']
+ endpoint = "/executions/%s/children" % EXECUTION["id"]
- data = {
- 'depth': -1
- }
+ data = {"depth": -1}
httpclient.HTTPClient.get.assert_called_with(url=endpoint, params=data)
@mock.patch.object(
- models.ResourceManager, 'get_all',
- mock.MagicMock(return_value=[models.Execution(**EXECUTION)]))
- @mock.patch.object(warnings, 'warn')
- def test_st2client_liveactions_has_been_deprecated_and_emits_warning(self, mock_warn):
+ models.ResourceManager,
+ "get_all",
+ mock.MagicMock(return_value=[models.Execution(**EXECUTION)]),
+ )
+ @mock.patch.object(warnings, "warn")
+ def test_st2client_liveactions_has_been_deprecated_and_emits_warning(
+ self, mock_warn
+ ):
self.assertEqual(mock_warn.call_args, None)
self.client.liveactions.get_all()
- expected_msg = 'st2client.liveactions has been renamed'
+ expected_msg = "st2client.liveactions has been renamed"
self.assertTrue(len(mock_warn.call_args_list) >= 1)
self.assertIn(expected_msg, mock_warn.call_args_list[0][0][0])
self.assertEqual(mock_warn.call_args_list[0][0][1], DeprecationWarning)
diff --git a/st2client/tests/unit/test_command_actionrun.py b/st2client/tests/unit/test_command_actionrun.py
index 763ac649a6..1e312e0786 100644
--- a/st2client/tests/unit/test_command_actionrun.py
+++ b/st2client/tests/unit/test_command_actionrun.py
@@ -21,73 +21,79 @@
import mock
from st2client.commands.action import ActionRunCommand
-from st2client.models.action import (Action, RunnerType)
+from st2client.models.action import Action, RunnerType
class ActionRunCommandTest(unittest2.TestCase):
-
def test_get_params_types(self):
runner = RunnerType()
runner_params = {
- 'foo': {'immutable': True, 'required': True},
- 'bar': {'description': 'Some param.', 'type': 'string'}
+ "foo": {"immutable": True, "required": True},
+ "bar": {"description": "Some param.", "type": "string"},
}
runner.runner_parameters = runner_params
orig_runner_params = copy.deepcopy(runner.runner_parameters)
action = Action()
action.parameters = {
- 'foo': {'immutable': False}, # Should not be allowed by API.
- 'stuff': {'description': 'Some param.', 'type': 'string', 'required': True}
+ "foo": {"immutable": False}, # Should not be allowed by API.
+ "stuff": {"description": "Some param.", "type": "string", "required": True},
}
orig_action_params = copy.deepcopy(action.parameters)
params, rqd, opt, imm = ActionRunCommand._get_params_types(runner, action)
self.assertEqual(len(list(params.keys())), 3)
- self.assertIn('foo', imm, '"foo" param should be in immutable set.')
- self.assertNotIn('foo', rqd, '"foo" param should not be in required set.')
- self.assertNotIn('foo', opt, '"foo" param should not be in optional set.')
+ self.assertIn("foo", imm, '"foo" param should be in immutable set.')
+ self.assertNotIn("foo", rqd, '"foo" param should not be in required set.')
+ self.assertNotIn("foo", opt, '"foo" param should not be in optional set.')
- self.assertIn('bar', opt, '"bar" param should be in optional set.')
- self.assertNotIn('bar', rqd, '"bar" param should not be in required set.')
- self.assertNotIn('bar', imm, '"bar" param should not be in immutable set.')
+ self.assertIn("bar", opt, '"bar" param should be in optional set.')
+ self.assertNotIn("bar", rqd, '"bar" param should not be in required set.')
+ self.assertNotIn("bar", imm, '"bar" param should not be in immutable set.')
- self.assertIn('stuff', rqd, '"stuff" param should be in required set.')
- self.assertNotIn('stuff', opt, '"stuff" param should not be in optional set.')
- self.assertNotIn('stuff', imm, '"stuff" param should not be in immutable set.')
- self.assertEqual(runner.runner_parameters, orig_runner_params, 'Runner params modified.')
- self.assertEqual(action.parameters, orig_action_params, 'Action params modified.')
+ self.assertIn("stuff", rqd, '"stuff" param should be in required set.')
+ self.assertNotIn("stuff", opt, '"stuff" param should not be in optional set.')
+ self.assertNotIn("stuff", imm, '"stuff" param should not be in immutable set.')
+ self.assertEqual(
+ runner.runner_parameters, orig_runner_params, "Runner params modified."
+ )
+ self.assertEqual(
+ action.parameters, orig_action_params, "Action params modified."
+ )
def test_opt_in_dict_auto_convert(self):
- """Test ability for user to opt-in to dict convert functionality
- """
+ """Test ability for user to opt-in to dict convert functionality"""
runner = RunnerType()
runner.runner_parameters = {}
action = Action()
- action.ref = 'test.action'
+ action.ref = "test.action"
action.parameters = {
- 'param_array': {'type': 'array'},
+ "param_array": {"type": "array"},
}
subparser = mock.Mock()
- command = ActionRunCommand(action, self, subparser, name='test')
+ command = ActionRunCommand(action, self, subparser, name="test")
mockarg = mock.Mock()
mockarg.inherit_env = False
mockarg.parameters = [
- 'param_array=foo:bar,foo2:bar2',
+ "param_array=foo:bar,foo2:bar2",
]
mockarg.auto_dict = False
- param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg)
- self.assertEqual(param['param_array'], ['foo:bar', 'foo2:bar2'])
+ param = command._get_action_parameters_from_args(
+ action=action, runner=runner, args=mockarg
+ )
+ self.assertEqual(param["param_array"], ["foo:bar", "foo2:bar2"])
mockarg.auto_dict = True
- param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg)
- self.assertEqual(param['param_array'], [{'foo': 'bar', 'foo2': 'bar2'}])
+ param = command._get_action_parameters_from_args(
+ action=action, runner=runner, args=mockarg
+ )
+ self.assertEqual(param["param_array"], [{"foo": "bar", "foo2": "bar2"}])
# set auto_dict back to default
mockarg.auto_dict = False
@@ -104,60 +110,65 @@ def test_get_params_from_args(self):
runner.runner_parameters = {}
action = Action()
- action.ref = 'test.action'
+ action.ref = "test.action"
action.parameters = {
- 'param_string': {'type': 'string'},
- 'param_integer': {'type': 'integer'},
- 'param_number': {'type': 'number'},
- 'param_object': {'type': 'object'},
- 'param_boolean': {'type': 'boolean'},
- 'param_array': {'type': 'array'},
- 'param_array_of_dicts': {'type': 'array', 'properties': {
- 'foo': {'type': 'string'},
- 'bar': {'type': 'integer'},
- 'baz': {'type': 'number'},
- 'qux': {'type': 'object'},
- 'quux': {'type': 'boolean'}}
+ "param_string": {"type": "string"},
+ "param_integer": {"type": "integer"},
+ "param_number": {"type": "number"},
+ "param_object": {"type": "object"},
+ "param_boolean": {"type": "boolean"},
+ "param_array": {"type": "array"},
+ "param_array_of_dicts": {
+ "type": "array",
+ "properties": {
+ "foo": {"type": "string"},
+ "bar": {"type": "integer"},
+ "baz": {"type": "number"},
+ "qux": {"type": "object"},
+ "quux": {"type": "boolean"},
+ },
},
}
subparser = mock.Mock()
- command = ActionRunCommand(action, self, subparser, name='test')
+ command = ActionRunCommand(action, self, subparser, name="test")
mockarg = mock.Mock()
mockarg.inherit_env = False
mockarg.auto_dict = True
mockarg.parameters = [
- 'param_string=hoge',
- 'param_integer=123',
- 'param_number=1.23',
- 'param_object=hoge=1,fuga=2',
- 'param_boolean=False',
- 'param_array=foo,bar,baz',
- 'param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True',
- 'param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False'
+ "param_string=hoge",
+ "param_integer=123",
+ "param_number=1.23",
+ "param_object=hoge=1,fuga=2",
+ "param_boolean=False",
+ "param_array=foo,bar,baz",
+ "param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True",
+ "param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False",
]
- param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg)
+ param = command._get_action_parameters_from_args(
+ action=action, runner=runner, args=mockarg
+ )
self.assertIsInstance(param, dict)
- self.assertEqual(param['param_string'], 'hoge')
- self.assertEqual(param['param_integer'], 123)
- self.assertEqual(param['param_number'], 1.23)
- self.assertEqual(param['param_object'], {'hoge': '1', 'fuga': '2'})
- self.assertFalse(param['param_boolean'])
- self.assertEqual(param['param_array'], ['foo', 'bar', 'baz'])
+ self.assertEqual(param["param_string"], "hoge")
+ self.assertEqual(param["param_integer"], 123)
+ self.assertEqual(param["param_number"], 1.23)
+ self.assertEqual(param["param_object"], {"hoge": "1", "fuga": "2"})
+ self.assertFalse(param["param_boolean"])
+ self.assertEqual(param["param_array"], ["foo", "bar", "baz"])
# checking the result of parsing for array of objects
- self.assertIsInstance(param['param_array_of_dicts'], list)
- self.assertEqual(len(param['param_array_of_dicts']), 2)
- for param in param['param_array_of_dicts']:
+ self.assertIsInstance(param["param_array_of_dicts"], list)
+ self.assertEqual(len(param["param_array_of_dicts"]), 2)
+ for param in param["param_array_of_dicts"]:
self.assertIsInstance(param, dict)
- self.assertIsInstance(param['foo'], str)
- self.assertIsInstance(param['bar'], int)
- self.assertIsInstance(param['baz'], float)
- self.assertIsInstance(param['qux'], dict)
- self.assertIsInstance(param['quux'], bool)
+ self.assertIsInstance(param["foo"], str)
+ self.assertIsInstance(param["bar"], int)
+ self.assertIsInstance(param["baz"], float)
+ self.assertIsInstance(param["qux"], dict)
+ self.assertIsInstance(param["quux"], bool)
# set auto_dict back to default
mockarg.auto_dict = False
@@ -167,36 +178,38 @@ def test_get_params_from_args_read_content_from_file(self):
runner.runner_parameters = {}
action = Action()
- action.ref = 'test.action'
+ action.ref = "test.action"
action.parameters = {
- 'param_object': {'type': 'object'},
+ "param_object": {"type": "object"},
}
subparser = mock.Mock()
- command = ActionRunCommand(action, self, subparser, name='test')
+ command = ActionRunCommand(action, self, subparser, name="test")
# 1. File doesn't exist
mockarg = mock.Mock()
mockarg.inherit_env = False
mockarg.auto_dict = True
- mockarg.parameters = [
- '@param_object=doesnt-exist.json'
- ]
+ mockarg.parameters = ["@param_object=doesnt-exist.json"]
- self.assertRaisesRegex(ValueError, "doesn't exist",
- command._get_action_parameters_from_args, action=action,
- runner=runner, args=mockarg)
+ self.assertRaisesRegex(
+ ValueError,
+ "doesn't exist",
+ command._get_action_parameters_from_args,
+ action=action,
+ runner=runner,
+ args=mockarg,
+ )
# 2. Valid file path (we simply read this file)
mockarg = mock.Mock()
mockarg.inherit_env = False
mockarg.auto_dict = True
- mockarg.parameters = [
- '@param_string=%s' % (__file__)
- ]
+ mockarg.parameters = ["@param_string=%s" % (__file__)]
- params = command._get_action_parameters_from_args(action=action,
- runner=runner, args=mockarg)
+ params = command._get_action_parameters_from_args(
+ action=action, runner=runner, args=mockarg
+ )
self.assertTrue(isinstance(params["param_string"], six.text_type))
self.assertTrue(params["param_string"].startswith("# Copyright"))
@@ -212,37 +225,39 @@ def test_get_params_from_args_with_multiple_declarations(self):
runner.runner_parameters = {}
action = Action()
- action.ref = 'test.action'
+ action.ref = "test.action"
action.parameters = {
- 'param_string': {'type': 'string'},
- 'param_array': {'type': 'array'},
- 'param_array_of_dicts': {'type': 'array'},
+ "param_string": {"type": "string"},
+ "param_array": {"type": "array"},
+ "param_array_of_dicts": {"type": "array"},
}
subparser = mock.Mock()
- command = ActionRunCommand(action, self, subparser, name='test')
+ command = ActionRunCommand(action, self, subparser, name="test")
mockarg = mock.Mock()
mockarg.inherit_env = False
mockarg.auto_dict = True
mockarg.parameters = [
- 'param_string=hoge', # This value will be overwritten with the next declaration.
- 'param_string=fuga',
- 'param_array=foo',
- 'param_array=bar',
- 'param_array_of_dicts=foo:1,bar:2',
- 'param_array_of_dicts=hoge:A,fuga:B'
+ "param_string=hoge", # This value will be overwritten with the next declaration.
+ "param_string=fuga",
+ "param_array=foo",
+ "param_array=bar",
+ "param_array_of_dicts=foo:1,bar:2",
+ "param_array_of_dicts=hoge:A,fuga:B",
]
- param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg)
+ param = command._get_action_parameters_from_args(
+ action=action, runner=runner, args=mockarg
+ )
# checks to accept multiple declaration only if the array type
- self.assertEqual(param['param_string'], 'fuga')
- self.assertEqual(param['param_array'], ['foo', 'bar'])
- self.assertEqual(param['param_array_of_dicts'], [
- {'foo': '1', 'bar': '2'},
- {'hoge': 'A', 'fuga': 'B'}
- ])
+ self.assertEqual(param["param_string"], "fuga")
+ self.assertEqual(param["param_array"], ["foo", "bar"])
+ self.assertEqual(
+ param["param_array_of_dicts"],
+ [{"foo": "1", "bar": "2"}, {"hoge": "A", "fuga": "B"}],
+ )
# set auto_dict back to default
mockarg.auto_dict = False
diff --git a/st2client/tests/unit/test_commands.py b/st2client/tests/unit/test_commands.py
index 0748a4aeec..de84f7883f 100644
--- a/st2client/tests/unit/test_commands.py
+++ b/st2client/tests/unit/test_commands.py
@@ -32,97 +32,117 @@
from st2client.commands import resource
from st2client.commands.resource import ResourceViewCommand
-__all__ = [
- 'TestResourceCommand',
- 'ResourceViewCommandTestCase'
-]
+__all__ = ["TestResourceCommand", "ResourceViewCommandTestCase"]
LOG = logging.getLogger(__name__)
class TestResourceCommand(unittest2.TestCase):
-
def __init__(self, *args, **kwargs):
super(TestResourceCommand, self).__init__(*args, **kwargs)
self.parser = argparse.ArgumentParser()
self.subparsers = self.parser.add_subparsers()
self.branch = resource.ResourceBranch(
- base.FakeResource, 'Test Command', base.FakeApp(), self.subparsers)
+ base.FakeResource, "Test Command", base.FakeApp(), self.subparsers
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
def test_command_list(self):
- args = self.parser.parse_args(['fakeresource', 'list'])
- self.assertEqual(args.func, self.branch.commands['list'].run_and_print)
- instances = self.branch.commands['list'].run(args)
+ args = self.parser.parse_args(["fakeresource", "list"])
+ self.assertEqual(args.func, self.branch.commands["list"].run_and_print)
+ instances = self.branch.commands["list"].run(args)
actual = [instance.serialize() for instance in instances]
expected = json.loads(json.dumps(base.RESOURCES))
self.assertListEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_command_list_failed(self):
- args = self.parser.parse_args(['fakeresource', 'list'])
- self.assertRaises(Exception, self.branch.commands['list'].run, args)
+ args = self.parser.parse_args(["fakeresource", "list"])
+ self.assertRaises(Exception, self.branch.commands["list"].run, args)
@mock.patch.object(
- models.ResourceManager, 'get_by_name',
- mock.MagicMock(return_value=None))
+ models.ResourceManager, "get_by_name", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- models.ResourceManager, 'get_by_id',
- mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0])))
+ models.ResourceManager,
+ "get_by_id",
+ mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0])),
+ )
def test_command_get_by_id(self):
- args = self.parser.parse_args(['fakeresource', 'get', '123'])
- self.assertEqual(args.func, self.branch.commands['get'].run_and_print)
- instance = self.branch.commands['get'].run(args)
+ args = self.parser.parse_args(["fakeresource", "get", "123"])
+ self.assertEqual(args.func, self.branch.commands["get"].run_and_print)
+ instance = self.branch.commands["get"].run(args)
actual = instance.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_command_get(self):
- args = self.parser.parse_args(['fakeresource', 'get', 'abc'])
- self.assertEqual(args.func, self.branch.commands['get'].run_and_print)
- instance = self.branch.commands['get'].run(args)
+ args = self.parser.parse_args(["fakeresource", "get", "abc"])
+ self.assertEqual(args.func, self.branch.commands["get"].run_and_print)
+ instance = self.branch.commands["get"].run(args)
actual = instance.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")),
+ )
def test_command_get_404(self):
- args = self.parser.parse_args(['fakeresource', 'get', 'cba'])
- self.assertEqual(args.func, self.branch.commands['get'].run_and_print)
- self.assertRaises(resource.ResourceNotFoundError,
- self.branch.commands['get'].run,
- args)
+ args = self.parser.parse_args(["fakeresource", "get", "cba"])
+ self.assertEqual(args.func, self.branch.commands["get"].run_and_print)
+ self.assertRaises(
+ resource.ResourceNotFoundError, self.branch.commands["get"].run, args
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_command_get_failed(self):
- args = self.parser.parse_args(['fakeresource', 'get', 'cba'])
- self.assertRaises(Exception, self.branch.commands['get'].run, args)
+ args = self.parser.parse_args(["fakeresource", "get", "cba"])
+ self.assertRaises(Exception, self.branch.commands["get"].run, args)
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_command_create(self):
- instance = base.FakeResource(name='abc')
- fd, path = tempfile.mkstemp(suffix='.json')
+ instance = base.FakeResource(name="abc")
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(instance.serialize(), indent=4))
- args = self.parser.parse_args(['fakeresource', 'create', path])
- self.assertEqual(args.func,
- self.branch.commands['create'].run_and_print)
- instance = self.branch.commands['create'].run(args)
+ args = self.parser.parse_args(["fakeresource", "create", path])
+ self.assertEqual(args.func, self.branch.commands["create"].run_and_print)
+ instance = self.branch.commands["create"].run(args)
actual = instance.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@@ -131,40 +151,49 @@ def test_command_create(self):
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_command_create_failed(self):
- instance = base.FakeResource(name='abc')
- fd, path = tempfile.mkstemp(suffix='.json')
+ instance = base.FakeResource(name="abc")
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(instance.serialize(), indent=4))
- args = self.parser.parse_args(['fakeresource', 'create', path])
- self.assertRaises(Exception,
- self.branch.commands['create'].run,
- args)
+ args = self.parser.parse_args(["fakeresource", "create", path])
+ self.assertRaises(Exception, self.branch.commands["create"].run, args)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_command_update(self):
- instance = base.FakeResource(id='123', name='abc')
- fd, path = tempfile.mkstemp(suffix='.json')
+ instance = base.FakeResource(id="123", name="abc")
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(instance.serialize(), indent=4))
- args = self.parser.parse_args(
- ['fakeresource', 'update', '123', path])
- self.assertEqual(args.func,
- self.branch.commands['update'].run_and_print)
- instance = self.branch.commands['update'].run(args)
+ args = self.parser.parse_args(["fakeresource", "update", "123", path])
+ self.assertEqual(args.func, self.branch.commands["update"].run_and_print)
+ instance = self.branch.commands["update"].run(args)
actual = instance.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@@ -173,122 +202,142 @@ def test_command_update(self):
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK")
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_command_update_failed(self):
- instance = base.FakeResource(id='123', name='abc')
- fd, path = tempfile.mkstemp(suffix='.json')
+ instance = base.FakeResource(id="123", name="abc")
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(instance.serialize(), indent=4))
- args = self.parser.parse_args(
- ['fakeresource', 'update', '123', path])
- self.assertRaises(Exception,
- self.branch.commands['update'].run,
- args)
+ args = self.parser.parse_args(["fakeresource", "update", "123", path])
+ self.assertRaises(Exception, self.branch.commands["update"].run, args)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK")
+ ),
+ )
def test_command_update_id_mismatch(self):
- instance = base.FakeResource(id='789', name='abc')
- fd, path = tempfile.mkstemp(suffix='.json')
+ instance = base.FakeResource(id="789", name="abc")
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(instance.serialize(), indent=4))
- args = self.parser.parse_args(
- ['fakeresource', 'update', '123', path])
- self.assertRaises(Exception,
- self.branch.commands['update'].run,
- args)
+ args = self.parser.parse_args(["fakeresource", "update", "123", path])
+ self.assertRaises(Exception, self.branch.commands["update"].run, args)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")),
+ )
def test_command_delete(self):
- args = self.parser.parse_args(['fakeresource', 'delete', 'abc'])
- self.assertEqual(args.func,
- self.branch.commands['delete'].run_and_print)
- self.branch.commands['delete'].run(args)
+ args = self.parser.parse_args(["fakeresource", "delete", "abc"])
+ self.assertEqual(args.func, self.branch.commands["delete"].run_and_print)
+ self.branch.commands["delete"].run(args)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")),
+ )
def test_command_delete_404(self):
- args = self.parser.parse_args(['fakeresource', 'delete', 'cba'])
- self.assertEqual(args.func,
- self.branch.commands['delete'].run_and_print)
- self.assertRaises(resource.ResourceNotFoundError,
- self.branch.commands['delete'].run,
- args)
+ args = self.parser.parse_args(["fakeresource", "delete", "cba"])
+ self.assertEqual(args.func, self.branch.commands["delete"].run_and_print)
+ self.assertRaises(
+ resource.ResourceNotFoundError, self.branch.commands["delete"].run, args
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK")
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_command_delete_failed(self):
- args = self.parser.parse_args(['fakeresource', 'delete', 'cba'])
- self.assertRaises(Exception, self.branch.commands['delete'].run, args)
+ args = self.parser.parse_args(["fakeresource", "delete", "cba"])
+ self.assertRaises(Exception, self.branch.commands["delete"].run, args)
class ResourceViewCommandTestCase(unittest2.TestCase):
-
def setUp(self):
ResourceViewCommand.display_attributes = []
def test_get_include_attributes(self):
- cls = namedtuple('Args', 'attr')
+ cls = namedtuple("Args", "attr")
args = cls(attr=[])
result = ResourceViewCommand._get_include_attributes(args=args)
self.assertEqual(result, [])
- args = cls(attr=['result'])
+ args = cls(attr=["result"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(result, ['result'])
+ self.assertEqual(result, ["result"])
- args = cls(attr=['result', 'trigger_instance'])
+ args = cls(attr=["result", "trigger_instance"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(result, ['result', 'trigger_instance'])
+ self.assertEqual(result, ["result", "trigger_instance"])
- args = cls(attr=['result.stdout'])
+ args = cls(attr=["result.stdout"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(result, ['result.stdout'])
+ self.assertEqual(result, ["result.stdout"])
- args = cls(attr=['result.stdout', 'result.stderr'])
+ args = cls(attr=["result.stdout", "result.stderr"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(result, ['result.stdout', 'result.stderr'])
+ self.assertEqual(result, ["result.stdout", "result.stderr"])
- args = cls(attr=['result.stdout', 'trigger_instance.id'])
+ args = cls(attr=["result.stdout", "trigger_instance.id"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(result, ['result.stdout', 'trigger_instance.id'])
+ self.assertEqual(result, ["result.stdout", "trigger_instance.id"])
- ResourceViewCommand.display_attributes = ['id', 'status']
+ ResourceViewCommand.display_attributes = ["id", "status"]
args = cls(attr=[])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(set(result), set(['id', 'status']))
+ self.assertEqual(set(result), set(["id", "status"]))
- args = cls(attr=['trigger_instance'])
+ args = cls(attr=["trigger_instance"])
result = ResourceViewCommand._get_include_attributes(args=args)
- self.assertEqual(set(result), set(['trigger_instance']))
+ self.assertEqual(set(result), set(["trigger_instance"]))
- args = cls(attr=['all'])
+ args = cls(attr=["all"])
result = ResourceViewCommand._get_include_attributes(args=args)
self.assertEqual(result, None)
@@ -303,20 +352,19 @@ class CommandsHelpStringTestCase(BaseCLITestCase):
# TODO: Automatically iterate all the available commands
COMMANDS = [
# action
- ['action', 'list'],
- ['action', 'get'],
- ['action', 'create'],
- ['action', 'update'],
- ['action', 'delete'],
- ['action', 'enable'],
- ['action', 'disable'],
- ['action', 'execute'],
-
+ ["action", "list"],
+ ["action", "get"],
+ ["action", "create"],
+ ["action", "update"],
+ ["action", "delete"],
+ ["action", "enable"],
+ ["action", "disable"],
+ ["action", "execute"],
# execution
- ['execution', 'cancel'],
- ['execution', 'pause'],
- ['execution', 'resume'],
- ['execution', 'tail']
+ ["execution", "cancel"],
+ ["execution", "pause"],
+ ["execution", "resume"],
+ ["execution", "tail"],
]
def test_help_command_line_arg_works_for_supported_commands(self):
@@ -324,7 +372,7 @@ def test_help_command_line_arg_works_for_supported_commands(self):
for command in self.COMMANDS:
# First test longhang notation
- argv = command + ['--help']
+ argv = command + ["--help"]
try:
result = shell.run(argv)
@@ -335,16 +383,16 @@ def test_help_command_line_arg_works_for_supported_commands(self):
stdout = self.stdout.getvalue()
- self.assertIn('usage:', stdout)
- self.assertIn(' '.join(command), stdout)
+ self.assertIn("usage:", stdout)
+ self.assertIn(" ".join(command), stdout)
# self.assertIn('positional arguments:', stdout)
- self.assertIn('optional arguments:', stdout)
+ self.assertIn("optional arguments:", stdout)
# Reset stdout and stderr after each iteration
self._reset_output_streams()
# Then shorthand notation
- argv = command + ['-h']
+ argv = command + ["-h"]
try:
result = shell.run(argv)
@@ -355,14 +403,14 @@ def test_help_command_line_arg_works_for_supported_commands(self):
stdout = self.stdout.getvalue()
- self.assertIn('usage:', stdout)
- self.assertIn(' '.join(command), stdout)
+ self.assertIn("usage:", stdout)
+ self.assertIn(" ".join(command), stdout)
# self.assertIn('positional arguments:', stdout)
- self.assertIn('optional arguments:', stdout)
+ self.assertIn("optional arguments:", stdout)
# Verify that the actual help usage string was triggered and not the invalid
# "too few arguments" which would indicate command doesn't actually correctly handle
# --help flag
- self.assertNotIn('too few arguments', stdout)
+ self.assertNotIn("too few arguments", stdout)
self._reset_output_streams()
diff --git a/st2client/tests/unit/test_config_parser.py b/st2client/tests/unit/test_config_parser.py
index 35a125ebeb..9cea63ee5a 100644
--- a/st2client/tests/unit/test_config_parser.py
+++ b/st2client/tests/unit/test_config_parser.py
@@ -26,80 +26,77 @@
from st2client.config_parser import CONFIG_DEFAULT_VALUES
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini')
-CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini')
-CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, '../fixtures/test_unicode.ini')
+CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini")
+CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini")
+CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, "../fixtures/test_unicode.ini")
class CLIConfigParserTestCase(unittest2.TestCase):
def test_constructor(self):
- parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False)
+ parser = CLIConfigParser(
+ config_file_path="doesnotexist", validate_config_exists=False
+ )
self.assertTrue(parser)
- self.assertRaises(ValueError, CLIConfigParser, config_file_path='doestnotexist',
- validate_config_exists=True)
+ self.assertRaises(
+ ValueError,
+ CLIConfigParser,
+ config_file_path="doestnotexist",
+ validate_config_exists=True,
+ )
def test_parse(self):
# File doesn't exist
- parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False)
+ parser = CLIConfigParser(
+ config_file_path="doesnotexist", validate_config_exists=False
+ )
result = parser.parse()
self.assertEqual(CONFIG_DEFAULT_VALUES, result)
# File exists - all the options specified
expected = {
- 'general': {
- 'base_url': 'http://127.0.0.1',
- 'api_version': 'v1',
- 'cacert': 'cacartpath',
- 'silence_ssl_warnings': False,
- 'silence_schema_output': True
+ "general": {
+ "base_url": "http://127.0.0.1",
+ "api_version": "v1",
+ "cacert": "cacartpath",
+ "silence_ssl_warnings": False,
+ "silence_schema_output": True,
},
- 'cli': {
- 'debug': True,
- 'cache_token': False,
- 'timezone': 'UTC'
- },
- 'credentials': {
- 'username': 'test1',
- 'password': 'test1',
- 'api_key': None
- },
- 'api': {
- 'url': 'http://127.0.0.1:9101/v1'
- },
- 'auth': {
- 'url': 'http://127.0.0.1:9100/'
- },
- 'stream': {
- 'url': 'http://127.0.0.1:9102/v1/stream'
- }
+ "cli": {"debug": True, "cache_token": False, "timezone": "UTC"},
+ "credentials": {"username": "test1", "password": "test1", "api_key": None},
+ "api": {"url": "http://127.0.0.1:9101/v1"},
+ "auth": {"url": "http://127.0.0.1:9100/"},
+ "stream": {"url": "http://127.0.0.1:9102/v1/stream"},
}
- parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_FULL,
- validate_config_exists=False)
+ parser = CLIConfigParser(
+ config_file_path=CONFIG_FILE_PATH_FULL, validate_config_exists=False
+ )
result = parser.parse()
self.assertEqual(expected, result)
# File exists - missing options, test defaults
- parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_PARTIAL,
- validate_config_exists=False)
+ parser = CLIConfigParser(
+ config_file_path=CONFIG_FILE_PATH_PARTIAL, validate_config_exists=False
+ )
result = parser.parse()
- self.assertTrue(result['cli']['cache_token'], True)
+ self.assertTrue(result["cli"]["cache_token"], True)
def test_get_config_for_unicode_char(self):
- parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_UNICODE,
- validate_config_exists=False)
+ parser = CLIConfigParser(
+ config_file_path=CONFIG_FILE_PATH_UNICODE, validate_config_exists=False
+ )
config = parser.parse()
if six.PY3:
- self.assertEqual(config['credentials']['password'], '密码')
+ self.assertEqual(config["credentials"]["password"], "密码")
else:
- self.assertEqual(config['credentials']['password'], u'\u5bc6\u7801')
+ self.assertEqual(config["credentials"]["password"], "\u5bc6\u7801")
class CLIConfigPermissionsTestCase(unittest2.TestCase):
def setUp(self):
- self.TEMP_FILE_PATH = os.path.join('st2config', '.st2', 'config')
+ self.TEMP_FILE_PATH = os.path.join("st2config", ".st2", "config")
self.TEMP_CONFIG_DIR = os.path.dirname(self.TEMP_FILE_PATH)
if os.path.exists(self.TEMP_FILE_PATH):
@@ -135,7 +132,9 @@ def test_correct_permissions_emit_no_warnings(self):
self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o660)
- parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True)
+ parser = CLIConfigParser(
+ config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True
+ )
parser.LOG = mock.Mock()
result = parser.parse() # noqa F841
@@ -159,7 +158,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self):
self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o640)
- parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True)
+ parser = CLIConfigParser(
+ config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True
+ )
parser.LOG = mock.Mock()
result = parser.parse() # noqa F841
@@ -175,7 +176,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self):
self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o600)
- parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True)
+ parser = CLIConfigParser(
+ config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True
+ )
parser.LOG = mock.Mock()
result = parser.parse() # noqa F841
@@ -200,7 +203,9 @@ def test_warn_on_bad_config_permissions(self):
self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770)
- parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True)
+ parser = CLIConfigParser(
+ config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True
+ )
parser.LOG = mock.Mock()
result = parser.parse() # noqa F841
@@ -209,17 +214,20 @@ def test_warn_on_bad_config_permissions(self):
self.assertEqual(
"The SGID bit is not set on the StackStorm configuration directory.",
- parser.LOG.info.call_args_list[0][0][0])
+ parser.LOG.info.call_args_list[0][0][0],
+ )
self.assertEqual(parser.LOG.warn.call_count, 2)
self.assertEqual(
"The StackStorm configuration directory permissions are insecure "
"(too permissive): others have access.",
- parser.LOG.warn.call_args_list[0][0][0])
+ parser.LOG.warn.call_args_list[0][0][0],
+ )
self.assertEqual(
"The StackStorm configuration file permissions are insecure: others have access.",
- parser.LOG.warn.call_args_list[1][0][0])
+ parser.LOG.warn.call_args_list[1][0][0],
+ )
# Make sure we left the file alone
self.assertTrue(os.path.exists(self.TEMP_FILE_PATH))
@@ -239,9 +247,11 @@ def test_disable_permissions_warnings(self):
self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770)
- parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH,
- validate_config_exists=True,
- validate_config_permissions=False)
+ parser = CLIConfigParser(
+ config_file_path=self.TEMP_FILE_PATH,
+ validate_config_exists=True,
+ validate_config_permissions=False,
+ )
parser.LOG = mock.Mock()
result = parser.parse() # noqa F841
diff --git a/st2client/tests/unit/test_execution_tail_command.py b/st2client/tests/unit/test_execution_tail_command.py
index 15500767f2..08957ddbf1 100644
--- a/st2client/tests/unit/test_execution_tail_command.py
+++ b/st2client/tests/unit/test_execution_tail_command.py
@@ -27,247 +27,180 @@
from st2client.commands.action import LIVEACTION_STATUS_TIMED_OUT
from st2client.shell import Shell
-__all__ = [
- 'ActionExecutionTailCommandTestCase'
-]
+__all__ = ["ActionExecutionTailCommandTestCase"]
# Mock objects
-MOCK_LIVEACTION_1_RUNNING = {
- 'id': 'idfoo1',
- 'status': LIVEACTION_STATUS_RUNNING
-}
+MOCK_LIVEACTION_1_RUNNING = {"id": "idfoo1", "status": LIVEACTION_STATUS_RUNNING}
-MOCK_LIVEACTION_1_SUCCEEDED = {
- 'id': 'idfoo1',
- 'status': LIVEACTION_STATUS_SUCCEEDED
-}
+MOCK_LIVEACTION_1_SUCCEEDED = {"id": "idfoo1", "status": LIVEACTION_STATUS_SUCCEEDED}
-MOCK_LIVEACTION_2_FAILED = {
- 'id': 'idfoo2',
- 'status': LIVEACTION_STATUS_FAILED
-}
+MOCK_LIVEACTION_2_FAILED = {"id": "idfoo2", "status": LIVEACTION_STATUS_FAILED}
# Mock liveaction objects for ActionChain workflow
-MOCK_LIVEACTION_3_RUNNING = {
- 'id': 'idfoo3',
- 'status': LIVEACTION_STATUS_RUNNING
-}
+MOCK_LIVEACTION_3_RUNNING = {"id": "idfoo3", "status": LIVEACTION_STATUS_RUNNING}
MOCK_LIVEACTION_3_CHILD_1_RUNNING = {
- 'id': 'idchild1',
- 'context': {
- 'parent': {
- 'execution_id': 'idfoo3'
- },
- 'chain': {
- 'name': 'task_1'
- }
- },
- 'status': LIVEACTION_STATUS_RUNNING
+ "id": "idchild1",
+ "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}},
+ "status": LIVEACTION_STATUS_RUNNING,
}
MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED = {
- 'id': 'idchild1',
- 'context': {
- 'parent': {
- 'execution_id': 'idfoo3'
- },
- 'chain': {
- 'name': 'task_1'
- }
- },
- 'status': LIVEACTION_STATUS_SUCCEEDED
+ "id": "idchild1",
+ "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}},
+ "status": LIVEACTION_STATUS_SUCCEEDED,
}
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1 = {
- 'execution_id': 'idchild1',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line ac 4\n'
+ "execution_id": "idchild1",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line ac 4\n",
}
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2 = {
- 'execution_id': 'idchild1',
- 'timestamp': '1505732598',
- 'output_type': 'stderr',
- 'data': 'line ac 5\n'
+ "execution_id": "idchild1",
+ "timestamp": "1505732598",
+ "output_type": "stderr",
+ "data": "line ac 5\n",
}
MOCK_LIVEACTION_3_CHILD_2_RUNNING = {
- 'id': 'idchild2',
- 'context': {
- 'parent': {
- 'execution_id': 'idfoo3'
- },
- 'chain': {
- 'name': 'task_2'
- }
- },
- 'status': LIVEACTION_STATUS_RUNNING
+ "id": "idchild2",
+ "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}},
+ "status": LIVEACTION_STATUS_RUNNING,
}
MOCK_LIVEACTION_3_CHILD_2_FAILED = {
- 'id': 'idchild2',
- 'context': {
- 'parent': {
- 'execution_id': 'idfoo3'
- },
- 'chain': {
- 'name': 'task_2'
- }
- },
- 'status': LIVEACTION_STATUS_FAILED
+ "id": "idchild2",
+ "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}},
+ "status": LIVEACTION_STATUS_FAILED,
}
MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1 = {
- 'execution_id': 'idchild2',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line ac 100\n'
+ "execution_id": "idchild2",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line ac 100\n",
}
-MOCK_LIVEACTION_3_SUCCEDED = {
- 'id': 'idfoo3',
- 'status': LIVEACTION_STATUS_SUCCEEDED
-}
+MOCK_LIVEACTION_3_SUCCEDED = {"id": "idfoo3", "status": LIVEACTION_STATUS_SUCCEEDED}
# Mock objects for Orquesta workflow execution
-MOCK_LIVEACTION_4_RUNNING = {
- 'id': 'idfoo4',
- 'status': LIVEACTION_STATUS_RUNNING
-}
+MOCK_LIVEACTION_4_RUNNING = {"id": "idfoo4", "status": LIVEACTION_STATUS_RUNNING}
MOCK_LIVEACTION_4_CHILD_1_RUNNING = {
- 'id': 'idorquestachild1',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_1'
- },
- 'parent': {
- 'execution_id': 'idfoo4'
- }
+ "id": "idorquestachild1",
+ "context": {
+ "orquesta": {"task_name": "task_1"},
+ "parent": {"execution_id": "idfoo4"},
},
- 'status': LIVEACTION_STATUS_RUNNING
+ "status": LIVEACTION_STATUS_RUNNING,
}
MOCK_LIVEACTION_4_CHILD_1_1_RUNNING = {
- 'id': 'idorquestachild1_1',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_1'
- },
- 'parent': {
- 'execution_id': 'idorquestachild1'
- }
+ "id": "idorquestachild1_1",
+ "context": {
+ "orquesta": {"task_name": "task_1"},
+ "parent": {"execution_id": "idorquestachild1"},
},
- 'status': LIVEACTION_STATUS_RUNNING
+ "status": LIVEACTION_STATUS_RUNNING,
}
MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED = {
- 'id': 'idorquestachild1',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_1',
+ "id": "idorquestachild1",
+ "context": {
+ "orquesta": {
+ "task_name": "task_1",
},
- 'parent': {
- 'execution_id': 'idfoo4'
- }
+ "parent": {"execution_id": "idfoo4"},
},
- 'status': LIVEACTION_STATUS_SUCCEEDED
+ "status": LIVEACTION_STATUS_SUCCEEDED,
}
MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED = {
- 'id': 'idorquestachild1_1',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_1',
+ "id": "idorquestachild1_1",
+ "context": {
+ "orquesta": {
+ "task_name": "task_1",
},
- 'parent': {
- 'execution_id': 'idorquestachild1'
- }
+ "parent": {"execution_id": "idorquestachild1"},
},
- 'status': LIVEACTION_STATUS_SUCCEEDED
+ "status": LIVEACTION_STATUS_SUCCEEDED,
}
MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1 = {
- 'execution_id': 'idorquestachild1',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line orquesta 4\n'
+ "execution_id": "idorquestachild1",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line orquesta 4\n",
}
MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2 = {
- 'execution_id': 'idorquestachild1',
- 'timestamp': '1505732598',
- 'output_type': 'stderr',
- 'data': 'line orquesta 5\n'
+ "execution_id": "idorquestachild1",
+ "timestamp": "1505732598",
+ "output_type": "stderr",
+ "data": "line orquesta 5\n",
}
MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1 = {
- 'execution_id': 'idorquestachild1_1',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line orquesta 4\n'
+ "execution_id": "idorquestachild1_1",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line orquesta 4\n",
}
MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2 = {
- 'execution_id': 'idorquestachild1_1',
- 'timestamp': '1505732598',
- 'output_type': 'stderr',
- 'data': 'line orquesta 5\n'
+ "execution_id": "idorquestachild1_1",
+ "timestamp": "1505732598",
+ "output_type": "stderr",
+ "data": "line orquesta 5\n",
}
MOCK_LIVEACTION_4_CHILD_2_RUNNING = {
- 'id': 'idorquestachild2',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_2',
+ "id": "idorquestachild2",
+ "context": {
+ "orquesta": {
+ "task_name": "task_2",
},
- 'parent': {
- 'execution_id': 'idfoo4'
- }
+ "parent": {"execution_id": "idfoo4"},
},
- 'status': LIVEACTION_STATUS_RUNNING
+ "status": LIVEACTION_STATUS_RUNNING,
}
MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT = {
- 'id': 'idorquestachild2',
- 'context': {
- 'orquesta': {
- 'task_name': 'task_2',
+ "id": "idorquestachild2",
+ "context": {
+ "orquesta": {
+ "task_name": "task_2",
},
- 'parent': {
- 'execution_id': 'idfoo4'
- }
+ "parent": {"execution_id": "idfoo4"},
},
- 'status': LIVEACTION_STATUS_TIMED_OUT
+ "status": LIVEACTION_STATUS_TIMED_OUT,
}
MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1 = {
- 'execution_id': 'idorquestachild2',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line orquesta 100\n'
+ "execution_id": "idorquestachild2",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line orquesta 100\n",
}
-MOCK_LIVEACTION_4_SUCCEDED = {
- 'id': 'idfoo4',
- 'status': LIVEACTION_STATUS_SUCCEEDED
-}
+MOCK_LIVEACTION_4_SUCCEDED = {"id": "idfoo4", "status": LIVEACTION_STATUS_SUCCEEDED}
# Mock objects for simple actions
MOCK_OUTPUT_1 = {
- 'execution_id': 'idfoo3',
- 'timestamp': '1505732598',
- 'output_type': 'stdout',
- 'data': 'line 1\n'
+ "execution_id": "idfoo3",
+ "timestamp": "1505732598",
+ "output_type": "stdout",
+ "data": "line 1\n",
}
MOCK_OUTPUT_2 = {
- 'execution_id': 'idfoo3',
- 'timestamp': '1505732598',
- 'output_type': 'stderr',
- 'data': 'line 2\n'
+ "execution_id": "idfoo3",
+ "timestamp": "1505732598",
+ "output_type": "stderr",
+ "data": "line 2\n",
}
@@ -279,42 +212,55 @@ def __init__(self, *args, **kwargs):
self.shell = Shell()
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_SUCCEEDED),
- 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_1_SUCCEEDED), 200, "OK"
+ )
+ ),
+ )
def test_tail_simple_execution_already_finished_succeeded(self):
- argv = ['execution', 'tail', 'idfoo1']
+ argv = ["execution", "tail", "idfoo1"]
self.assertEqual(self.shell.run(argv), 0)
stdout = self.stdout.getvalue()
stderr = self.stderr.getvalue()
- self.assertIn('Execution idfoo1 has completed (status=succeeded)', stdout)
- self.assertEqual(stderr, '')
+ self.assertIn("Execution idfoo1 has completed (status=succeeded)", stdout)
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_2_FAILED),
- 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_2_FAILED), 200, "OK"
+ )
+ ),
+ )
def test_tail_simple_execution_already_finished_failed(self):
- argv = ['execution', 'tail', 'idfoo2']
+ argv = ["execution", "tail", "idfoo2"]
self.assertEqual(self.shell.run(argv), 0)
stdout = self.stdout.getvalue()
stderr = self.stderr.getvalue()
- self.assertIn('Execution idfoo2 has completed (status=failed)', stdout)
- self.assertEqual(stderr, '')
+ self.assertIn("Execution idfoo2 has completed (status=failed)", stdout)
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_1_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo1']
+ argv = ["execution", "tail", "idfoo1"]
- MOCK_EVENTS = [
- MOCK_LIVEACTION_1_SUCCEEDED
- ]
+ MOCK_EVENTS = [MOCK_LIVEACTION_1_SUCCEEDED]
mock_cls = mock.Mock()
mock_cls.listen = mock.Mock()
@@ -333,21 +279,26 @@ def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manage
Execution idfoo1 has completed (status=succeeded).
"""
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_simple_execution_running_with_data(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo3']
+ argv = ["execution", "tail", "idfoo3"]
MOCK_EVENTS = [
MOCK_LIVEACTION_3_RUNNING,
MOCK_OUTPUT_1,
MOCK_OUTPUT_2,
- MOCK_LIVEACTION_3_SUCCEDED
+ MOCK_LIVEACTION_3_SUCCEDED,
]
mock_cls = mock.Mock()
@@ -372,41 +323,39 @@ def test_tail_simple_execution_running_with_data(self, mock_stream_manager):
Execution idfoo3 has completed (status=succeeded).
""".lstrip()
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_action_chain_workflow_execution(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo3']
+ argv = ["execution", "tail", "idfoo3"]
MOCK_EVENTS = [
# Workflow started running
MOCK_LIVEACTION_3_RUNNING,
-
# Child task 1 started running
MOCK_LIVEACTION_3_CHILD_1_RUNNING,
-
# Output produced by the child task
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1,
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2,
-
# Child task 1 finished
MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED,
-
# Child task 2 started running
MOCK_LIVEACTION_3_CHILD_2_RUNNING,
-
# Output produced by child task
MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1,
-
# Child task 2 finished
MOCK_LIVEACTION_3_CHILD_2_FAILED,
-
# Parent workflow task finished
- MOCK_LIVEACTION_3_SUCCEDED
+ MOCK_LIVEACTION_3_SUCCEDED,
]
mock_cls = mock.Mock()
@@ -440,41 +389,39 @@ def test_tail_action_chain_workflow_execution(self, mock_stream_manager):
Execution idfoo3 has completed (status=succeeded).
""".lstrip()
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_orquesta_workflow_execution(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo4']
+ argv = ["execution", "tail", "idfoo4"]
MOCK_EVENTS = [
# Workflow started running
MOCK_LIVEACTION_4_RUNNING,
-
# Child task 1 started running
MOCK_LIVEACTION_4_CHILD_1_RUNNING,
-
# Output produced by the child task
MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1,
MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2,
-
# Child task 1 finished
MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED,
-
# Child task 2 started running
MOCK_LIVEACTION_4_CHILD_2_RUNNING,
-
# Output produced by child task
MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1,
-
# Child task 2 finished
MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT,
-
# Parent workflow task finished
- MOCK_LIVEACTION_4_SUCCEDED
+ MOCK_LIVEACTION_4_SUCCEDED,
]
mock_cls = mock.Mock()
@@ -508,64 +455,55 @@ def test_tail_orquesta_workflow_execution(self, mock_stream_manager):
Execution idfoo4 has completed (status=succeeded).
""".lstrip()
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo4']
+ argv = ["execution", "tail", "idfoo4"]
MOCK_EVENTS = [
# Workflow started running
MOCK_LIVEACTION_4_RUNNING,
-
# Child task 1 started running (sub workflow)
MOCK_LIVEACTION_4_CHILD_1_RUNNING,
-
# Child task 1 started running
MOCK_LIVEACTION_4_CHILD_1_1_RUNNING,
-
# Output produced by the child task
MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1,
MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2,
-
# Another execution has started, this output should not be included
MOCK_LIVEACTION_3_RUNNING,
-
# Child task 1 started running
MOCK_LIVEACTION_3_CHILD_1_RUNNING,
-
# Output produced by the child task
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1,
MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2,
-
# Child task 1 finished
MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED,
-
# Parent workflow task finished
MOCK_LIVEACTION_3_SUCCEDED,
# End another execution
-
# Child task 1 has finished
MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED,
-
# Child task 1 finished (sub workflow)
MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED,
-
# Child task 2 started running
MOCK_LIVEACTION_4_CHILD_2_RUNNING,
-
# Output produced by child task
MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1,
-
# Child task 2 finished
MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT,
-
# Parent workflow task finished
- MOCK_LIVEACTION_4_SUCCEDED
+ MOCK_LIVEACTION_4_SUCCEDED,
]
mock_cls = mock.Mock()
@@ -604,32 +542,33 @@ def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manage
""".lstrip()
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING),
- 200, 'OK')))
- @mock.patch('st2client.client.StreamManager', autospec=True)
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING), 200, "OK"
+ )
+ ),
+ )
+ @mock.patch("st2client.client.StreamManager", autospec=True)
def test_tail_child_execution_directly(self, mock_stream_manager):
- argv = ['execution', 'tail', 'idfoo4']
+ argv = ["execution", "tail", "idfoo4"]
MOCK_EVENTS = [
# Child task 2 started running
MOCK_LIVEACTION_4_CHILD_2_RUNNING,
-
# Output produced by child task
MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1,
-
# Other executions should not interfere
# Child task 1 started running
MOCK_LIVEACTION_3_CHILD_1_RUNNING,
-
# Child task 1 finished (sub workflow)
MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED,
-
# Child task 2 finished
- MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT
+ MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT,
]
mock_cls = mock.Mock()
@@ -654,4 +593,4 @@ def test_tail_child_execution_directly(self, mock_stream_manager):
""".lstrip()
self.assertEqual(stdout, expected_result)
- self.assertEqual(stderr, '')
+ self.assertEqual(stderr, "")
diff --git a/st2client/tests/unit/test_formatters.py b/st2client/tests/unit/test_formatters.py
index b3733faba5..fe0370aea1 100644
--- a/st2client/tests/unit/test_formatters.py
+++ b/st2client/tests/unit/test_formatters.py
@@ -39,38 +39,43 @@
LOG = logging.getLogger(__name__)
FIXTURES_MANIFEST = {
- 'executions': ['execution.json',
- 'execution_result_has_carriage_return.json',
- 'execution_unicode.json',
- 'execution_double_backslash.json',
- 'execution_with_stack_trace.json',
- 'execution_with_schema.json'],
- 'results': ['execution_get_default.txt',
- 'execution_get_detail.txt',
- 'execution_get_result_by_key.txt',
- 'execution_result_has_carriage_return.txt',
- 'execution_result_has_carriage_return_py3.txt',
- 'execution_get_attributes.txt',
- 'execution_list_attr_start_timestamp.txt',
- 'execution_list_empty_response_start_timestamp_attr.txt',
- 'execution_unescape_newline.txt',
- 'execution_unicode.txt',
- 'execution_double_backslash.txt',
- 'execution_unicode_py3.txt',
- 'execution_get_has_schema.txt']
+ "executions": [
+ "execution.json",
+ "execution_result_has_carriage_return.json",
+ "execution_unicode.json",
+ "execution_double_backslash.json",
+ "execution_with_stack_trace.json",
+ "execution_with_schema.json",
+ ],
+ "results": [
+ "execution_get_default.txt",
+ "execution_get_detail.txt",
+ "execution_get_result_by_key.txt",
+ "execution_result_has_carriage_return.txt",
+ "execution_result_has_carriage_return_py3.txt",
+ "execution_get_attributes.txt",
+ "execution_list_attr_start_timestamp.txt",
+ "execution_list_empty_response_start_timestamp_attr.txt",
+ "execution_unescape_newline.txt",
+ "execution_unicode.txt",
+ "execution_double_backslash.txt",
+ "execution_unicode_py3.txt",
+ "execution_get_has_schema.txt",
+ ],
}
FIXTURES = loader.load_fixtures(fixtures_dict=FIXTURES_MANIFEST)
-EXECUTION = FIXTURES['executions']['execution.json']
-UNICODE = FIXTURES['executions']['execution_unicode.json']
-DOUBLE_BACKSLASH = FIXTURES['executions']['execution_double_backslash.json']
-OUTPUT_SCHEMA = FIXTURES['executions']['execution_with_schema.json']
-NEWLINE = FIXTURES['executions']['execution_with_stack_trace.json']
-HAS_CARRIAGE_RETURN = FIXTURES['executions']['execution_result_has_carriage_return.json']
+EXECUTION = FIXTURES["executions"]["execution.json"]
+UNICODE = FIXTURES["executions"]["execution_unicode.json"]
+DOUBLE_BACKSLASH = FIXTURES["executions"]["execution_double_backslash.json"]
+OUTPUT_SCHEMA = FIXTURES["executions"]["execution_with_schema.json"]
+NEWLINE = FIXTURES["executions"]["execution_with_stack_trace.json"]
+HAS_CARRIAGE_RETURN = FIXTURES["executions"][
+ "execution_result_has_carriage_return.json"
+]
class TestExecutionResultFormatter(unittest2.TestCase):
-
def __init__(self, *args, **kwargs):
super(TestExecutionResultFormatter, self).__init__(*args, **kwargs)
self.shell = shell.Shell()
@@ -88,212 +93,278 @@ def tearDown(self):
os.unlink(self.path)
def _redirect_console(self, path):
- sys.stdout = open(path, 'w')
- sys.stderr = open(path, 'w')
+ sys.stdout = open(path, "w")
+ sys.stderr = open(path, "w")
def _undo_console_redirect(self):
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
def test_console_redirect(self):
- message = 'Hello, World!'
+ message = "Hello, World!"
print(message)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
- content = fd.read().replace('\n', '')
+ with open(self.path, "r") as fd:
+ content = fd.read().replace("\n", "")
self.assertEqual(content, message)
def test_execution_get_default(self):
- argv = ['execution', 'get', EXECUTION['id']]
+ argv = ["execution", "get", EXECUTION["id"]]
content = self._get_execution(argv)
- self.assertEqual(content, FIXTURES['results']['execution_get_default.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_get_default.txt"])
def test_execution_get_attributes(self):
- argv = ['execution', 'get', EXECUTION['id'], '--attr', 'status', 'end_timestamp']
+ argv = [
+ "execution",
+ "get",
+ EXECUTION["id"],
+ "--attr",
+ "status",
+ "end_timestamp",
+ ]
content = self._get_execution(argv)
- self.assertEqual(content, FIXTURES['results']['execution_get_attributes.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_get_attributes.txt"])
def test_execution_get_default_in_json(self):
- argv = ['execution', 'get', EXECUTION['id'], '-j']
+ argv = ["execution", "get", EXECUTION["id"], "-j"]
content = self._get_execution(argv)
- self.assertEqual(json.loads(content),
- jsutil.get_kvps(EXECUTION, ['id', 'action.ref', 'context.user',
- 'start_timestamp', 'end_timestamp', 'status',
- 'parameters', 'result']))
+ self.assertEqual(
+ json.loads(content),
+ jsutil.get_kvps(
+ EXECUTION,
+ [
+ "id",
+ "action.ref",
+ "context.user",
+ "start_timestamp",
+ "end_timestamp",
+ "status",
+ "parameters",
+ "result",
+ ],
+ ),
+ )
def test_execution_get_detail(self):
- argv = ['execution', 'get', EXECUTION['id'], '-d']
+ argv = ["execution", "get", EXECUTION["id"], "-d"]
content = self._get_execution(argv)
- self.assertEqual(content, FIXTURES['results']['execution_get_detail.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_get_detail.txt"])
def test_execution_with_schema(self):
- argv = ['execution', 'get', OUTPUT_SCHEMA['id']]
+ argv = ["execution", "get", OUTPUT_SCHEMA["id"]]
content = self._get_schema_execution(argv)
- self.assertEqual(content, FIXTURES['results']['execution_get_has_schema.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_get_has_schema.txt"])
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(NEWLINE), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(NEWLINE), 200, "OK", {})
+ ),
+ )
def test_execution_unescape_newline(self):
- """Ensure client renders newline characters
- """
+ """Ensure client renders newline characters"""
- argv = ['execution', 'get', NEWLINE['id']]
+ argv = ["execution", "get", NEWLINE["id"]]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
- self.assertEqual(content, FIXTURES['results']['execution_unescape_newline.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_unescape_newline.txt"])
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(UNICODE), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(UNICODE), 200, "OK", {})
+ ),
+ )
def test_execution_unicode(self):
- """Ensure client renders unicode escape sequences
- """
+ """Ensure client renders unicode escape sequences"""
- argv = ['execution', 'get', UNICODE['id']]
+ argv = ["execution", "get", UNICODE["id"]]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
if six.PY2:
- self.assertEqual(content, FIXTURES['results']['execution_unicode.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_unicode.txt"])
else:
- content = content.replace(r'\xE2\x80\xA1', r'\u2021')
- self.assertEqual(content, FIXTURES['results']['execution_unicode_py3.txt'])
+ content = content.replace(r"\xE2\x80\xA1", r"\u2021")
+ self.assertEqual(content, FIXTURES["results"]["execution_unicode_py3.txt"])
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, "OK", {})
+ ),
+ )
def test_execution_double_backslash_not_unicode_escape_sequence(self):
- argv = ['execution', 'get', DOUBLE_BACKSLASH['id']]
+ argv = ["execution", "get", DOUBLE_BACKSLASH["id"]]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
- self.assertEqual(content, FIXTURES['results']['execution_double_backslash.txt'])
+ self.assertEqual(content, FIXTURES["results"]["execution_double_backslash.txt"])
def test_execution_get_detail_in_json(self):
- argv = ['execution', 'get', EXECUTION['id'], '-d', '-j']
+ argv = ["execution", "get", EXECUTION["id"], "-d", "-j"]
content = self._get_execution(argv)
content_dict = json.loads(content)
# Sufficient to check if output contains all expected keys. The entire result will not
# match as content will contain characters which improve rendering.
for k in six.iterkeys(EXECUTION):
- if k in ['liveaction', 'callback']:
+ if k in ["liveaction", "callback"]:
continue
if k in content:
continue
- self.assertTrue(False, 'Missing key %s. %s != %s' % (k, EXECUTION, content_dict))
+ self.assertTrue(
+ False, "Missing key %s. %s != %s" % (k, EXECUTION, content_dict)
+ )
def test_execution_get_result_by_key(self):
- argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout']
+ argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout"]
content = self._get_execution(argv)
- self.assertEqual(content, FIXTURES['results']['execution_get_result_by_key.txt'])
+ self.assertEqual(
+ content, FIXTURES["results"]["execution_get_result_by_key.txt"]
+ )
def test_execution_get_result_by_key_in_json(self):
- argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout', '-j']
+ argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout", "-j"]
content = self._get_execution(argv)
- self.assertDictEqual(json.loads(content),
- jsutil.get_kvps(EXECUTION, ['result.localhost.stdout']))
+ self.assertDictEqual(
+ json.loads(content), jsutil.get_kvps(EXECUTION, ["result.localhost.stdout"])
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(HAS_CARRIAGE_RETURN), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(HAS_CARRIAGE_RETURN), 200, "OK", {}
+ )
+ ),
+ )
def test_execution_get_detail_with_carriage_return(self):
- argv = ['execution', 'get', HAS_CARRIAGE_RETURN['id'], '-d']
+ argv = ["execution", "get", HAS_CARRIAGE_RETURN["id"], "-d"]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
if six.PY2:
self.assertEqual(
- content, FIXTURES['results']['execution_result_has_carriage_return.txt'])
+ content, FIXTURES["results"]["execution_result_has_carriage_return.txt"]
+ )
else:
self.assertEqual(
content,
- FIXTURES['results']['execution_result_has_carriage_return_py3.txt'])
+ FIXTURES["results"]["execution_result_has_carriage_return_py3.txt"],
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK", {})
+ ),
+ )
def test_execution_list_attribute_provided(self):
# Client shouldn't throw if "-a" flag is provided when listing executions
- argv = ['execution', 'list', '-a', 'start_timestamp']
+ argv = ["execution", "list", "-a", "start_timestamp"]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
self.assertEqual(
- content, FIXTURES['results']['execution_list_attr_start_timestamp.txt'])
+ content, FIXTURES["results"]["execution_list_attr_start_timestamp.txt"]
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK", {})),
+ )
def test_execution_list_attribute_provided_empty_response(self):
# Client shouldn't throw if "-a" flag is provided, but there are no executions
- argv = ['execution', 'list', '-a', 'start_timestamp']
+ argv = ["execution", "list", "-a", "start_timestamp"]
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
self.assertEqual(
- content, FIXTURES['results']['execution_list_empty_response_start_timestamp_attr.txt'])
+ content,
+ FIXTURES["results"][
+ "execution_list_empty_response_start_timestamp_attr.txt"
+ ],
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK", {})
+ ),
+ )
def _get_execution(self, argv):
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
return content
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, 'OK', {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, "OK", {})
+ ),
+ )
def _get_schema_execution(self, argv):
self.assertEqual(self.shell.run(argv), 0)
self._undo_console_redirect()
- with open(self.path, 'r') as fd:
+ with open(self.path, "r") as fd:
content = fd.read()
return content
def test_SinlgeRowTable_notebox_one(self):
- with mock.patch('sys.stderr', new=StringIO()) as fackety_fake:
- expected = "Note: Only one action execution is displayed. Use -n/--last flag for " \
+ with mock.patch("sys.stderr", new=StringIO()) as fackety_fake:
+ expected = (
+ "Note: Only one action execution is displayed. Use -n/--last flag for "
"more results."
+ )
print(self.table.note_box("action executions", 1))
- content = (fackety_fake.getvalue().split("|")[1].strip())
+ content = fackety_fake.getvalue().split("|")[1].strip()
self.assertEqual(content, expected)
def test_SinlgeRowTable_notebox_zero(self):
- with mock.patch('sys.stderr', new=BytesIO()) as fackety_fake:
- contents = (fackety_fake.getvalue())
- self.assertEqual(contents, b'')
+ with mock.patch("sys.stderr", new=BytesIO()) as fackety_fake:
+ contents = fackety_fake.getvalue()
+ self.assertEqual(contents, b"")
def test_SinlgeRowTable_notebox_default(self):
- with mock.patch('sys.stderr', new=StringIO()) as fackety_fake:
- expected = "Note: Only first 50 action executions are displayed. Use -n/--last flag " \
+ with mock.patch("sys.stderr", new=StringIO()) as fackety_fake:
+ expected = (
+ "Note: Only first 50 action executions are displayed. Use -n/--last flag "
"for more results."
+ )
print(self.table.note_box("action executions", 50))
- content = (fackety_fake.getvalue().split("|")[1].strip())
+ content = fackety_fake.getvalue().split("|")[1].strip()
self.assertEqual(content, expected)
- with mock.patch('sys.stderr', new=StringIO()) as fackety_fake:
- expected = "Note: Only first 15 action executions are displayed. Use -n/--last flag " \
+ with mock.patch("sys.stderr", new=StringIO()) as fackety_fake:
+ expected = (
+ "Note: Only first 15 action executions are displayed. Use -n/--last flag "
"for more results."
+ )
print(self.table.note_box("action executions", 15))
- content = (fackety_fake.getvalue().split("|")[1].strip())
+ content = fackety_fake.getvalue().split("|")[1].strip()
self.assertEqual(content, expected)
diff --git a/st2client/tests/unit/test_inquiry.py b/st2client/tests/unit/test_inquiry.py
index 138f1da899..4132fda0d1 100644
--- a/st2client/tests/unit/test_inquiry.py
+++ b/st2client/tests/unit/test_inquiry.py
@@ -31,12 +31,12 @@
def _randomize_inquiry_id(inquiry):
newinquiry = copy.deepcopy(inquiry)
- newinquiry['id'] = str(uuid.uuid4())
+ newinquiry["id"] = str(uuid.uuid4())
# ID can't have '1440' in it, otherwise our `count()` fails
# when inspecting the inquiry list output for test:
# test_list_inquiries_limit()
- while '1440' in newinquiry['id']:
- newinquiry['id'] = str(uuid.uuid4())
+ while "1440" in newinquiry["id"]:
+ newinquiry["id"] = str(uuid.uuid4())
return newinquiry
@@ -45,8 +45,7 @@ def _generate_inquiries(count):
class TestInquiryBase(base.BaseCLITestCase):
- """Base class for "inquiry" CLI tests
- """
+ """Base class for "inquiry" CLI tests"""
capture_output = True
@@ -54,8 +53,8 @@ def __init__(self, *args, **kwargs):
super(TestInquiryBase, self).__init__(*args, **kwargs)
self.parser = argparse.ArgumentParser()
- self.parser.add_argument('-t', '--token', dest='token')
- self.parser.add_argument('--api-key', dest='api_key')
+ self.parser.add_argument("-t", "--token", dest="token")
+ self.parser.add_argument("--api-key", dest="api_key")
self.shell = shell.Shell()
def setUp(self):
@@ -72,14 +71,12 @@ def tearDown(self):
"continue": {
"type": "boolean",
"description": "Would you like to continue the workflow?",
- "required": True
+ "required": True,
}
},
}
-RESPONSE_DEFAULT = {
- "continue": True
-}
+RESPONSE_DEFAULT = {"continue": True}
SCHEMA_MULTIPLE = {
"title": "response_data",
@@ -88,30 +85,24 @@ def tearDown(self):
"name": {
"type": "string",
"description": "What is your name?",
- "required": True
+ "required": True,
},
"pin": {
"type": "integer",
"description": "What is your PIN?",
- "required": True
+ "required": True,
},
"paradox": {
"type": "boolean",
"description": "This statement is False.",
- "required": True
- }
+ "required": True,
+ },
},
}
-RESPONSE_MULTIPLE = {
- "name": "matt",
- "pin": 1234,
- "paradox": True
-}
+RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True}
-RESPONSE_BAD = {
- "foo": "bar"
-}
+RESPONSE_BAD = {"foo": "bar"}
INQUIRY_1 = {
"id": "abcdef",
@@ -119,7 +110,7 @@ def tearDown(self):
"roles": [],
"users": [],
"route": "",
- "ttl": 1440
+ "ttl": 1440,
}
INQUIRY_MULTIPLE = {
@@ -128,145 +119,200 @@ def tearDown(self):
"roles": [],
"users": [],
"route": "",
- "ttl": 1440
+ "ttl": 1440,
}
class TestInquirySubcommands(TestInquiryBase):
-
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")
+ ),
+ )
def test_get_inquiry(self):
- """Test retrieval of a single inquiry
- """
- inquiry_id = 'abcdef'
- args = ['inquiry', 'get', inquiry_id]
+ """Test retrieval of a single inquiry"""
+ inquiry_id = "abcdef"
+ args = ["inquiry", "get", inquiry_id]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 404, 'NOT FOUND')))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps({}), 404, "NOT FOUND")
+ ),
+ )
def test_get_inquiry_not_found(self):
- """Test retrieval of a inquiry that doesn't exist
- """
- inquiry_id = 'asdbv'
- args = ['inquiry', 'get', inquiry_id]
+ """Test retrieval of a inquiry that doesn't exist"""
+ inquiry_id = "asdbv"
+ args = ["inquiry", "get", inquiry_id]
retcode = self.shell.run(args)
- self.assertEqual('Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue())
+ self.assertEqual(
+ 'Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue()
+ )
self.assertEqual(retcode, 2)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps([INQUIRY_1]), 200, 'OK', {'X-Total-Count': '1'}
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(
+ json.dumps([INQUIRY_1]), 200, "OK", {"X-Total-Count": "1"}
+ )
+ )
+ ),
+ )
def test_list_inquiries(self):
- """Test retrieval of a list of Inquiries
- """
- args = ['inquiry', 'list']
+ """Test retrieval of a list of Inquiries"""
+ args = ["inquiry", "list"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
- self.assertEqual(self.stdout.getvalue().count('1440'), 1)
+ self.assertEqual(self.stdout.getvalue().count("1440"), 1)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps(_generate_inquiries(50)), 200, 'OK', {'X-Total-Count': '55'}
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(
+ json.dumps(_generate_inquiries(50)),
+ 200,
+ "OK",
+ {"X-Total-Count": "55"},
+ )
+ )
+ ),
+ )
def test_list_inquiries_limit(self):
- """Test retrieval of a list of Inquiries while using the "limit" option
- """
- args = ['inquiry', 'list', '-n', '50']
+ """Test retrieval of a list of Inquiries while using the "limit" option"""
+ args = ["inquiry", "list", "-n", "50"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
- self.assertEqual(self.stdout.getvalue().count('1440'), 50)
- self.assertIn('Note: Only first 50 inquiries are displayed.', self.stderr.getvalue())
+ self.assertEqual(self.stdout.getvalue().count("1440"), 50)
+ self.assertIn(
+ "Note: Only first 50 inquiries are displayed.", self.stderr.getvalue()
+ )
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps([]), 200, 'OK', {'X-Total-Count': '0'}
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(json.dumps([]), 200, "OK", {"X-Total-Count": "0"})
+ )
+ ),
+ )
def test_list_empty_inquiries(self):
- """Test empty list of Inquiries
- """
- args = ['inquiry', 'list']
+ """Test empty list of Inquiries"""
+ args = ["inquiry", "list"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps(INQUIRY_1), 200, 'OK'
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK"))
+ ),
+ )
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK'
- ))))
- @mock.patch('st2client.commands.inquiry.InteractiveForm')
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(
+ json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}),
+ 200,
+ "OK",
+ )
+ )
+ ),
+ )
+ @mock.patch("st2client.commands.inquiry.InteractiveForm")
def test_respond(self, mock_form):
- """Test interactive response
- """
+ """Test interactive response"""
form_instance = mock_form.return_value
form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT
- args = ['inquiry', 'respond', 'abcdef']
+ args = ["inquiry", "respond", "abcdef"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps(INQUIRY_1), 200, 'OK'
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK"))
+ ),
+ )
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK'
- ))))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(
+ json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}),
+ 200,
+ "OK",
+ )
+ )
+ ),
+ )
def test_respond_response_flag(self):
- """Test response without interactive mode
- """
- args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, 'abcdef']
+ """Test response without interactive mode"""
+ args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, "abcdef"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps(INQUIRY_1), 200, 'OK'
- ))))
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK"))
+ ),
+ )
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps({}), 400, '400 Client Error: Bad Request'
- ))))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(json.dumps({}), 400, "400 Client Error: Bad Request")
+ )
+ ),
+ )
def test_respond_invalid(self):
- """Test invalid response
- """
- args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_BAD, 'abcdef']
+ """Test invalid response"""
+ args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_BAD, "abcdef"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
- self.assertEqual('ERROR: 400 Client Error: Bad Request', self.stdout.getvalue().strip())
+ self.assertEqual(
+ "ERROR: 400 Client Error: Bad Request", self.stdout.getvalue().strip()
+ )
def test_respond_nonexistent_inquiry(self):
- """Test responding to an inquiry that doesn't exist
- """
- inquiry_id = '134234'
- args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, inquiry_id]
+ """Test responding to an inquiry that doesn't exist"""
+ inquiry_id = "134234"
+ args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, inquiry_id]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
- self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id,
- self.stdout.getvalue().strip())
+ self.assertEqual(
+ 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id,
+ self.stdout.getvalue().strip(),
+ )
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=(base.FakeResponse(
- json.dumps({}), 404, '404 Client Error: Not Found'
- ))))
- @mock.patch('st2client.commands.inquiry.InteractiveForm')
+ requests,
+ "get",
+ mock.MagicMock(
+ return_value=(
+ base.FakeResponse(json.dumps({}), 404, "404 Client Error: Not Found")
+ )
+ ),
+ )
+ @mock.patch("st2client.commands.inquiry.InteractiveForm")
def test_respond_nonexistent_inquiry_interactive(self, mock_form):
"""Test interactively responding to an inquiry that doesn't exist
@@ -274,11 +320,13 @@ def test_respond_nonexistent_inquiry_interactive(self, mock_form):
responding with PUT, in order to retrieve the desired schema for this inquiry.
So, we want to test that interaction separately.
"""
- inquiry_id = '253432'
+ inquiry_id = "253432"
form_instance = mock_form.return_value
form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT
- args = ['inquiry', 'respond', inquiry_id]
+ args = ["inquiry", "respond", inquiry_id]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
- self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id,
- self.stdout.getvalue().strip())
+ self.assertEqual(
+ 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id,
+ self.stdout.getvalue().strip(),
+ )
diff --git a/st2client/tests/unit/test_interactive.py b/st2client/tests/unit/test_interactive.py
index 24f0080232..dce4c6748d 100644
--- a/st2client/tests/unit/test_interactive.py
+++ b/st2client/tests/unit/test_interactive.py
@@ -31,37 +31,32 @@
class TestInteractive(unittest2.TestCase):
-
def assertPromptMessage(self, prompt_mock, message, msg=None):
self.assertEqual(prompt_mock.call_args[0], (message,), msg)
def assertPromptDescription(self, prompt_mock, message, msg=None):
- toolbar_factory = prompt_mock.call_args[1]['get_bottom_toolbar_tokens']
+ toolbar_factory = prompt_mock.call_args[1]["get_bottom_toolbar_tokens"]
self.assertEqual(toolbar_factory(None)[0][1], message, msg)
def assertPromptValidate(self, prompt_mock, value):
- validator = prompt_mock.call_args[1]['validator']
+ validator = prompt_mock.call_args[1]["validator"]
validator.validate(Document(text=six.text_type(value)))
def assertPromptPassword(self, prompt_mock, value, msg=None):
- self.assertEqual(prompt_mock.call_args[1]['is_password'], value, msg)
+ self.assertEqual(prompt_mock.call_args[1]["is_password"], value, msg)
def test_interactive_form(self):
reader = mock.MagicMock()
Reader = mock.MagicMock(return_value=reader)
Reader.condition = mock.MagicMock(return_value=True)
- schema = {
- 'string': {
- 'type': 'string'
- }
- }
+ schema = {"string": {"type": "string"}}
- with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]):
+ with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]):
interactive.InteractiveForm(schema).initiate_dialog()
- Reader.condition.assert_called_once_with(schema['string'])
+ Reader.condition.assert_called_once_with(schema["string"])
reader.read.assert_called_once_with()
def test_interactive_form_no_match(self):
@@ -69,35 +64,27 @@ def test_interactive_form_no_match(self):
Reader = mock.MagicMock(return_value=reader)
Reader.condition = mock.MagicMock(return_value=False)
- schema = {
- 'string': {
- 'type': 'string'
- }
- }
+ schema = {"string": {"type": "string"}}
- with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]):
+ with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]):
interactive.InteractiveForm(schema).initiate_dialog()
- Reader.condition.assert_called_once_with(schema['string'])
+ Reader.condition.assert_called_once_with(schema["string"])
reader.read.assert_not_called()
- @mock.patch('sys.stdout', new_callable=StringIO)
+ @mock.patch("sys.stdout", new_callable=StringIO)
def test_interactive_form_interrupted(self, stdout_mock):
reader = mock.MagicMock()
Reader = mock.MagicMock(return_value=reader)
Reader.condition = mock.MagicMock(return_value=True)
reader.read = mock.MagicMock(side_effect=KeyboardInterrupt)
- schema = {
- 'string': {
- 'type': 'string'
- }
- }
+ schema = {"string": {"type": "string"}}
- with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]):
+ with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]):
interactive.InteractiveForm(schema).initiate_dialog()
- self.assertEqual(stdout_mock.getvalue(), 'Dialog interrupted.\n')
+ self.assertEqual(stdout_mock.getvalue(), "Dialog interrupted.\n")
def test_interactive_form_interrupted_reraised(self):
reader = mock.MagicMock()
@@ -105,285 +92,278 @@ def test_interactive_form_interrupted_reraised(self):
Reader.condition = mock.MagicMock(return_value=True)
reader.read = mock.MagicMock(side_effect=KeyboardInterrupt)
- schema = {
- 'string': {
- 'type': 'string'
- }
- }
+ schema = {"string": {"type": "string"}}
- with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]):
- self.assertRaises(interactive.DialogInterrupted,
- interactive.InteractiveForm(schema, reraise=True).initiate_dialog)
+ with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]):
+ self.assertRaises(
+ interactive.DialogInterrupted,
+ interactive.InteractiveForm(schema, reraise=True).initiate_dialog,
+ )
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_stringreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': 'hey'
- }
- Reader = interactive.StringReader('some', spec)
+ spec = {"description": "some description", "default": "hey"}
+ Reader = interactive.StringReader("some", spec)
- prompt_mock.return_value = 'stuff'
+ prompt_mock.return_value = "stuff"
result = Reader.read()
- self.assertEqual(result, 'stuff')
- self.assertPromptMessage(prompt_mock, 'some [hey]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, 'stuff')
+ self.assertEqual(result, "stuff")
+ self.assertPromptMessage(prompt_mock, "some [hey]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "stuff")
- prompt_mock.return_value = ''
+ prompt_mock.return_value = ""
result = Reader.read()
- self.assertEqual(result, 'hey')
- self.assertPromptValidate(prompt_mock, '')
+ self.assertEqual(result, "hey")
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_booleanreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': False
- }
- Reader = interactive.BooleanReader('some', spec)
+ spec = {"description": "some description", "default": False}
+ Reader = interactive.BooleanReader("some", spec)
- prompt_mock.return_value = 'y'
+ prompt_mock.return_value = "y"
result = Reader.read()
self.assertEqual(result, True)
- self.assertPromptMessage(prompt_mock, 'some (boolean) [n]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, 'y')
- self.assertRaises(prompt_toolkit.validation.ValidationError,
- self.assertPromptValidate, prompt_mock, 'some')
-
- prompt_mock.return_value = ''
+ self.assertPromptMessage(prompt_mock, "some (boolean) [n]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "y")
+ self.assertRaises(
+ prompt_toolkit.validation.ValidationError,
+ self.assertPromptValidate,
+ prompt_mock,
+ "some",
+ )
+
+ prompt_mock.return_value = ""
result = Reader.read()
self.assertEqual(result, False)
- self.assertPromptValidate(prompt_mock, '')
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_numberreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': 3.2
- }
- Reader = interactive.NumberReader('some', spec)
+ spec = {"description": "some description", "default": 3.2}
+ Reader = interactive.NumberReader("some", spec)
- prompt_mock.return_value = '5.3'
+ prompt_mock.return_value = "5.3"
result = Reader.read()
self.assertEqual(result, 5.3)
- self.assertPromptMessage(prompt_mock, 'some (float) [3.2]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, '5.3')
- self.assertRaises(prompt_toolkit.validation.ValidationError,
- self.assertPromptValidate, prompt_mock, 'some')
-
- prompt_mock.return_value = ''
+ self.assertPromptMessage(prompt_mock, "some (float) [3.2]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "5.3")
+ self.assertRaises(
+ prompt_toolkit.validation.ValidationError,
+ self.assertPromptValidate,
+ prompt_mock,
+ "some",
+ )
+
+ prompt_mock.return_value = ""
result = Reader.read()
self.assertEqual(result, 3.2)
- self.assertPromptValidate(prompt_mock, '')
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_integerreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': 3
- }
- Reader = interactive.IntegerReader('some', spec)
+ spec = {"description": "some description", "default": 3}
+ Reader = interactive.IntegerReader("some", spec)
- prompt_mock.return_value = '5'
+ prompt_mock.return_value = "5"
result = Reader.read()
self.assertEqual(result, 5)
- self.assertPromptMessage(prompt_mock, 'some (integer) [3]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, '5')
- self.assertRaises(prompt_toolkit.validation.ValidationError,
- self.assertPromptValidate, prompt_mock, '5.3')
-
- prompt_mock.return_value = ''
+ self.assertPromptMessage(prompt_mock, "some (integer) [3]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "5")
+ self.assertRaises(
+ prompt_toolkit.validation.ValidationError,
+ self.assertPromptValidate,
+ prompt_mock,
+ "5.3",
+ )
+
+ prompt_mock.return_value = ""
result = Reader.read()
self.assertEqual(result, 3)
- self.assertPromptValidate(prompt_mock, '')
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_secretstringreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': 'hey'
- }
- Reader = interactive.SecretStringReader('some', spec)
+ spec = {"description": "some description", "default": "hey"}
+ Reader = interactive.SecretStringReader("some", spec)
- prompt_mock.return_value = 'stuff'
+ prompt_mock.return_value = "stuff"
result = Reader.read()
- self.assertEqual(result, 'stuff')
- self.assertPromptMessage(prompt_mock, 'some (secret) [hey]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, 'stuff')
+ self.assertEqual(result, "stuff")
+ self.assertPromptMessage(prompt_mock, "some (secret) [hey]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "stuff")
self.assertPromptPassword(prompt_mock, True)
- prompt_mock.return_value = ''
+ prompt_mock.return_value = ""
result = Reader.read()
- self.assertEqual(result, 'hey')
- self.assertPromptValidate(prompt_mock, '')
+ self.assertEqual(result, "hey")
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_enumreader(self, prompt_mock):
spec = {
- 'enum': ['some', 'thing', 'else'],
- 'description': 'some description',
- 'default': 'thing'
+ "enum": ["some", "thing", "else"],
+ "description": "some description",
+ "default": "thing",
}
- Reader = interactive.EnumReader('some', spec)
+ Reader = interactive.EnumReader("some", spec)
- prompt_mock.return_value = '2'
+ prompt_mock.return_value = "2"
result = Reader.read()
- self.assertEqual(result, 'else')
- message = 'some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: '
+ self.assertEqual(result, "else")
+ message = "some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: "
self.assertPromptMessage(prompt_mock, message)
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, '0')
- self.assertRaises(prompt_toolkit.validation.ValidationError,
- self.assertPromptValidate, prompt_mock, 'some')
- self.assertRaises(prompt_toolkit.validation.ValidationError,
- self.assertPromptValidate, prompt_mock, '5')
-
- prompt_mock.return_value = ''
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "0")
+ self.assertRaises(
+ prompt_toolkit.validation.ValidationError,
+ self.assertPromptValidate,
+ prompt_mock,
+ "some",
+ )
+ self.assertRaises(
+ prompt_toolkit.validation.ValidationError,
+ self.assertPromptValidate,
+ prompt_mock,
+ "5",
+ )
+
+ prompt_mock.return_value = ""
result = Reader.read()
- self.assertEqual(result, 'thing')
- self.assertPromptValidate(prompt_mock, '')
+ self.assertEqual(result, "thing")
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_arrayreader(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': ['a', 'b']
- }
- Reader = interactive.ArrayReader('some', spec)
+ spec = {"description": "some description", "default": ["a", "b"]}
+ Reader = interactive.ArrayReader("some", spec)
- prompt_mock.return_value = 'some,thing,else'
+ prompt_mock.return_value = "some,thing,else"
result = Reader.read()
- self.assertEqual(result, ['some', 'thing', 'else'])
- self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, 'some,thing,else')
+ self.assertEqual(result, ["some", "thing", "else"])
+ self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "some,thing,else")
- prompt_mock.return_value = ''
+ prompt_mock.return_value = ""
result = Reader.read()
- self.assertEqual(result, ['a', 'b'])
- self.assertPromptValidate(prompt_mock, '')
+ self.assertEqual(result, ["a", "b"])
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_arrayreader_ends_with_comma(self, prompt_mock):
- spec = {
- 'description': 'some description',
- 'default': ['a', 'b']
- }
- Reader = interactive.ArrayReader('some', spec)
+ spec = {"description": "some description", "default": ["a", "b"]}
+ Reader = interactive.ArrayReader("some", spec)
- prompt_mock.return_value = 'some,thing,else,'
+ prompt_mock.return_value = "some,thing,else,"
result = Reader.read()
- self.assertEqual(result, ['some', 'thing', 'else', ''])
- self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ')
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, 'some,thing,else,')
+ self.assertEqual(result, ["some", "thing", "else", ""])
+ self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ")
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "some,thing,else,")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_arrayenumreader(self, prompt_mock):
spec = {
- 'items': {
- 'enum': ['a', 'b', 'c', 'd', 'e']
- },
- 'description': 'some description',
- 'default': ['a', 'b']
+ "items": {"enum": ["a", "b", "c", "d", "e"]},
+ "description": "some description",
+ "default": ["a", "b"],
}
- Reader = interactive.ArrayEnumReader('some', spec)
+ Reader = interactive.ArrayEnumReader("some", spec)
- prompt_mock.return_value = '0,2,4'
+ prompt_mock.return_value = "0,2,4"
result = Reader.read()
- self.assertEqual(result, ['a', 'c', 'e'])
- message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: '
+ self.assertEqual(result, ["a", "c", "e"])
+ message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: "
self.assertPromptMessage(prompt_mock, message)
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, '0,2,4')
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "0,2,4")
- prompt_mock.return_value = ''
+ prompt_mock.return_value = ""
result = Reader.read()
- self.assertEqual(result, ['a', 'b'])
- self.assertPromptValidate(prompt_mock, '')
+ self.assertEqual(result, ["a", "b"])
+ self.assertPromptValidate(prompt_mock, "")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_arrayenumreader_ends_with_comma(self, prompt_mock):
spec = {
- 'items': {
- 'enum': ['a', 'b', 'c', 'd', 'e']
- },
- 'description': 'some description',
- 'default': ['a', 'b']
+ "items": {"enum": ["a", "b", "c", "d", "e"]},
+ "description": "some description",
+ "default": ["a", "b"],
}
- Reader = interactive.ArrayEnumReader('some', spec)
+ Reader = interactive.ArrayEnumReader("some", spec)
- prompt_mock.return_value = '0,2,4,'
+ prompt_mock.return_value = "0,2,4,"
result = Reader.read()
- self.assertEqual(result, ['a', 'c', 'e'])
- message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: '
+ self.assertEqual(result, ["a", "c", "e"])
+ message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: "
self.assertPromptMessage(prompt_mock, message)
- self.assertPromptDescription(prompt_mock, 'some description')
- self.assertPromptValidate(prompt_mock, '0,2,4,')
+ self.assertPromptDescription(prompt_mock, "some description")
+ self.assertPromptValidate(prompt_mock, "0,2,4,")
- @mock.patch.object(interactive, 'prompt')
+ @mock.patch.object(interactive, "prompt")
def test_arrayobjectreader(self, prompt_mock):
spec = {
- 'items': {
- 'type': 'object',
- 'properties': {
- 'foo': {
- 'type': 'string',
- 'description': 'some description',
+ "items": {
+ "type": "object",
+ "properties": {
+ "foo": {
+ "type": "string",
+ "description": "some description",
+ },
+ "bar": {
+ "type": "string",
+ "description": "some description",
},
- 'bar': {
- 'type': 'string',
- 'description': 'some description',
- }
- }
+ },
},
- 'description': 'some description',
+ "description": "some description",
}
- Reader = interactive.ArrayObjectReader('some', spec)
+ Reader = interactive.ArrayObjectReader("some", spec)
# To emulate continuing setting, this flag variable is needed
self.is_continued = False
def side_effect(msg, **kwargs):
- if re.match(r'^~~~ Would you like to add another item to.*', msg):
+ if re.match(r"^~~~ Would you like to add another item to.*", msg):
# prompt requires the input to judge continuing setting, or not
if not self.is_continued:
# continuing the configuration only once
self.is_continued = True
- return ''
+ return ""
else:
# finishing to configuration
- return 'n'
+ return "n"
else:
# prompt requires the input of property value in the object
- return 'value'
+ return "value"
prompt_mock.side_effect = side_effect
results = Reader.read()
self.assertEqual(len(results), 2)
self.assertTrue(all([len(list(x.keys())) == 2 for x in results]))
- self.assertTrue(all(['foo' in x and 'bar' in x for x in results]))
+ self.assertTrue(all(["foo" in x and "bar" in x for x in results]))
diff --git a/st2client/tests/unit/test_keyvalue.py b/st2client/tests/unit/test_keyvalue.py
index bb5bf09d60..52c240a052 100644
--- a/st2client/tests/unit/test_keyvalue.py
+++ b/st2client/tests/unit/test_keyvalue.py
@@ -29,77 +29,70 @@
LOG = logging.getLogger(__name__)
KEYVALUE = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'super cool value',
- 'scope': 'system'
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "super cool value",
+ "scope": "system",
}
KEYVALUE_USER = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'super cool value',
- 'scope': 'system',
- 'user': 'stanley'
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "super cool value",
+ "scope": "system",
+ "user": "stanley",
}
KEYVALUE_SECRET = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'super cool value',
- 'scope': 'system',
- 'secret': True
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "super cool value",
+ "scope": "system",
+ "secret": True,
}
KEYVALUE_PRE_ENCRYPTED = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'AAABBBCCC1234',
- 'scope': 'system',
- 'encrypted': True,
- 'secret': True
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "AAABBBCCC1234",
+ "scope": "system",
+ "encrypted": True,
+ "secret": True,
}
KEYVALUE_TTL = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'super cool value',
- 'scope': 'system',
- 'ttl': 100
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "super cool value",
+ "scope": "system",
+ "ttl": 100,
}
KEYVALUE_OBJECT = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': {'obj': [1, True, 23.4, 'abc']},
- 'scope': 'system',
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": {"obj": [1, True, 23.4, "abc"]},
+ "scope": "system",
}
KEYVALUE_ALL = {
- 'id': 'kv_name',
- 'name': 'kv_name.',
- 'value': 'AAAAABBBBBCCCCCCDDDDD11122345',
- 'scope': 'system',
- 'user': 'stanley',
- 'secret': True,
- 'encrypted': True,
- 'ttl': 100
+ "id": "kv_name",
+ "name": "kv_name.",
+ "value": "AAAAABBBBBCCCCCCDDDDD11122345",
+ "scope": "system",
+ "user": "stanley",
+ "secret": True,
+ "encrypted": True,
+ "ttl": 100,
}
-KEYVALUE_MISSING_NAME = {
- 'id': 'kv_name',
- 'value': 'super cool value'
-}
+KEYVALUE_MISSING_NAME = {"id": "kv_name", "value": "super cool value"}
-KEYVALUE_MISSING_VALUE = {
- 'id': 'kv_name',
- 'name': 'kv_name.'
-}
+KEYVALUE_MISSING_VALUE = {"id": "kv_name", "name": "kv_name."}
class TestKeyValueBase(base.BaseCLITestCase):
- """Base class for "key" CLI tests
- """
+ """Base class for "key" CLI tests"""
capture_output = True
@@ -107,8 +100,8 @@ def __init__(self, *args, **kwargs):
super(TestKeyValueBase, self).__init__(*args, **kwargs)
self.parser = argparse.ArgumentParser()
- self.parser.add_argument('-t', '--token', dest='token')
- self.parser.add_argument('--api-key', dest='api_key')
+ self.parser.add_argument("-t", "--token", dest="token")
+ self.parser.add_argument("--api-key", dest="api_key")
self.shell = shell.Shell()
def setUp(self):
@@ -119,44 +112,49 @@ def tearDown(self):
class TestKeyValueSet(TestKeyValueBase):
-
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200,
- 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK"
+ )
+ ),
+ )
def test_set_keyvalue(self):
- """Test setting key/value pair with optional pre_encrypted field
- """
- args = ['key', 'set', '--encrypted', 'kv_name', 'AAABBBCCC1234']
+ """Test setting key/value pair with optional pre_encrypted field"""
+ args = ["key", "set", "--encrypted", "kv_name", "AAABBBCCC1234"]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
def test_encrypt_and_encrypted_flags_are_mutually_exclusive(self):
- args = ['key', 'set', '--encrypt', '--encrypted', 'kv_name', 'AAABBBCCC1234']
+ args = ["key", "set", "--encrypt", "--encrypted", "kv_name", "AAABBBCCC1234"]
- self.assertRaisesRegexp(SystemExit, '2', self.shell.run, args)
+ self.assertRaisesRegexp(SystemExit, "2", self.shell.run, args)
self.stderr.seek(0)
stderr = self.stderr.read()
- expected_msg = ('error: argument --encrypted: not allowed with argument -e/--encrypt')
+ expected_msg = (
+ "error: argument --encrypted: not allowed with argument -e/--encrypt"
+ )
self.assertIn(expected_msg, stderr)
class TestKeyValueLoad(TestKeyValueBase):
-
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")),
+ )
def test_load_keyvalue_json(self):
- """Test loading of key/value pair in JSON format
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair in JSON format"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -164,17 +162,18 @@ def test_load_keyvalue_json(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")),
+ )
def test_load_keyvalue_yaml(self):
- """Test loading of key/value pair in YAML format
- """
- fd, path = tempfile.mkstemp(suffix='.yaml')
+ """Test loading of key/value pair in YAML format"""
+ fd, path = tempfile.mkstemp(suffix=".yaml")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(yaml.safe_dump(KEYVALUE))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -182,17 +181,20 @@ def test_load_keyvalue_yaml(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, "OK")
+ ),
+ )
def test_load_keyvalue_user(self):
- """Test loading of key/value pair with the optional user field
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair with the optional user field"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_USER, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -200,17 +202,20 @@ def test_load_keyvalue_user(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, "OK")
+ ),
+ )
def test_load_keyvalue_secret(self):
- """Test loading of key/value pair with the optional secret field
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair with the optional secret field"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_SECRET, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -218,18 +223,22 @@ def test_load_keyvalue_secret(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200,
- 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK"
+ )
+ ),
+ )
def test_load_keyvalue_already_encrypted(self):
- """Test loading of key/value pair with the pre-encrypted value
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair with the pre-encrypted value"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_PRE_ENCRYPTED, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -237,17 +246,20 @@ def test_load_keyvalue_already_encrypted(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, "OK")
+ ),
+ )
def test_load_keyvalue_ttl(self):
- """Test loading of key/value pair with the optional ttl field
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair with the optional ttl field"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_TTL, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -255,23 +267,26 @@ def test_load_keyvalue_ttl(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK")
+ ),
+ )
def test_load_keyvalue_object(self):
- """Test loading of key/value pair where the value is an object
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair where the value is an object"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_OBJECT, indent=4))
# test converting with short option
- args = ['key', 'load', '-c', path]
+ args = ["key", "load", "-c", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
# test converting with long option
- args = ['key', 'load', '--convert', path]
+ args = ["key", "load", "--convert", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -279,19 +294,23 @@ def test_load_keyvalue_object(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK")
+ ),
+ )
def test_load_keyvalue_object_fail(self):
"""Test failure to load key/value pair where the value is an object
- and the -c/--convert option is not passed
+ and the -c/--convert option is not passed
"""
- fd, path = tempfile.mkstemp(suffix='.json')
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_OBJECT, indent=4))
# test converting with short option
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertNotEqual(retcode, 0)
finally:
@@ -299,17 +318,20 @@ def test_load_keyvalue_object_fail(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK")
+ ),
+ )
def test_load_keyvalue_all(self):
- """Test loading of key/value pair with all optional fields
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of key/value pair with all optional fields"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_ALL, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -317,21 +339,23 @@ def test_load_keyvalue_all(self):
os.unlink(path)
@mock.patch.object(
- requests, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL),
- 200, 'OK')))
+ requests,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK")
+ ),
+ )
def test_load_keyvalue_array(self):
- """Test loading an array of key/value pairs
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading an array of key/value pairs"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
array = [KEYVALUE, KEYVALUE_ALL]
json_str = json.dumps(array, indent=4)
LOG.info(json_str)
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json_str)
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 0)
finally:
@@ -339,14 +363,13 @@ def test_load_keyvalue_array(self):
os.unlink(path)
def test_load_keyvalue_missing_name(self):
- """Test loading of a key/value pair with the required field 'name' missing
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of a key/value pair with the required field 'name' missing"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_MISSING_NAME, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
finally:
@@ -354,14 +377,13 @@ def test_load_keyvalue_missing_name(self):
os.unlink(path)
def test_load_keyvalue_missing_value(self):
- """Test loading of a key/value pair with the required field 'value' missing
- """
- fd, path = tempfile.mkstemp(suffix='.json')
+ """Test loading of a key/value pair with the required field 'value' missing"""
+ fd, path = tempfile.mkstemp(suffix=".json")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(json.dumps(KEYVALUE_MISSING_VALUE, indent=4))
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
finally:
@@ -369,19 +391,17 @@ def test_load_keyvalue_missing_value(self):
os.unlink(path)
def test_load_keyvalue_missing_file(self):
- """Test loading of a key/value pair with a missing file
- """
- path = '/some/file/that/doesnt/exist.json'
- args = ['key', 'load', path]
+ """Test loading of a key/value pair with a missing file"""
+ path = "/some/file/that/doesnt/exist.json"
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
def test_load_keyvalue_bad_file_extension(self):
- """Test loading of a key/value pair with a bad file extension
- """
- fd, path = tempfile.mkstemp(suffix='.badext')
+ """Test loading of a key/value pair with a bad file extension"""
+ fd, path = tempfile.mkstemp(suffix=".badext")
try:
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
self.assertEqual(retcode, 1)
finally:
@@ -392,11 +412,11 @@ def test_load_keyvalue_empty_file(self):
"""
Loading K/V from an empty file shouldn't throw an error
"""
- fd, path = tempfile.mkstemp(suffix='.yaml')
+ fd, path = tempfile.mkstemp(suffix=".yaml")
try:
- args = ['key', 'load', path]
+ args = ["key", "load", path]
retcode = self.shell.run(args)
- self.assertIn('No matching items found', self.stdout.getvalue())
+ self.assertIn("No matching items found", self.stdout.getvalue())
self.assertEqual(retcode, 0)
finally:
os.close(fd)
diff --git a/st2client/tests/unit/test_models.py b/st2client/tests/unit/test_models.py
index dd7f35d6b8..8a137afa13 100644
--- a/st2client/tests/unit/test_models.py
+++ b/st2client/tests/unit/test_models.py
@@ -29,22 +29,24 @@
class TestSerialization(unittest2.TestCase):
-
def test_resource_serialize(self):
- instance = base.FakeResource(id='123', name='abc')
+ instance = base.FakeResource(id="123", name="abc")
self.assertDictEqual(instance.serialize(), base.RESOURCES[0])
def test_resource_deserialize(self):
instance = base.FakeResource.deserialize(base.RESOURCES[0])
- self.assertEqual(instance.id, '123')
- self.assertEqual(instance.name, 'abc')
+ self.assertEqual(instance.id, "123")
+ self.assertEqual(instance.name, "abc")
class TestResourceManager(unittest2.TestCase):
-
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
def test_resource_get_all(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
resources = mgr.get_all()
@@ -53,8 +55,12 @@ def test_resource_get_all(self):
self.assertListEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
def test_resource_get_all_with_limit(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
resources = mgr.get_all(limit=50)
@@ -63,135 +69,197 @@ def test_resource_get_all_with_limit(self):
self.assertListEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_get_all_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
self.assertRaises(Exception, mgr.get_all)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_resource_get_by_id(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resource = mgr.get_by_id('123')
+ resource = mgr.get_by_id("123")
actual = resource.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")),
+ )
def test_resource_get_by_id_404(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resource = mgr.get_by_id('123')
+ resource = mgr.get_by_id("123")
self.assertIsNone(resource)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_get_by_id_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
self.assertRaises(Exception, mgr.get_by_id)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
def test_resource_query(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resources = mgr.query(name='abc')
+ resources = mgr.query(name="abc")
actual = [resource.serialize() for resource in resources]
expected = json.loads(json.dumps([base.RESOURCES[0]]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {'X-Total-Count': '50'})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {"X-Total-Count": "50"}
+ )
+ ),
+ )
def test_resource_query_with_count(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resources, count = mgr.query_with_count(name='abc')
+ resources, count = mgr.query_with_count(name="abc")
actual = [resource.serialize() for resource in resources]
expected = json.loads(json.dumps([base.RESOURCES[0]]))
self.assertEqual(actual, expected)
self.assertEqual(count, 50)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
def test_resource_query_with_limit(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resources = mgr.query(name='abc', limit=50)
+ resources = mgr.query(name="abc", limit=50)
actual = [resource.serialize() for resource in resources]
expected = json.loads(json.dumps([base.RESOURCES[0]]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND',
- {'X-Total-Count': '30'})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ "", 404, "NOT FOUND", {"X-Total-Count": "30"}
+ )
+ ),
+ )
def test_resource_query_404(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
# No X-Total-Count
- resources = mgr.query(name='abc')
+ resources = mgr.query(name="abc")
self.assertListEqual(resources, [])
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND',
- {'X-Total-Count': '30'})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ "", 404, "NOT FOUND", {"X-Total-Count": "30"}
+ )
+ ),
+ )
def test_resource_query_with_count_404(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resources, count = mgr.query_with_count(name='abc')
+ resources, count = mgr.query_with_count(name="abc")
self.assertListEqual(resources, [])
self.assertIsNone(count)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_query_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- self.assertRaises(Exception, mgr.query, name='abc')
+ self.assertRaises(Exception, mgr.query, name="abc")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
def test_resource_get_by_name(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
# No X-Total-Count
- resource = mgr.get_by_name('abc')
+ resource = mgr.get_by_name("abc")
actual = resource.serialize()
expected = json.loads(json.dumps(base.RESOURCES[0]))
self.assertEqual(actual, expected)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")),
+ )
def test_resource_get_by_name_404(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- resource = mgr.get_by_name('abc')
+ resource = mgr.get_by_name("abc")
self.assertIsNone(resource)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
def test_resource_get_by_name_ambiguous(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- self.assertRaises(Exception, mgr.get_by_name, 'abc')
+ self.assertRaises(Exception, mgr.get_by_name, "abc")
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_get_by_name_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
self.assertRaises(Exception, mgr.get_by_name)
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_resource_create(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
instance = base.FakeResource.deserialize('{"name": "abc"}')
@@ -199,16 +267,24 @@ def test_resource_create(self):
self.assertIsNotNone(resource)
@mock.patch.object(
- httpclient.HTTPClient, 'post',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "post",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_create_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
instance = base.FakeResource.deserialize('{"name": "abc"}')
self.assertRaises(Exception, mgr.create, instance)
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK")
+ ),
+ )
def test_resource_update(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
text = '{"id": "123", "name": "cba"}'
@@ -217,8 +293,12 @@ def test_resource_update(self):
self.assertIsNotNone(resource)
@mock.patch.object(
- httpclient.HTTPClient, 'put',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "put",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_update_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
text = '{"id": "123", "name": "cba"}'
@@ -226,39 +306,57 @@ def test_resource_update_failed(self):
self.assertRaises(Exception, mgr.update, instance)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")),
+ )
def test_resource_delete(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- instance = mgr.get_by_name('abc')
+ instance = mgr.get_by_name("abc")
mgr.delete(instance)
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")),
+ )
def test_resource_delete_404(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
instance = base.FakeResource.deserialize(base.RESOURCES[0])
mgr.delete(instance)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK',
- {})))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(
+ json.dumps([base.RESOURCES[0]]), 200, "OK", {}
+ )
+ ),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'delete',
- mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "delete",
+ mock.MagicMock(
+ return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_resource_delete_failed(self):
mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT)
- instance = mgr.get_by_name('abc')
+ instance = mgr.get_by_name("abc")
self.assertRaises(Exception, mgr.delete, instance)
- @mock.patch('requests.get')
- @mock.patch('sseclient.SSEClient')
+ @mock.patch("requests.get")
+ @mock.patch("sseclient.SSEClient")
def test_stream_resource_listen(self, mock_sseclient, mock_requests):
mock_msg = mock.Mock()
mock_msg.data = json.dumps(base.RESOURCES)
@@ -267,14 +365,16 @@ def test_stream_resource_listen(self, mock_sseclient, mock_requests):
def side_effect_checking_verify_parameter_is():
return [mock_msg]
- mock_sseclient.return_value.events.side_effect = side_effect_checking_verify_parameter_is
- mgr = models.StreamManager('https://example.com', cacert='/path/ca.crt')
+ mock_sseclient.return_value.events.side_effect = (
+ side_effect_checking_verify_parameter_is
+ )
+ mgr = models.StreamManager("https://example.com", cacert="/path/ca.crt")
- resp = mgr.listen(events=['foo', 'bar'])
+ resp = mgr.listen(events=["foo", "bar"])
self.assertEqual(list(resp), [base.RESOURCES])
- call_args = tuple(['https://example.com/stream?events=foo%2Cbar'])
- call_kwargs = {'stream': True, 'verify': '/path/ca.crt'}
+ call_args = tuple(["https://example.com/stream?events=foo%2Cbar"])
+ call_kwargs = {"stream": True, "verify": "/path/ca.crt"}
self.assertEqual(mock_requests.call_args_list[0][0], call_args)
self.assertEqual(mock_requests.call_args_list[0][1], call_kwargs)
@@ -283,15 +383,16 @@ def side_effect_checking_verify_parameter_is():
def side_effect_checking_verify_parameter_is_not():
return [mock_msg]
- mock_sseclient.return_value.events.side_effect = \
+ mock_sseclient.return_value.events.side_effect = (
side_effect_checking_verify_parameter_is_not
- mgr = models.StreamManager('https://example.com')
+ )
+ mgr = models.StreamManager("https://example.com")
resp = mgr.listen()
self.assertEqual(list(resp), [base.RESOURCES])
- call_args = tuple(['https://example.com/stream?'])
- call_kwargs = {'stream': True}
+ call_args = tuple(["https://example.com/stream?"])
+ call_kwargs = {"stream": True}
self.assertEqual(mock_requests.call_args_list[1][0], call_args)
self.assertEqual(mock_requests.call_args_list[1][1], call_kwargs)
diff --git a/st2client/tests/unit/test_shell.py b/st2client/tests/unit/test_shell.py
index 8383526615..bce176b4ad 100644
--- a/st2client/tests/unit/test_shell.py
+++ b/st2client/tests/unit/test_shell.py
@@ -38,8 +38,8 @@
LOG = logging.getLogger(__name__)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini')
-CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini')
+CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini")
+CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini")
MOCK_CONFIG = """
[credentials]
@@ -77,352 +77,383 @@ def test_commands_usage_and_help_strings(self):
self.stderr.seek(0)
stderr = self.stderr.read()
- self.assertIn('Usage: ', stderr)
- self.assertIn('For example:', stderr)
- self.assertIn('CLI for StackStorm', stderr)
- self.assertIn('positional arguments:', stderr)
+ self.assertIn("Usage: ", stderr)
+ self.assertIn("For example:", stderr)
+ self.assertIn("CLI for StackStorm", stderr)
+ self.assertIn("positional arguments:", stderr)
self.stdout.truncate()
self.stderr.truncate()
# --help should result in the same output
try:
- self.assertEqual(self.shell.run(['--help']), 0)
+ self.assertEqual(self.shell.run(["--help"]), 0)
except SystemExit as e:
self.assertEqual(e.code, 0)
self.stdout.seek(0)
stdout = self.stdout.read()
- self.assertIn('Usage: ', stdout)
- self.assertIn('For example:', stdout)
- self.assertIn('CLI for StackStorm', stdout)
- self.assertIn('positional arguments:', stdout)
+ self.assertIn("Usage: ", stdout)
+ self.assertIn("For example:", stdout)
+ self.assertIn("CLI for StackStorm", stdout)
+ self.assertIn("positional arguments:", stdout)
self.stdout.truncate()
self.stderr.truncate()
# Sub command with no args
try:
- self.assertEqual(self.shell.run(['action']), 2)
+ self.assertEqual(self.shell.run(["action"]), 2)
except SystemExit as e:
self.assertEqual(e.code, 2)
self.stderr.seek(0)
stderr = self.stderr.read()
- self.assertIn('usage', stderr)
+ self.assertIn("usage", stderr)
if six.PY2:
- self.assertIn('{list,get,create,update', stderr)
- self.assertIn('error: too few arguments', stderr)
+ self.assertIn("{list,get,create,update", stderr)
+ self.assertIn("error: too few arguments", stderr)
def test_endpoints_default(self):
- base_url = 'http://127.0.0.1'
- auth_url = 'http://127.0.0.1:9100'
- api_url = 'http://127.0.0.1:9101/v1'
- stream_url = 'http://127.0.0.1:9102/v1'
- args = ['trigger', 'list']
+ base_url = "http://127.0.0.1"
+ auth_url = "http://127.0.0.1:9100"
+ api_url = "http://127.0.0.1:9101/v1"
+ stream_url = "http://127.0.0.1:9102/v1"
+ args = ["trigger", "list"]
parsed_args = self.shell.parser.parse_args(args)
client = self.shell.get_client(parsed_args)
- self.assertEqual(client.endpoints['base'], base_url)
- self.assertEqual(client.endpoints['auth'], auth_url)
- self.assertEqual(client.endpoints['api'], api_url)
- self.assertEqual(client.endpoints['stream'], stream_url)
+ self.assertEqual(client.endpoints["base"], base_url)
+ self.assertEqual(client.endpoints["auth"], auth_url)
+ self.assertEqual(client.endpoints["api"], api_url)
+ self.assertEqual(client.endpoints["stream"], stream_url)
def test_endpoints_base_url_from_cli(self):
- base_url = 'http://www.st2.com'
- auth_url = 'http://www.st2.com:9100'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
- args = ['--url', base_url, 'trigger', 'list']
+ base_url = "http://www.st2.com"
+ auth_url = "http://www.st2.com:9100"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
+ args = ["--url", base_url, "trigger", "list"]
parsed_args = self.shell.parser.parse_args(args)
client = self.shell.get_client(parsed_args)
- self.assertEqual(client.endpoints['base'], base_url)
- self.assertEqual(client.endpoints['auth'], auth_url)
- self.assertEqual(client.endpoints['api'], api_url)
- self.assertEqual(client.endpoints['stream'], stream_url)
+ self.assertEqual(client.endpoints["base"], base_url)
+ self.assertEqual(client.endpoints["auth"], auth_url)
+ self.assertEqual(client.endpoints["api"], api_url)
+ self.assertEqual(client.endpoints["stream"], stream_url)
def test_endpoints_base_url_from_env(self):
- base_url = 'http://www.st2.com'
- auth_url = 'http://www.st2.com:9100'
- api_url = 'http://www.st2.com:9101/v1'
- stream_url = 'http://www.st2.com:9102/v1'
- os.environ['ST2_BASE_URL'] = base_url
- args = ['trigger', 'list']
+ base_url = "http://www.st2.com"
+ auth_url = "http://www.st2.com:9100"
+ api_url = "http://www.st2.com:9101/v1"
+ stream_url = "http://www.st2.com:9102/v1"
+ os.environ["ST2_BASE_URL"] = base_url
+ args = ["trigger", "list"]
parsed_args = self.shell.parser.parse_args(args)
client = self.shell.get_client(parsed_args)
- self.assertEqual(client.endpoints['base'], base_url)
- self.assertEqual(client.endpoints['auth'], auth_url)
- self.assertEqual(client.endpoints['api'], api_url)
- self.assertEqual(client.endpoints['stream'], stream_url)
+ self.assertEqual(client.endpoints["base"], base_url)
+ self.assertEqual(client.endpoints["auth"], auth_url)
+ self.assertEqual(client.endpoints["api"], api_url)
+ self.assertEqual(client.endpoints["stream"], stream_url)
def test_endpoints_override_from_cli(self):
- base_url = 'http://www.st2.com'
- auth_url = 'http://www.st2.com:8888'
- api_url = 'http://www.stackstorm1.com:9101/v1'
- stream_url = 'http://www.stackstorm1.com:9102/v1'
- args = ['--url', base_url,
- '--auth-url', auth_url,
- '--api-url', api_url,
- '--stream-url', stream_url,
- 'trigger', 'list']
+ base_url = "http://www.st2.com"
+ auth_url = "http://www.st2.com:8888"
+ api_url = "http://www.stackstorm1.com:9101/v1"
+ stream_url = "http://www.stackstorm1.com:9102/v1"
+ args = [
+ "--url",
+ base_url,
+ "--auth-url",
+ auth_url,
+ "--api-url",
+ api_url,
+ "--stream-url",
+ stream_url,
+ "trigger",
+ "list",
+ ]
parsed_args = self.shell.parser.parse_args(args)
client = self.shell.get_client(parsed_args)
- self.assertEqual(client.endpoints['base'], base_url)
- self.assertEqual(client.endpoints['auth'], auth_url)
- self.assertEqual(client.endpoints['api'], api_url)
- self.assertEqual(client.endpoints['stream'], stream_url)
+ self.assertEqual(client.endpoints["base"], base_url)
+ self.assertEqual(client.endpoints["auth"], auth_url)
+ self.assertEqual(client.endpoints["api"], api_url)
+ self.assertEqual(client.endpoints["stream"], stream_url)
def test_endpoints_override_from_env(self):
- base_url = 'http://www.st2.com'
- auth_url = 'http://www.st2.com:8888'
- api_url = 'http://www.stackstorm1.com:9101/v1'
- stream_url = 'http://www.stackstorm1.com:9102/v1'
- os.environ['ST2_BASE_URL'] = base_url
- os.environ['ST2_AUTH_URL'] = auth_url
- os.environ['ST2_API_URL'] = api_url
- os.environ['ST2_STREAM_URL'] = stream_url
- args = ['trigger', 'list']
+ base_url = "http://www.st2.com"
+ auth_url = "http://www.st2.com:8888"
+ api_url = "http://www.stackstorm1.com:9101/v1"
+ stream_url = "http://www.stackstorm1.com:9102/v1"
+ os.environ["ST2_BASE_URL"] = base_url
+ os.environ["ST2_AUTH_URL"] = auth_url
+ os.environ["ST2_API_URL"] = api_url
+ os.environ["ST2_STREAM_URL"] = stream_url
+ args = ["trigger", "list"]
parsed_args = self.shell.parser.parse_args(args)
client = self.shell.get_client(parsed_args)
- self.assertEqual(client.endpoints['base'], base_url)
- self.assertEqual(client.endpoints['auth'], auth_url)
- self.assertEqual(client.endpoints['api'], api_url)
- self.assertEqual(client.endpoints['stream'], stream_url)
+ self.assertEqual(client.endpoints["base"], base_url)
+ self.assertEqual(client.endpoints["auth"], auth_url)
+ self.assertEqual(client.endpoints["api"], api_url)
+ self.assertEqual(client.endpoints["stream"], stream_url)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
def test_exit_code_on_success(self):
- argv = ['trigger', 'list']
+ argv = ["trigger", "list"]
self.assertEqual(self.shell.run(argv), 0)
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(None, 500, 'INTERNAL SERVER ERROR')))
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(None, 500, "INTERNAL SERVER ERROR")
+ ),
+ )
def test_exit_code_on_error(self):
- argv = ['trigger', 'list']
+ argv = ["trigger", "list"]
self.assertEqual(self.shell.run(argv), 1)
def _validate_parser(self, args_list, is_subcommand=True):
for args in args_list:
ns = self.shell.parser.parse_args(args)
- func = (self.shell.commands[args[0]].run_and_print
- if not is_subcommand
- else self.shell.commands[args[0]].commands[args[1]].run_and_print)
+ func = (
+ self.shell.commands[args[0]].run_and_print
+ if not is_subcommand
+ else self.shell.commands[args[0]].commands[args[1]].run_and_print
+ )
self.assertEqual(ns.func, func)
def test_action(self):
args_list = [
- ['action', 'list'],
- ['action', 'get', 'abc'],
- ['action', 'create', '/tmp/action.json'],
- ['action', 'update', '123', '/tmp/action.json'],
- ['action', 'delete', 'abc'],
- ['action', 'execute', '-h'],
- ['action', 'execute', 'remote', '-h'],
- ['action', 'execute', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'],
- ['action', 'execute', 'remote-fib', 'hosts=192.168.1.1', '3', '8']
+ ["action", "list"],
+ ["action", "get", "abc"],
+ ["action", "create", "/tmp/action.json"],
+ ["action", "update", "123", "/tmp/action.json"],
+ ["action", "delete", "abc"],
+ ["action", "execute", "-h"],
+ ["action", "execute", "remote", "-h"],
+ [
+ "action",
+ "execute",
+ "remote",
+ "hosts=192.168.1.1",
+ "user=st2",
+ 'cmd="ls -l"',
+ ],
+ ["action", "execute", "remote-fib", "hosts=192.168.1.1", "3", "8"],
]
self._validate_parser(args_list)
def test_action_execution(self):
args_list = [
- ['execution', 'list'],
- ['execution', 'list', '-a', 'all'],
- ['execution', 'list', '--attr=all'],
- ['execution', 'get', '123'],
- ['execution', 'get', '123', '-d'],
- ['execution', 'get', '123', '-k', 'localhost.stdout'],
- ['execution', 're-run', '123'],
- ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z'],
- ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z', '--no-reset', 'x'],
- ['execution', 're-run', '123', 'a=1', 'b=x', 'c=True'],
- ['execution', 'cancel', '123'],
- ['execution', 'cancel', '123', '456'],
- ['execution', 'pause', '123'],
- ['execution', 'pause', '123', '456'],
- ['execution', 'resume', '123'],
- ['execution', 'resume', '123', '456']
+ ["execution", "list"],
+ ["execution", "list", "-a", "all"],
+ ["execution", "list", "--attr=all"],
+ ["execution", "get", "123"],
+ ["execution", "get", "123", "-d"],
+ ["execution", "get", "123", "-k", "localhost.stdout"],
+ ["execution", "re-run", "123"],
+ ["execution", "re-run", "123", "--tasks", "x", "y", "z"],
+ ["execution", "re-run", "123", "--tasks", "x", "y", "z", "--no-reset", "x"],
+ ["execution", "re-run", "123", "a=1", "b=x", "c=True"],
+ ["execution", "cancel", "123"],
+ ["execution", "cancel", "123", "456"],
+ ["execution", "pause", "123"],
+ ["execution", "pause", "123", "456"],
+ ["execution", "resume", "123"],
+ ["execution", "resume", "123", "456"],
]
self._validate_parser(args_list)
# Test mutually exclusive argument groups
- self.assertRaises(SystemExit, self._validate_parser,
- [['execution', 'get', '123', '-d', '-k', 'localhost.stdout']])
+ self.assertRaises(
+ SystemExit,
+ self._validate_parser,
+ [["execution", "get", "123", "-d", "-k", "localhost.stdout"]],
+ )
def test_key(self):
args_list = [
- ['key', 'list'],
- ['key', 'list', '-n', '2'],
- ['key', 'get', 'abc'],
- ['key', 'set', 'abc', '123'],
- ['key', 'delete', 'abc'],
- ['key', 'load', '/tmp/keys.json']
+ ["key", "list"],
+ ["key", "list", "-n", "2"],
+ ["key", "get", "abc"],
+ ["key", "set", "abc", "123"],
+ ["key", "delete", "abc"],
+ ["key", "load", "/tmp/keys.json"],
]
self._validate_parser(args_list)
def test_policy(self):
args_list = [
- ['policy', 'list'],
- ['policy', 'list', '-p', 'core'],
- ['policy', 'list', '--pack', 'core'],
- ['policy', 'list', '-r', 'core.local'],
- ['policy', 'list', '--resource-ref', 'core.local'],
- ['policy', 'list', '-pt', 'action.type1'],
- ['policy', 'list', '--policy-type', 'action.type1'],
- ['policy', 'list', '-r', 'core.local', '-pt', 'action.type1'],
- ['policy', 'list', '--resource-ref', 'core.local', '--policy-type', 'action.type1'],
- ['policy', 'get', 'abc'],
- ['policy', 'create', '/tmp/policy.json'],
- ['policy', 'update', '123', '/tmp/policy.json'],
- ['policy', 'delete', 'abc']
+ ["policy", "list"],
+ ["policy", "list", "-p", "core"],
+ ["policy", "list", "--pack", "core"],
+ ["policy", "list", "-r", "core.local"],
+ ["policy", "list", "--resource-ref", "core.local"],
+ ["policy", "list", "-pt", "action.type1"],
+ ["policy", "list", "--policy-type", "action.type1"],
+ ["policy", "list", "-r", "core.local", "-pt", "action.type1"],
+ [
+ "policy",
+ "list",
+ "--resource-ref",
+ "core.local",
+ "--policy-type",
+ "action.type1",
+ ],
+ ["policy", "get", "abc"],
+ ["policy", "create", "/tmp/policy.json"],
+ ["policy", "update", "123", "/tmp/policy.json"],
+ ["policy", "delete", "abc"],
]
self._validate_parser(args_list)
def test_policy_type(self):
args_list = [
- ['policy-type', 'list'],
- ['policy-type', 'list', '-r', 'action'],
- ['policy-type', 'list', '--resource-type', 'action'],
- ['policy-type', 'get', 'abc']
+ ["policy-type", "list"],
+ ["policy-type", "list", "-r", "action"],
+ ["policy-type", "list", "--resource-type", "action"],
+ ["policy-type", "get", "abc"],
]
self._validate_parser(args_list)
def test_pack(self):
args_list = [
- ['pack', 'list'],
- ['pack', 'get', 'abc'],
- ['pack', 'search', 'abc'],
- ['pack', 'show', 'abc'],
- ['pack', 'remove', 'abc'],
- ['pack', 'remove', 'abc', '--detail'],
- ['pack', 'install', 'abc'],
- ['pack', 'install', 'abc', '--force'],
- ['pack', 'install', 'abc', '--detail'],
- ['pack', 'config', 'abc']
+ ["pack", "list"],
+ ["pack", "get", "abc"],
+ ["pack", "search", "abc"],
+ ["pack", "show", "abc"],
+ ["pack", "remove", "abc"],
+ ["pack", "remove", "abc", "--detail"],
+ ["pack", "install", "abc"],
+ ["pack", "install", "abc", "--force"],
+ ["pack", "install", "abc", "--detail"],
+ ["pack", "config", "abc"],
]
self._validate_parser(args_list)
- @mock.patch('st2client.base.ST2_CONFIG_PATH', '/home/does/not/exist')
+ @mock.patch("st2client.base.ST2_CONFIG_PATH", "/home/does/not/exist")
def test_print_config_default_config_no_config(self):
- os.environ['ST2_CONFIG_FILE'] = '/home/does/not/exist'
- argv = ['--print-config']
+ os.environ["ST2_CONFIG_FILE"] = "/home/does/not/exist"
+ argv = ["--print-config"]
self.assertEqual(self.shell.run(argv), 3)
self.stdout.seek(0)
stdout = self.stdout.read()
- self.assertIn('username = None', stdout)
- self.assertIn('cache_token = True', stdout)
+ self.assertIn("username = None", stdout)
+ self.assertIn("cache_token = True", stdout)
def test_print_config_custom_config_as_env_variable(self):
- os.environ['ST2_CONFIG_FILE'] = CONFIG_FILE_PATH_FULL
- argv = ['--print-config']
+ os.environ["ST2_CONFIG_FILE"] = CONFIG_FILE_PATH_FULL
+ argv = ["--print-config"]
self.assertEqual(self.shell.run(argv), 3)
self.stdout.seek(0)
stdout = self.stdout.read()
- self.assertIn('username = test1', stdout)
- self.assertIn('cache_token = False', stdout)
+ self.assertIn("username = test1", stdout)
+ self.assertIn("cache_token = False", stdout)
def test_print_config_custom_config_as_command_line_argument(self):
- argv = ['--print-config', '--config-file=%s' % (CONFIG_FILE_PATH_FULL)]
+ argv = ["--print-config", "--config-file=%s" % (CONFIG_FILE_PATH_FULL)]
self.assertEqual(self.shell.run(argv), 3)
self.stdout.seek(0)
stdout = self.stdout.read()
- self.assertIn('username = test1', stdout)
- self.assertIn('cache_token = False', stdout)
+ self.assertIn("username = test1", stdout)
+ self.assertIn("cache_token = False", stdout)
def test_run(self):
args_list = [
- ['run', '-h'],
- ['run', 'abc', '-h'],
- ['run', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'],
- ['run', 'remote-fib', 'hosts=192.168.1.1', '3', '8']
+ ["run", "-h"],
+ ["run", "abc", "-h"],
+ ["run", "remote", "hosts=192.168.1.1", "user=st2", 'cmd="ls -l"'],
+ ["run", "remote-fib", "hosts=192.168.1.1", "3", "8"],
]
self._validate_parser(args_list, is_subcommand=False)
def test_runner(self):
- args_list = [
- ['runner', 'list'],
- ['runner', 'get', 'abc']
- ]
+ args_list = [["runner", "list"], ["runner", "get", "abc"]]
self._validate_parser(args_list)
def test_rule(self):
args_list = [
- ['rule', 'list'],
- ['rule', 'list', '-n', '1'],
- ['rule', 'get', 'abc'],
- ['rule', 'create', '/tmp/rule.json'],
- ['rule', 'update', '123', '/tmp/rule.json'],
- ['rule', 'delete', 'abc']
+ ["rule", "list"],
+ ["rule", "list", "-n", "1"],
+ ["rule", "get", "abc"],
+ ["rule", "create", "/tmp/rule.json"],
+ ["rule", "update", "123", "/tmp/rule.json"],
+ ["rule", "delete", "abc"],
]
self._validate_parser(args_list)
def test_trigger(self):
args_list = [
- ['trigger', 'list'],
- ['trigger', 'get', 'abc'],
- ['trigger', 'create', '/tmp/trigger.json'],
- ['trigger', 'update', '123', '/tmp/trigger.json'],
- ['trigger', 'delete', 'abc']
+ ["trigger", "list"],
+ ["trigger", "get", "abc"],
+ ["trigger", "create", "/tmp/trigger.json"],
+ ["trigger", "update", "123", "/tmp/trigger.json"],
+ ["trigger", "delete", "abc"],
]
self._validate_parser(args_list)
def test_workflow(self):
args_list = [
- ['workflow', 'inspect', '--file', '/path/to/workflow/definition'],
- ['workflow', 'inspect', '--action', 'mock.foobar']
+ ["workflow", "inspect", "--file", "/path/to/workflow/definition"],
+ ["workflow", "inspect", "--action", "mock.foobar"],
]
self._validate_parser(args_list)
- @mock.patch('sys.exit', mock.Mock())
- @mock.patch('st2client.shell.__version__', 'v2.8.0')
+ @mock.patch("sys.exit", mock.Mock())
+ @mock.patch("st2client.shell.__version__", "v2.8.0")
def test_get_version_no_package_metadata_file_stable_version(self):
# stable version, package metadata file doesn't exist on disk - no git revision should be
# included
shell = Shell()
- shell.parser.parse_args(args=['--version'])
+ shell.parser.parse_args(args=["--version"])
self.version_output.seek(0)
stderr = self.version_output.read()
- self.assertIn('v2.8.0, on Python', stderr)
+ self.assertIn("v2.8.0, on Python", stderr)
- @mock.patch('sys.exit', mock.Mock())
- @mock.patch('st2client.shell.__version__', 'v2.8.0')
+ @mock.patch("sys.exit", mock.Mock())
+ @mock.patch("st2client.shell.__version__", "v2.8.0")
def test_get_version_package_metadata_file_exists_stable_version(self):
# stable version, package metadata file exists on disk - no git revision should be included
package_metadata_path = self._write_mock_package_metadata_file()
st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path
shell = Shell()
- shell.run(argv=['--version'])
+ shell.run(argv=["--version"])
self.version_output.seek(0)
stderr = self.version_output.read()
- self.assertIn('v2.8.0, on Python', stderr)
+ self.assertIn("v2.8.0, on Python", stderr)
- @mock.patch('sys.exit', mock.Mock())
- @mock.patch('st2client.shell.__version__', 'v2.9dev')
- @mock.patch('st2client.shell.PACKAGE_METADATA_FILE_PATH', '/tmp/doesnt/exist.1')
+ @mock.patch("sys.exit", mock.Mock())
+ @mock.patch("st2client.shell.__version__", "v2.9dev")
+ @mock.patch("st2client.shell.PACKAGE_METADATA_FILE_PATH", "/tmp/doesnt/exist.1")
def test_get_version_no_package_metadata_file_dev_version(self):
# dev version, package metadata file doesn't exist on disk - no git revision should be
# included since package metadata file doesn't exist on disk
shell = Shell()
- shell.parser.parse_args(args=['--version'])
+ shell.parser.parse_args(args=["--version"])
self.version_output.seek(0)
stderr = self.version_output.read()
- self.assertIn('v2.9dev, on Python', stderr)
+ self.assertIn("v2.9dev, on Python", stderr)
- @mock.patch('sys.exit', mock.Mock())
- @mock.patch('st2client.shell.__version__', 'v2.9dev')
+ @mock.patch("sys.exit", mock.Mock())
+ @mock.patch("st2client.shell.__version__", "v2.9dev")
def test_get_version_package_metadata_file_exists_dev_version(self):
# dev version, package metadata file exists on disk - git revision should be included
# since package metadata file exists on disk and contains server.git_sha attribute
@@ -430,55 +461,67 @@ def test_get_version_package_metadata_file_exists_dev_version(self):
st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path
shell = Shell()
- shell.parser.parse_args(args=['--version'])
+ shell.parser.parse_args(args=["--version"])
self.version_output.seek(0)
stderr = self.version_output.read()
- self.assertIn('v2.9dev (abcdefg), on Python', stderr)
+ self.assertIn("v2.9dev (abcdefg), on Python", stderr)
- @mock.patch('locale.getdefaultlocale', mock.Mock(return_value=['en_US']))
- @mock.patch('locale.getpreferredencoding', mock.Mock(return_value='iso'))
+ @mock.patch("locale.getdefaultlocale", mock.Mock(return_value=["en_US"]))
+ @mock.patch("locale.getpreferredencoding", mock.Mock(return_value="iso"))
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
- @mock.patch('st2client.shell.LOGGER')
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
+ @mock.patch("st2client.shell.LOGGER")
def test_non_unicode_encoding_locale_warning_is_printed(self, mock_logger):
shell = Shell()
- shell.run(argv=['trigger', 'list'])
+ shell.run(argv=["trigger", "list"])
call_args = mock_logger.warn.call_args[0][0]
- self.assertIn('Locale en_US with encoding iso which is not UTF-8 is used.', call_args)
+ self.assertIn(
+ "Locale en_US with encoding iso which is not UTF-8 is used.", call_args
+ )
- @mock.patch('locale.getdefaultlocale', mock.Mock(side_effect=ValueError('bar')))
- @mock.patch('locale.getpreferredencoding', mock.Mock(side_effect=ValueError('bar')))
+ @mock.patch("locale.getdefaultlocale", mock.Mock(side_effect=ValueError("bar")))
+ @mock.patch("locale.getpreferredencoding", mock.Mock(side_effect=ValueError("bar")))
@mock.patch.object(
- httpclient.HTTPClient, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK')))
- @mock.patch('st2client.shell.LOGGER')
+ httpclient.HTTPClient,
+ "get",
+ mock.MagicMock(
+ return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK")
+ ),
+ )
+ @mock.patch("st2client.shell.LOGGER")
def test_failed_to_get_locale_encoding_warning_is_printed(self, mock_logger):
shell = Shell()
- shell.run(argv=['trigger', 'list'])
+ shell.run(argv=["trigger", "list"])
call_args = mock_logger.warn.call_args[0][0]
- self.assertTrue('Locale unknown with encoding unknown which is not UTF-8 is used.' in
- call_args)
+ self.assertTrue(
+ "Locale unknown with encoding unknown which is not UTF-8 is used."
+ in call_args
+ )
def _write_mock_package_metadata_file(self):
_, package_metadata_path = tempfile.mkstemp()
- with open(package_metadata_path, 'w') as fp:
+ with open(package_metadata_path, "w") as fp:
fp.write(MOCK_PACKAGE_METADATA)
return package_metadata_path
- @unittest2.skipIf(True, 'skipping until checks are re-enabled')
+ @unittest2.skipIf(True, "skipping until checks are re-enabled")
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse("{}", 200, 'OK')))
+ requests, "get", mock.MagicMock(return_value=base.FakeResponse("{}", 200, "OK"))
+ )
def test_dont_warn_multiple_times(self):
mock_temp_dir_path = tempfile.mkdtemp()
- mock_config_dir_path = os.path.join(mock_temp_dir_path, 'testconfig')
- mock_config_path = os.path.join(mock_config_dir_path, 'config')
+ mock_config_dir_path = os.path.join(mock_temp_dir_path, "testconfig")
+ mock_config_path = os.path.join(mock_config_dir_path, "config")
# Make the temporary config directory
os.makedirs(mock_config_dir_path)
@@ -495,38 +538,46 @@ def test_dont_warn_multiple_times(self):
shell.LOG = mock.Mock()
# Test without token.
- shell.run(['--config-file', mock_config_path, 'action', 'list'])
+ shell.run(["--config-file", mock_config_path, "action", "list"])
self.assertEqual(shell.LOG.warn.call_count, 2)
self.assertEqual(
shell.LOG.warn.call_args_list[0][0][0][:63],
- 'The StackStorm configuration directory permissions are insecure')
+ "The StackStorm configuration directory permissions are insecure",
+ )
self.assertEqual(
shell.LOG.warn.call_args_list[1][0][0][:58],
- 'The StackStorm configuration file permissions are insecure')
+ "The StackStorm configuration file permissions are insecure",
+ )
self.assertEqual(shell.LOG.info.call_count, 2)
self.assertEqual(
- shell.LOG.info.call_args_list[0][0][0], "The SGID bit is not "
- "set on the StackStorm configuration directory.")
+ shell.LOG.info.call_args_list[0][0][0],
+ "The SGID bit is not " "set on the StackStorm configuration directory.",
+ )
self.assertEqual(
- shell.LOG.info.call_args_list[1][0][0], 'Skipping parsing CLI config')
+ shell.LOG.info.call_args_list[1][0][0], "Skipping parsing CLI config"
+ )
class CLITokenCachingTestCase(unittest2.TestCase):
def setUp(self):
super(CLITokenCachingTestCase, self).setUp()
self._mock_temp_dir_path = tempfile.mkdtemp()
- self._mock_config_directory_path = os.path.join(self._mock_temp_dir_path, 'testconfig')
- self._mock_config_path = os.path.join(self._mock_config_directory_path, 'config')
+ self._mock_config_directory_path = os.path.join(
+ self._mock_temp_dir_path, "testconfig"
+ )
+ self._mock_config_path = os.path.join(
+ self._mock_config_directory_path, "config"
+ )
os.makedirs(self._mock_config_directory_path)
- self._p1 = mock.patch('st2client.base.ST2_CONFIG_DIRECTORY',
- self._mock_config_directory_path)
- self._p2 = mock.patch('st2client.base.ST2_CONFIG_PATH',
- self._mock_config_path)
+ self._p1 = mock.patch(
+ "st2client.base.ST2_CONFIG_DIRECTORY", self._mock_config_directory_path
+ )
+ self._p2 = mock.patch("st2client.base.ST2_CONFIG_PATH", self._mock_config_path)
self._p1.start()
self._p2.start()
@@ -536,46 +587,46 @@ def tearDown(self):
self._p2.stop()
for var in [
- 'ST2_BASE_URL',
- 'ST2_API_URL',
- 'ST2_STREAM_URL',
- 'ST2_DATASTORE_URL',
- 'ST2_AUTH_TOKEN'
+ "ST2_BASE_URL",
+ "ST2_API_URL",
+ "ST2_STREAM_URL",
+ "ST2_DATASTORE_URL",
+ "ST2_AUTH_TOKEN",
]:
if var in os.environ:
del os.environ[var]
def _write_mock_config(self):
- with open(self._mock_config_path, 'w') as fp:
+ with open(self._mock_config_path, "w") as fp:
fp.write(MOCK_CONFIG)
def test_get_cached_auth_token_invalid_permissions(self):
shell = Shell()
client = Client()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
cached_token_path = shell._get_cached_token_path_for_user(username=username)
- data = {
- 'token': 'yayvalid',
- 'expire_timestamp': (int(time.time()) + 20)
- }
- with open(cached_token_path, 'w') as fp:
+ data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)}
+ with open(cached_token_path, "w") as fp:
fp.write(json.dumps(data))
# 1. Current user doesn't have read access to the config directory
os.chmod(self._mock_config_directory_path, 0o000)
shell.LOG = mock.Mock()
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
self.assertEqual(result, None)
self.assertEqual(shell.LOG.warn.call_count, 1)
log_message = shell.LOG.warn.call_args[0][0]
- expected_msg = ('Unable to retrieve cached token from .*? read access to the parent '
- 'directory')
+ expected_msg = (
+ "Unable to retrieve cached token from .*? read access to the parent "
+ "directory"
+ )
self.assertRegexpMatches(log_message, expected_msg)
# 2. Read access on the directory, but not on the cached token file
@@ -583,14 +634,17 @@ def test_get_cached_auth_token_invalid_permissions(self):
os.chmod(cached_token_path, 0o000)
shell.LOG = mock.Mock()
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
self.assertEqual(result, None)
self.assertEqual(shell.LOG.warn.call_count, 1)
log_message = shell.LOG.warn.call_args[0][0]
- expected_msg = ('Unable to retrieve cached token from .*? read access to this file')
+ expected_msg = (
+ "Unable to retrieve cached token from .*? read access to this file"
+ )
self.assertRegexpMatches(log_message, expected_msg)
# 3. Other users also have read access to the file
@@ -598,31 +652,29 @@ def test_get_cached_auth_token_invalid_permissions(self):
os.chmod(cached_token_path, 0o444)
shell.LOG = mock.Mock()
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
- self.assertEqual(result, 'yayvalid')
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
+ self.assertEqual(result, "yayvalid")
self.assertEqual(shell.LOG.warn.call_count, 1)
log_message = shell.LOG.warn.call_args[0][0]
- expected_msg = ('Permissions .*? for cached token file .*? are too permissive.*')
+ expected_msg = "Permissions .*? for cached token file .*? are too permissive.*"
self.assertRegexpMatches(log_message, expected_msg)
def test_cache_auth_token_invalid_permissions(self):
shell = Shell()
- username = 'testu'
+ username = "testu"
cached_token_path = shell._get_cached_token_path_for_user(username=username)
expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30)
- token_db = TokenDB(user=username, token='fyeah', expiry=expiry)
+ token_db = TokenDB(user=username, token="fyeah", expiry=expiry)
cached_token_path = shell._get_cached_token_path_for_user(username=username)
- data = {
- 'token': 'yayvalid',
- 'expire_timestamp': (int(time.time()) + 20)
- }
- with open(cached_token_path, 'w') as fp:
+ data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)}
+ with open(cached_token_path, "w") as fp:
fp.write(json.dumps(data))
# 1. Current user has no write access to the parent directory
@@ -634,8 +686,10 @@ def test_cache_auth_token_invalid_permissions(self):
self.assertEqual(shell.LOG.warn.call_count, 1)
log_message = shell.LOG.warn.call_args[0][0]
- expected_msg = ('Unable to write token to .*? doesn\'t have write access to the parent '
- 'directory')
+ expected_msg = (
+ "Unable to write token to .*? doesn't have write access to the parent "
+ "directory"
+ )
self.assertRegexpMatches(log_message, expected_msg)
# 2. Current user has no write access to the cached token file
@@ -648,86 +702,93 @@ def test_cache_auth_token_invalid_permissions(self):
self.assertEqual(shell.LOG.warn.call_count, 1)
log_message = shell.LOG.warn.call_args[0][0]
- expected_msg = ('Unable to write token to .*? doesn\'t have write access to this file')
+ expected_msg = (
+ "Unable to write token to .*? doesn't have write access to this file"
+ )
self.assertRegexpMatches(log_message, expected_msg)
def test_get_cached_auth_token_no_token_cache_file(self):
client = Client()
shell = Shell()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
self.assertEqual(result, None)
def test_get_cached_auth_token_corrupted_token_cache_file(self):
client = Client()
shell = Shell()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
cached_token_path = shell._get_cached_token_path_for_user(username=username)
- with open(cached_token_path, 'w') as fp:
- fp.write('CORRRRRUPTED!')
-
- expected_msg = 'File (.+) with cached token is corrupted or invalid'
- self.assertRaisesRegexp(ValueError, expected_msg, shell._get_cached_auth_token,
- client=client, username=username, password=password)
+ with open(cached_token_path, "w") as fp:
+ fp.write("CORRRRRUPTED!")
+
+ expected_msg = "File (.+) with cached token is corrupted or invalid"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ shell._get_cached_auth_token,
+ client=client,
+ username=username,
+ password=password,
+ )
def test_get_cached_auth_token_expired_token_in_cache_file(self):
client = Client()
shell = Shell()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
cached_token_path = shell._get_cached_token_path_for_user(username=username)
- data = {
- 'token': 'expired',
- 'expire_timestamp': (int(time.time()) - 10)
- }
- with open(cached_token_path, 'w') as fp:
+ data = {"token": "expired", "expire_timestamp": (int(time.time()) - 10)}
+ with open(cached_token_path, "w") as fp:
fp.write(json.dumps(data))
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
self.assertEqual(result, None)
def test_get_cached_auth_token_valid_token_in_cache_file(self):
client = Client()
shell = Shell()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
cached_token_path = shell._get_cached_token_path_for_user(username=username)
- data = {
- 'token': 'yayvalid',
- 'expire_timestamp': (int(time.time()) + 20)
- }
- with open(cached_token_path, 'w') as fp:
+ data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)}
+ with open(cached_token_path, "w") as fp:
fp.write(json.dumps(data))
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
- self.assertEqual(result, 'yayvalid')
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
+ self.assertEqual(result, "yayvalid")
def test_cache_auth_token_success(self):
client = Client()
shell = Shell()
- username = 'testu'
- password = 'testp'
+ username = "testu"
+ password = "testp"
expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30)
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
self.assertEqual(result, None)
- token_db = TokenDB(user=username, token='fyeah', expiry=expiry)
+ token_db = TokenDB(user=username, token="fyeah", expiry=expiry)
shell._cache_auth_token(token_obj=token_db)
- result = shell._get_cached_auth_token(client=client, username=username,
- password=password)
- self.assertEqual(result, 'fyeah')
+ result = shell._get_cached_auth_token(
+ client=client, username=username, password=password
+ )
+ self.assertEqual(result, "fyeah")
def test_automatic_auth_skipped_on_auth_command(self):
self._write_mock_config()
@@ -735,7 +796,7 @@ def test_automatic_auth_skipped_on_auth_command(self):
shell = Shell()
shell._get_auth_token = mock.Mock()
- argv = ['auth', 'testu', '-p', 'testp']
+ argv = ["auth", "testu", "-p", "testp"]
args = shell.parser.parse_args(args=argv)
shell.get_client(args=args)
self.assertEqual(shell._get_auth_token.call_count, 0)
@@ -746,8 +807,8 @@ def test_automatic_auth_skipped_if_token_provided_as_env_variable(self):
shell = Shell()
shell._get_auth_token = mock.Mock()
- os.environ['ST2_AUTH_TOKEN'] = 'fooo'
- argv = ['action', 'list']
+ os.environ["ST2_AUTH_TOKEN"] = "fooo"
+ argv = ["action", "list"]
args = shell.parser.parse_args(args=argv)
shell.get_client(args=args)
self.assertEqual(shell._get_auth_token.call_count, 0)
@@ -758,12 +819,12 @@ def test_automatic_auth_skipped_if_token_provided_as_cli_argument(self):
shell = Shell()
shell._get_auth_token = mock.Mock()
- argv = ['action', 'list', '--token=bar']
+ argv = ["action", "list", "--token=bar"]
args = shell.parser.parse_args(args=argv)
shell.get_client(args=args)
self.assertEqual(shell._get_auth_token.call_count, 0)
- argv = ['action', 'list', '-t', 'bar']
+ argv = ["action", "list", "-t", "bar"]
args = shell.parser.parse_args(args=argv)
shell.get_client(args=args)
self.assertEqual(shell._get_auth_token.call_count, 0)
diff --git a/st2client/tests/unit/test_ssl.py b/st2client/tests/unit/test_ssl.py
index 5ed8bfbf28..5db836482b 100644
--- a/st2client/tests/unit/test_ssl.py
+++ b/st2client/tests/unit/test_ssl.py
@@ -27,17 +27,18 @@
LOG = logging.getLogger(__name__)
-USERNAME = 'stanley'
-PASSWORD = 'ShhhDontTell'
-HEADERS = {'content-type': 'application/json'}
-AUTH_URL = 'https://127.0.0.1:9100/tokens'
-GET_RULES_URL = ('http://127.0.0.1:9101/v1/rules/'
- '?include_attributes=ref,pack,description,enabled&limit=50')
-GET_RULES_URL = GET_RULES_URL.replace(',', '%2C')
+USERNAME = "stanley"
+PASSWORD = "ShhhDontTell"
+HEADERS = {"content-type": "application/json"}
+AUTH_URL = "https://127.0.0.1:9100/tokens"
+GET_RULES_URL = (
+ "http://127.0.0.1:9101/v1/rules/"
+ "?include_attributes=ref,pack,description,enabled&limit=50"
+)
+GET_RULES_URL = GET_RULES_URL.replace(",", "%2C")
class TestHttps(base.BaseCLITestCase):
-
def __init__(self, *args, **kwargs):
super(TestHttps, self).__init__(*args, **kwargs)
self.shell = shell.Shell()
@@ -46,11 +47,11 @@ def setUp(self):
super(TestHttps, self).setUp()
# Setup environment.
- os.environ['ST2_BASE_URL'] = 'http://127.0.0.1'
- os.environ['ST2_AUTH_URL'] = 'https://127.0.0.1:9100'
+ os.environ["ST2_BASE_URL"] = "http://127.0.0.1"
+ os.environ["ST2_AUTH_URL"] = "https://127.0.0.1:9100"
- if 'ST2_CACERT' in os.environ:
- del os.environ['ST2_CACERT']
+ if "ST2_CACERT" in os.environ:
+ del os.environ["ST2_CACERT"]
# Create a temp file to mock a cert file.
self.cacert_fd, self.cacert_path = tempfile.mkstemp()
@@ -59,58 +60,78 @@ def tearDown(self):
super(TestHttps, self).tearDown()
# Clean up environment.
- if 'ST2_CACERT' in os.environ:
- del os.environ['ST2_CACERT']
- if 'ST2_BASE_URL' in os.environ:
- del os.environ['ST2_BASE_URL']
+ if "ST2_CACERT" in os.environ:
+ del os.environ["ST2_CACERT"]
+ if "ST2_BASE_URL" in os.environ:
+ del os.environ["ST2_BASE_URL"]
# Clean up temp files.
os.close(self.cacert_fd)
os.unlink(self.cacert_path)
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_https_without_cacert(self):
- self.shell.run(['auth', USERNAME, '-p', PASSWORD])
- kwargs = {'verify': False, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)}
+ self.shell.run(["auth", USERNAME, "-p", PASSWORD])
+ kwargs = {"verify": False, "headers": HEADERS, "auth": (USERNAME, PASSWORD)}
requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs)
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_https_with_cacert_from_cli(self):
- self.shell.run(['--cacert', self.cacert_path, 'auth', USERNAME, '-p', PASSWORD])
- kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)}
+ self.shell.run(["--cacert", self.cacert_path, "auth", USERNAME, "-p", PASSWORD])
+ kwargs = {
+ "verify": self.cacert_path,
+ "headers": HEADERS,
+ "auth": (USERNAME, PASSWORD),
+ }
requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs)
@mock.patch.object(
- requests, 'post',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "post",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_https_with_cacert_from_env(self):
- os.environ['ST2_CACERT'] = self.cacert_path
- self.shell.run(['auth', USERNAME, '-p', PASSWORD])
- kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)}
+ os.environ["ST2_CACERT"] = self.cacert_path
+ self.shell.run(["auth", USERNAME, "-p", PASSWORD])
+ kwargs = {
+ "verify": self.cacert_path,
+ "headers": HEADERS,
+ "auth": (USERNAME, PASSWORD),
+ }
requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK")),
+ )
def test_decorate_http_without_cacert(self):
- self.shell.run(['rule', 'list'])
+ self.shell.run(["rule", "list"])
requests.get.assert_called_with(GET_RULES_URL)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_http_with_cacert_from_cli(self):
- self.shell.run(['--cacert', self.cacert_path, 'rule', 'list'])
+ self.shell.run(["--cacert", self.cacert_path, "rule", "list"])
requests.get.assert_called_with(GET_RULES_URL)
@mock.patch.object(
- requests, 'get',
- mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK')))
+ requests,
+ "get",
+ mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")),
+ )
def test_decorate_http_with_cacert_from_env(self):
- os.environ['ST2_CACERT'] = self.cacert_path
- self.shell.run(['rule', 'list'])
+ os.environ["ST2_CACERT"] = self.cacert_path
+ self.shell.run(["rule", "list"])
requests.get.assert_called_with(GET_RULES_URL)
diff --git a/st2client/tests/unit/test_trace_commands.py b/st2client/tests/unit/test_trace_commands.py
index 99d60598a4..ea3b552d47 100644
--- a/st2client/tests/unit/test_trace_commands.py
+++ b/st2client/tests/unit/test_trace_commands.py
@@ -23,23 +23,38 @@
class TraceCommandTestCase(base.BaseCLITestCase):
-
def test_trace_get_filter_trace_components_executions(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions',
- [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}])
- setattr(trace, 'rules',
- [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}])
- setattr(trace, 'trigger_instances',
- [{'object_id': 't1', 'caused_by': {}},
- {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}])
+ setattr(
+ trace,
+ "action_executions",
+ [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}],
+ )
+ setattr(
+ trace,
+ "rules",
+ [
+ {
+ "object_id": "r1",
+ "caused_by": {"id": "t1", "type": "trigger_instance"},
+ }
+ ],
+ )
+ setattr(
+ trace,
+ "trigger_instances",
+ [
+ {"object_id": "t1", "caused_by": {}},
+ {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}},
+ ],
+ )
args = argparse.Namespace()
- setattr(args, 'execution', 'e1')
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "execution", "e1")
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args)
self.assertEqual(len(trace.action_executions), 1)
@@ -48,22 +63,38 @@ def test_trace_get_filter_trace_components_executions(self):
def test_trace_get_filter_trace_components_rules(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions',
- [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}])
- setattr(trace, 'rules',
- [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}])
- setattr(trace, 'trigger_instances',
- [{'object_id': 't1', 'caused_by': {}},
- {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}])
+ setattr(
+ trace,
+ "action_executions",
+ [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}],
+ )
+ setattr(
+ trace,
+ "rules",
+ [
+ {
+ "object_id": "r1",
+ "caused_by": {"id": "t1", "type": "trigger_instance"},
+ }
+ ],
+ )
+ setattr(
+ trace,
+ "trigger_instances",
+ [
+ {"object_id": "t1", "caused_by": {}},
+ {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}},
+ ],
+ )
args = argparse.Namespace()
- setattr(args, 'execution', None)
- setattr(args, 'rule', 'r1')
- setattr(args, 'trigger_instance', None)
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "execution", None)
+ setattr(args, "rule", "r1")
+ setattr(args, "trigger_instance", None)
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args)
self.assertEqual(len(trace.action_executions), 0)
@@ -72,22 +103,38 @@ def test_trace_get_filter_trace_components_rules(self):
def test_trace_get_filter_trace_components_trigger_instances(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions',
- [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}])
- setattr(trace, 'rules',
- [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}])
- setattr(trace, 'trigger_instances',
- [{'object_id': 't1', 'caused_by': {}},
- {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}])
+ setattr(
+ trace,
+ "action_executions",
+ [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}],
+ )
+ setattr(
+ trace,
+ "rules",
+ [
+ {
+ "object_id": "r1",
+ "caused_by": {"id": "t1", "type": "trigger_instance"},
+ }
+ ],
+ )
+ setattr(
+ trace,
+ "trigger_instances",
+ [
+ {"object_id": "t1", "caused_by": {}},
+ {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}},
+ ],
+ )
args = argparse.Namespace()
- setattr(args, 'execution', None)
- setattr(args, 'rule', None)
- setattr(args, 'trigger_instance', 't1')
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "execution", None)
+ setattr(args, "rule", None)
+ setattr(args, "trigger_instance", "t1")
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args)
self.assertEqual(len(trace.action_executions), 0)
@@ -96,15 +143,15 @@ def test_trace_get_filter_trace_components_trigger_instances(self):
def test_trace_get_apply_display_filters_show_executions(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions', ['1'])
- setattr(trace, 'rules', ['1'])
- setattr(trace, 'trigger_instances', ['1'])
+ setattr(trace, "action_executions", ["1"])
+ setattr(trace, "rules", ["1"])
+ setattr(trace, "trigger_instances", ["1"])
args = argparse.Namespace()
- setattr(args, 'show_executions', True)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "show_executions", True)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertTrue(trace.action_executions)
@@ -113,15 +160,15 @@ def test_trace_get_apply_display_filters_show_executions(self):
def test_trace_get_apply_display_filters_show_rules(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions', ['1'])
- setattr(trace, 'rules', ['1'])
- setattr(trace, 'trigger_instances', ['1'])
+ setattr(trace, "action_executions", ["1"])
+ setattr(trace, "rules", ["1"])
+ setattr(trace, "trigger_instances", ["1"])
args = argparse.Namespace()
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', True)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", True)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertFalse(trace.action_executions)
@@ -130,15 +177,15 @@ def test_trace_get_apply_display_filters_show_rules(self):
def test_trace_get_apply_display_filters_show_trigger_instances(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions', ['1'])
- setattr(trace, 'rules', ['1'])
- setattr(trace, 'trigger_instances', ['1'])
+ setattr(trace, "action_executions", ["1"])
+ setattr(trace, "rules", ["1"])
+ setattr(trace, "trigger_instances", ["1"])
args = argparse.Namespace()
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', True)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", True)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertFalse(trace.action_executions)
@@ -147,15 +194,15 @@ def test_trace_get_apply_display_filters_show_trigger_instances(self):
def test_trace_get_apply_display_filters_show_multiple(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions', ['1'])
- setattr(trace, 'rules', ['1'])
- setattr(trace, 'trigger_instances', ['1'])
+ setattr(trace, "action_executions", ["1"])
+ setattr(trace, "rules", ["1"])
+ setattr(trace, "trigger_instances", ["1"])
args = argparse.Namespace()
- setattr(args, 'show_executions', True)
- setattr(args, 'show_rules', True)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "show_executions", True)
+ setattr(args, "show_rules", True)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertTrue(trace.action_executions)
@@ -164,15 +211,15 @@ def test_trace_get_apply_display_filters_show_multiple(self):
def test_trace_get_apply_display_filters_show_all(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions', ['1'])
- setattr(trace, 'rules', ['1'])
- setattr(trace, 'trigger_instances', ['1'])
+ setattr(trace, "action_executions", ["1"])
+ setattr(trace, "rules", ["1"])
+ setattr(trace, "trigger_instances", ["1"])
args = argparse.Namespace()
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', False)
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", False)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertEqual(len(trace.action_executions), 1)
@@ -181,19 +228,35 @@ def test_trace_get_apply_display_filters_show_all(self):
def test_trace_get_apply_display_filters_hide_noop(self):
trace = trace_models.Trace()
- setattr(trace, 'action_executions',
- [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}])
- setattr(trace, 'rules',
- [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}])
- setattr(trace, 'trigger_instances',
- [{'object_id': 't1', 'caused_by': {}},
- {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}])
+ setattr(
+ trace,
+ "action_executions",
+ [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}],
+ )
+ setattr(
+ trace,
+ "rules",
+ [
+ {
+ "object_id": "r1",
+ "caused_by": {"id": "t1", "type": "trigger_instance"},
+ }
+ ],
+ )
+ setattr(
+ trace,
+ "trigger_instances",
+ [
+ {"object_id": "t1", "caused_by": {}},
+ {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}},
+ ],
+ )
args = argparse.Namespace()
- setattr(args, 'show_executions', False)
- setattr(args, 'show_rules', False)
- setattr(args, 'show_trigger_instances', False)
- setattr(args, 'hide_noop_triggers', True)
+ setattr(args, "show_executions", False)
+ setattr(args, "show_rules", False)
+ setattr(args, "show_trigger_instances", False)
+ setattr(args, "hide_noop_triggers", True)
trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args)
self.assertEqual(len(trace.action_executions), 1)
diff --git a/st2client/tests/unit/test_util_date.py b/st2client/tests/unit/test_util_date.py
index e29b840ed7..2cdeab95fc 100644
--- a/st2client/tests/unit/test_util_date.py
+++ b/st2client/tests/unit/test_util_date.py
@@ -30,31 +30,31 @@ def test_format_dt(self):
dt = datetime.datetime(2015, 10, 20, 8, 0, 0)
dt = add_utc_tz(dt)
result = format_dt(dt)
- self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC')
+ self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC")
def test_format_isodate(self):
# No timezone, defaults to UTC
- value = 'Tue, 20 Oct 2015 08:00:00 UTC'
+ value = "Tue, 20 Oct 2015 08:00:00 UTC"
result = format_isodate(value=value)
- self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC')
+ self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC")
# Timezone provided
- value = 'Tue, 20 Oct 2015 08:00:00 UTC'
- result = format_isodate(value=value, timezone='Europe/Ljubljana')
- self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST')
+ value = "Tue, 20 Oct 2015 08:00:00 UTC"
+ result = format_isodate(value=value, timezone="Europe/Ljubljana")
+ self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST")
- @mock.patch('st2client.utils.date.get_config')
+ @mock.patch("st2client.utils.date.get_config")
def test_format_isodate_for_user_timezone(self, mock_get_config):
# No timezone, defaults to UTC
mock_get_config.return_value = {}
- value = 'Tue, 20 Oct 2015 08:00:00 UTC'
+ value = "Tue, 20 Oct 2015 08:00:00 UTC"
result = format_isodate_for_user_timezone(value=value)
- self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC')
+ self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC")
# Timezone provided
- mock_get_config.return_value = {'cli': {'timezone': 'Europe/Ljubljana'}}
+ mock_get_config.return_value = {"cli": {"timezone": "Europe/Ljubljana"}}
- value = 'Tue, 20 Oct 2015 08:00:00 UTC'
+ value = "Tue, 20 Oct 2015 08:00:00 UTC"
result = format_isodate_for_user_timezone(value=value)
- self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST')
+ self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST")
diff --git a/st2client/tests/unit/test_util_json.py b/st2client/tests/unit/test_util_json.py
index f44a4b9bf9..2333128c2e 100644
--- a/st2client/tests/unit/test_util_json.py
+++ b/st2client/tests/unit/test_util_json.py
@@ -25,76 +25,67 @@
LOG = logging.getLogger(__name__)
DOC = {
- 'a01': 1,
- 'b01': 2,
- 'c01': {
- 'c11': 3,
- 'd12': 4,
- 'c13': {
- 'c21': 5,
- 'c22': 6
- },
- 'c14': [7, 8, 9]
- }
+ "a01": 1,
+ "b01": 2,
+ "c01": {"c11": 3, "d12": 4, "c13": {"c21": 5, "c22": 6}, "c14": [7, 8, 9]},
}
DOC_IP_ADDRESS = {
- 'ips': {
- "192.168.1.1": {
- "hostname": "router.domain.tld"
- },
- "192.168.1.10": {
- "hostname": "server.domain.tld"
- }
+ "ips": {
+ "192.168.1.1": {"hostname": "router.domain.tld"},
+ "192.168.1.10": {"hostname": "server.domain.tld"},
}
}
class TestGetValue(unittest2.TestCase):
-
def test_dot_notation(self):
- self.assertEqual(jsutil.get_value(DOC, 'a01'), 1)
- self.assertEqual(jsutil.get_value(DOC, 'c01.c11'), 3)
- self.assertEqual(jsutil.get_value(DOC, 'c01.c13.c22'), 6)
- self.assertEqual(jsutil.get_value(DOC, 'c01.c13'), {'c21': 5, 'c22': 6})
- self.assertListEqual(jsutil.get_value(DOC, 'c01.c14'), [7, 8, 9])
+ self.assertEqual(jsutil.get_value(DOC, "a01"), 1)
+ self.assertEqual(jsutil.get_value(DOC, "c01.c11"), 3)
+ self.assertEqual(jsutil.get_value(DOC, "c01.c13.c22"), 6)
+ self.assertEqual(jsutil.get_value(DOC, "c01.c13"), {"c21": 5, "c22": 6})
+ self.assertListEqual(jsutil.get_value(DOC, "c01.c14"), [7, 8, 9])
def test_dot_notation_with_val_error(self):
self.assertRaises(ValueError, jsutil.get_value, DOC, None)
- self.assertRaises(ValueError, jsutil.get_value, DOC, '')
- self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), 'a01')
+ self.assertRaises(ValueError, jsutil.get_value, DOC, "")
+ self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), "a01")
def test_dot_notation_with_key_error(self):
- self.assertIsNone(jsutil.get_value(DOC, 'd01'))
- self.assertIsNone(jsutil.get_value(DOC, 'a01.a11'))
- self.assertIsNone(jsutil.get_value(DOC, 'c01.c11.c21.c31'))
- self.assertIsNone(jsutil.get_value(DOC, 'c01.c14.c31'))
+ self.assertIsNone(jsutil.get_value(DOC, "d01"))
+ self.assertIsNone(jsutil.get_value(DOC, "a01.a11"))
+ self.assertIsNone(jsutil.get_value(DOC, "c01.c11.c21.c31"))
+ self.assertIsNone(jsutil.get_value(DOC, "c01.c14.c31"))
def test_ip_address(self):
- self.assertEqual(jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'),
- {"hostname": "router.domain.tld"})
+ self.assertEqual(
+ jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'),
+ {"hostname": "router.domain.tld"},
+ )
def test_chars_nums_dashes_underscores_calls_simple(self):
- for char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_':
+ for char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_":
with mock.patch("st2client.utils.jsutil._get_value_simple") as mock_simple:
jsutil.get_value(DOC, char)
mock_simple.assert_called_with(DOC, char)
def test_symbols_calls_complex(self):
- for char in '`~!@#$%^&&*()=+{}[]|\\;:\'"<>,./?':
- with mock.patch("st2client.utils.jsutil._get_value_complex") as mock_complex:
+ for char in "`~!@#$%^&&*()=+{}[]|\\;:'\"<>,./?":
+ with mock.patch(
+ "st2client.utils.jsutil._get_value_complex"
+ ) as mock_complex:
jsutil.get_value(DOC, char)
mock_complex.assert_called_with(DOC, char)
@mock.patch("st2client.utils.jsutil._get_value_simple")
def test_single_key_calls_simple(self, mock__get_value_simple):
- jsutil.get_value(DOC, 'a01')
- mock__get_value_simple.assert_called_with(DOC, 'a01')
+ jsutil.get_value(DOC, "a01")
+ mock__get_value_simple.assert_called_with(DOC, "a01")
@mock.patch("st2client.utils.jsutil._get_value_simple")
def test_dot_notation_calls_simple(self, mock__get_value_simple):
- jsutil.get_value(DOC, 'c01.c11')
- mock__get_value_simple.assert_called_with(DOC, 'c01.c11')
+ jsutil.get_value(DOC, "c01.c11")
+ mock__get_value_simple.assert_called_with(DOC, "c01.c11")
@mock.patch("st2client.utils.jsutil._get_value_complex")
def test_ip_address_calls_complex(self, mock__get_value_complex):
@@ -103,54 +94,64 @@ def test_ip_address_calls_complex(self, mock__get_value_complex):
@mock.patch("st2client.utils.jsutil._get_value_complex")
def test_beginning_dot_calls_complex(self, mock__get_value_complex):
- jsutil.get_value(DOC, '.c01.c11')
- mock__get_value_complex.assert_called_with(DOC, '.c01.c11')
+ jsutil.get_value(DOC, ".c01.c11")
+ mock__get_value_complex.assert_called_with(DOC, ".c01.c11")
@mock.patch("st2client.utils.jsutil._get_value_complex")
def test_ending_dot_calls_complex(self, mock__get_value_complex):
- jsutil.get_value(DOC, 'c01.c11.')
- mock__get_value_complex.assert_called_with(DOC, 'c01.c11.')
+ jsutil.get_value(DOC, "c01.c11.")
+ mock__get_value_complex.assert_called_with(DOC, "c01.c11.")
@mock.patch("st2client.utils.jsutil._get_value_complex")
def test_double_dot_calls_complex(self, mock__get_value_complex):
- jsutil.get_value(DOC, 'c01..c11')
- mock__get_value_complex.assert_called_with(DOC, 'c01..c11')
+ jsutil.get_value(DOC, "c01..c11")
+ mock__get_value_complex.assert_called_with(DOC, "c01..c11")
class TestGetKeyValuePairs(unittest2.TestCase):
-
def test_select_kvps(self):
- self.assertEqual(jsutil.get_kvps(DOC, ['a01']),
- {'a01': 1})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11']),
- {'c01': {'c11': 3}})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13.c22']),
- {'c01': {'c13': {'c22': 6}}})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13']),
- {'c01': {'c13': {'c21': 5, 'c22': 6}}})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14']),
- {'c01': {'c14': [7, 8, 9]}})
- self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c21']),
- {'a01': 1, 'c01': {'c11': 3, 'c13': {'c21': 5}}})
- self.assertEqual(jsutil.get_kvps(DOC_IP_ADDRESS,
- ['ips."192.168.1.1"',
- 'ips."192.168.1.10".hostname']),
- {'ips':
- {'"192':
- {'168':
- {'1':
- {'1"': {'hostname': 'router.domain.tld'},
- '10"': {'hostname': 'server.domain.tld'}}}}}})
+ self.assertEqual(jsutil.get_kvps(DOC, ["a01"]), {"a01": 1})
+ self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11"]), {"c01": {"c11": 3}})
+ self.assertEqual(
+ jsutil.get_kvps(DOC, ["c01.c13.c22"]), {"c01": {"c13": {"c22": 6}}}
+ )
+ self.assertEqual(
+ jsutil.get_kvps(DOC, ["c01.c13"]), {"c01": {"c13": {"c21": 5, "c22": 6}}}
+ )
+ self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14"]), {"c01": {"c14": [7, 8, 9]}})
+ self.assertEqual(
+ jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c21"]),
+ {"a01": 1, "c01": {"c11": 3, "c13": {"c21": 5}}},
+ )
+ self.assertEqual(
+ jsutil.get_kvps(
+ DOC_IP_ADDRESS, ['ips."192.168.1.1"', 'ips."192.168.1.10".hostname']
+ ),
+ {
+ "ips": {
+ '"192': {
+ "168": {
+ "1": {
+ '1"': {"hostname": "router.domain.tld"},
+ '10"': {"hostname": "server.domain.tld"},
+ }
+ }
+ }
+ }
+ },
+ )
def test_select_kvps_with_val_error(self):
self.assertRaises(ValueError, jsutil.get_kvps, DOC, [None])
- self.assertRaises(ValueError, jsutil.get_kvps, DOC, [''])
- self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ['a01'])
+ self.assertRaises(ValueError, jsutil.get_kvps, DOC, [""])
+ self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ["a01"])
def test_select_kvps_with_key_error(self):
- self.assertEqual(jsutil.get_kvps(DOC, ['d01']), {})
- self.assertEqual(jsutil.get_kvps(DOC, ['a01.a11']), {})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11.c21.c31']), {})
- self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14.c31']), {})
- self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c23']),
- {'a01': 1, 'c01': {'c11': 3}})
+ self.assertEqual(jsutil.get_kvps(DOC, ["d01"]), {})
+ self.assertEqual(jsutil.get_kvps(DOC, ["a01.a11"]), {})
+ self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11.c21.c31"]), {})
+ self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14.c31"]), {})
+ self.assertEqual(
+ jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c23"]),
+ {"a01": 1, "c01": {"c11": 3}},
+ )
diff --git a/st2client/tests/unit/test_util_misc.py b/st2client/tests/unit/test_util_misc.py
index 6a2cf3a8fc..2e33156adc 100644
--- a/st2client/tests/unit/test_util_misc.py
+++ b/st2client/tests/unit/test_util_misc.py
@@ -21,37 +21,37 @@
class MiscUtilTestCase(unittest2.TestCase):
def test_merge_dicts(self):
- d1 = {'a': 1}
- d2 = {'a': 2}
- expected = {'a': 2}
+ d1 = {"a": 1}
+ d2 = {"a": 2}
+ expected = {"a": 2}
result = merge_dicts(d1, d2)
self.assertEqual(result, expected)
- d1 = {'a': 1}
- d2 = {'b': 1}
- expected = {'a': 1, 'b': 1}
+ d1 = {"a": 1}
+ d2 = {"b": 1}
+ expected = {"a": 1, "b": 1}
result = merge_dicts(d1, d2)
self.assertEqual(result, expected)
- d1 = {'a': 1}
- d2 = {'a': 3, 'b': 1}
- expected = {'a': 3, 'b': 1}
+ d1 = {"a": 1}
+ d2 = {"a": 3, "b": 1}
+ expected = {"a": 3, "b": 1}
result = merge_dicts(d1, d2)
self.assertEqual(result, expected)
- d1 = {'a': 1, 'm': None}
- d2 = {'a': None, 'b': 1, 'c': None}
- expected = {'a': 1, 'b': 1, 'c': None, 'm': None}
+ d1 = {"a": 1, "m": None}
+ d2 = {"a": None, "b": 1, "c": None}
+ expected = {"a": 1, "b": 1, "c": None, "m": None}
result = merge_dicts(d1, d2)
self.assertEqual(result, expected)
- d1 = {'a': 1, 'b': {'a': 1, 'b': 2, 'c': 3}}
- d2 = {'b': {'b': 100}}
- expected = {'a': 1, 'b': {'a': 1, 'b': 100, 'c': 3}}
+ d1 = {"a": 1, "b": {"a": 1, "b": 2, "c": 3}}
+ d2 = {"b": {"b": 100}}
+ expected = {"a": 1, "b": {"a": 1, "b": 100, "c": 3}}
result = merge_dicts(d1, d2)
self.assertEqual(result, expected)
diff --git a/st2client/tests/unit/test_util_strutil.py b/st2client/tests/unit/test_util_strutil.py
index 2d442013de..585e88c389 100644
--- a/st2client/tests/unit/test_util_strutil.py
+++ b/st2client/tests/unit/test_util_strutil.py
@@ -26,17 +26,17 @@ class StrUtilTestCase(unittest2.TestCase):
def test_unescape(self):
in_str = 'Action execution result double escape \\"stuffs\\".\\r\\n'
- expected = 'Action execution result double escape \"stuffs\".\r\n'
+ expected = 'Action execution result double escape "stuffs".\r\n'
out_str = strutil.unescape(in_str)
self.assertEqual(out_str, expected)
def test_unicode_string(self):
- in_str = '\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55'
+ in_str = "\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55"
out_str = strutil.unescape(in_str)
self.assertEqual(out_str, in_str)
def test_strip_carriage_returns(self):
- in_str = 'Windows editors introduce\r\nlike a noob in 2017.'
+ in_str = "Windows editors introduce\r\nlike a noob in 2017."
out_str = strutil.strip_carriage_returns(in_str)
- exp_str = 'Windows editors introduce\nlike a noob in 2017.'
+ exp_str = "Windows editors introduce\nlike a noob in 2017."
self.assertEqual(out_str, exp_str)
diff --git a/st2client/tests/unit/test_util_terminal.py b/st2client/tests/unit/test_util_terminal.py
index c9b6d82b27..29a8386b0b 100644
--- a/st2client/tests/unit/test_util_terminal.py
+++ b/st2client/tests/unit/test_util_terminal.py
@@ -23,20 +23,20 @@
from st2client.utils.terminal import DEFAULT_TERMINAL_SIZE_COLUMNS
from st2client.utils.terminal import get_terminal_size_columns
-__all__ = [
- 'TerminalUtilsTestCase'
-]
+__all__ = ["TerminalUtilsTestCase"]
class TerminalUtilsTestCase(unittest2.TestCase):
def setUp(self):
super(TerminalUtilsTestCase, self).setUp()
- if 'COLUMNS' in os.environ:
- del os.environ['COLUMNS']
+ if "COLUMNS" in os.environ:
+ del os.environ["COLUMNS"]
- @mock.patch.dict(os.environ, {'LINES': '111', 'COLUMNS': '222'})
- def test_get_terminal_size_columns_columns_environment_variable_has_precedence(self):
+ @mock.patch.dict(os.environ, {"LINES": "111", "COLUMNS": "222"})
+ def test_get_terminal_size_columns_columns_environment_variable_has_precedence(
+ self,
+ ):
# Verify that COLUMNS environment variables has precedence over other approaches
columns = get_terminal_size_columns()
@@ -44,16 +44,16 @@ def test_get_terminal_size_columns_columns_environment_variable_has_precedence(s
# make sure that os.environ['COLUMNS'] isn't set so it can't override/screw-up this test
@mock.patch.dict(os.environ, {})
- @mock.patch('fcntl.ioctl', mock.Mock(return_value='dummy'))
- @mock.patch('struct.unpack', mock.Mock(return_value=(333, 444)))
+ @mock.patch("fcntl.ioctl", mock.Mock(return_value="dummy"))
+ @mock.patch("struct.unpack", mock.Mock(return_value=(333, 444)))
def test_get_terminal_size_columns_stdout_is_used(self):
columns = get_terminal_size_columns()
self.assertEqual(columns, 444)
- @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a')))
- @mock.patch('subprocess.Popen')
+ @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a")))
+ @mock.patch("subprocess.Popen")
def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen):
- mock_communicate = mock.Mock(return_value=['555 666'])
+ mock_communicate = mock.Mock(return_value=["555 666"])
mock_process = mock.Mock()
mock_process.returncode = 0
@@ -64,8 +64,8 @@ def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen):
columns = get_terminal_size_columns()
self.assertEqual(columns, 666)
- @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a')))
- @mock.patch('subprocess.Popen', mock.Mock(side_effect=Exception('b')))
+ @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a")))
+ @mock.patch("subprocess.Popen", mock.Mock(side_effect=Exception("b")))
def test_get_terminal_size_default_values_are_used(self):
columns = get_terminal_size_columns()
diff --git a/st2client/tests/unit/test_workflow.py b/st2client/tests/unit/test_workflow.py
index 3896a27bc2..79d580f85d 100644
--- a/st2client/tests/unit/test_workflow.py
+++ b/st2client/tests/unit/test_workflow.py
@@ -31,13 +31,13 @@
LOG = logging.getLogger(__name__)
MOCK_ACTION = {
- 'ref': 'mock.foobar',
- 'runner_type': 'mock-runner',
- 'pack': 'mock',
- 'name': 'foobar',
- 'parameters': {},
- 'enabled': True,
- 'entry_point': 'workflows/foobar.yaml'
+ "ref": "mock.foobar",
+ "runner_type": "mock-runner",
+ "pack": "mock",
+ "name": "foobar",
+ "parameters": {},
+ "enabled": True,
+ "entry_point": "workflows/foobar.yaml",
}
MOCK_WF_DEF = """
@@ -56,73 +56,88 @@ def get_by_ref(**kwargs):
class WorkflowCommandTestCase(st2cli_tests.BaseCLITestCase):
-
def __init__(self, *args, **kwargs):
super(WorkflowCommandTestCase, self).__init__(*args, **kwargs)
self.shell = shell.Shell()
@mock.patch.object(
- httpclient.HTTPClient, 'post_raw',
- mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post_raw",
+ mock.MagicMock(
+ return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK")
+ ),
+ )
def test_inspect_file(self):
- fd, path = tempfile.mkstemp(suffix='.yaml')
+ fd, path = tempfile.mkstemp(suffix=".yaml")
try:
- with open(path, 'a') as f:
+ with open(path, "a") as f:
f.write(MOCK_WF_DEF)
- retcode = self.shell.run(['workflow', 'inspect', '--file', path])
+ retcode = self.shell.run(["workflow", "inspect", "--file", path])
self.assertEqual(retcode, 0)
httpclient.HTTPClient.post_raw.assert_called_with(
- '/inspect',
- MOCK_WF_DEF,
- headers={'content-type': 'text/plain'}
+ "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"}
)
finally:
os.close(fd)
os.unlink(path)
@mock.patch.object(
- httpclient.HTTPClient, 'post_raw',
- mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post_raw",
+ mock.MagicMock(
+ return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK")
+ ),
+ )
def test_inspect_bad_file(self):
- retcode = self.shell.run(['workflow', 'inspect', '--file', '/tmp/foobar'])
+ retcode = self.shell.run(["workflow", "inspect", "--file", "/tmp/foobar"])
self.assertEqual(retcode, 1)
- self.assertIn('does not exist', self.stdout.getvalue())
+ self.assertIn("does not exist", self.stdout.getvalue())
self.assertFalse(httpclient.HTTPClient.post_raw.called)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(side_effect=get_by_ref))
+ models.ResourceManager,
+ "get_by_ref_or_id",
+ mock.MagicMock(side_effect=get_by_ref),
+ )
@mock.patch.object(
- workflow.WorkflowInspectionCommand, 'get_file_content',
- mock.MagicMock(return_value=MOCK_WF_DEF))
+ workflow.WorkflowInspectionCommand,
+ "get_file_content",
+ mock.MagicMock(return_value=MOCK_WF_DEF),
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post_raw',
- mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post_raw",
+ mock.MagicMock(
+ return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK")
+ ),
+ )
def test_inspect_action(self):
- retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar'])
+ retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"])
self.assertEqual(retcode, 0)
httpclient.HTTPClient.post_raw.assert_called_with(
- '/inspect',
- MOCK_WF_DEF,
- headers={'content-type': 'text/plain'}
+ "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"}
)
@mock.patch.object(
- models.ResourceManager, 'get_by_ref_or_id',
- mock.MagicMock(return_value=None))
+ models.ResourceManager, "get_by_ref_or_id", mock.MagicMock(return_value=None)
+ )
@mock.patch.object(
- httpclient.HTTPClient, 'post_raw',
- mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK')))
+ httpclient.HTTPClient,
+ "post_raw",
+ mock.MagicMock(
+ return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK")
+ ),
+ )
def test_inspect_bad_action(self):
- retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar'])
+ retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"])
self.assertEqual(retcode, 1)
- self.assertIn('Unable to identify action', self.stdout.getvalue())
+ self.assertIn("Unable to identify action", self.stdout.getvalue())
self.assertFalse(httpclient.HTTPClient.post_raw.called)
diff --git a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py
index db5cbdcf67..b8de86a661 100755
--- a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py
+++ b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py
@@ -35,16 +35,20 @@ def migrate_datastore():
try:
for kvp in key_value_items:
- kvp_id = getattr(kvp, 'id', None)
- secret = getattr(kvp, 'secret', False)
- scope = getattr(kvp, 'scope', SYSTEM_SCOPE)
- new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name,
- expire_timestamp=kvp.expire_timestamp,
- value=kvp.value, secret=secret,
- scope=scope)
+ kvp_id = getattr(kvp, "id", None)
+ secret = getattr(kvp, "secret", False)
+ scope = getattr(kvp, "scope", SYSTEM_SCOPE)
+ new_kvp_db = KeyValuePairDB(
+ id=kvp_id,
+ name=kvp.name,
+ expire_timestamp=kvp.expire_timestamp,
+ value=kvp.value,
+ secret=secret,
+ scope=scope,
+ )
KeyValuePair.add_or_update(new_kvp_db)
except:
- print('ERROR: Failed migrating datastore item with name: %s' % kvp.name)
+ print("ERROR: Failed migrating datastore item with name: %s" % kvp.name)
tb.print_exc()
raise
@@ -58,10 +62,10 @@ def main():
# Migrate rules.
try:
migrate_datastore()
- print('SUCCESS: Datastore items migrated successfully.')
+ print("SUCCESS: Datastore items migrated successfully.")
exit_code = 0
except:
- print('ABORTED: Datastore migration aborted on first failure.')
+ print("ABORTED: Datastore migration aborted on first failure.")
exit_code = 1
# Disconnect from db.
@@ -69,5 +73,5 @@ def main():
sys.exit(exit_code)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py
index 24275f80dc..a1a500ad96 100755
--- a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py
+++ b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py
@@ -32,9 +32,9 @@ def migrate_datastore():
try:
for kvp in key_value_items:
- kvp_id = getattr(kvp, 'id', None)
- secret = getattr(kvp, 'secret', False)
- scope = getattr(kvp, 'scope', SYSTEM_SCOPE)
+ kvp_id = getattr(kvp, "id", None)
+ secret = getattr(kvp, "secret", False)
+ scope = getattr(kvp, "scope", SYSTEM_SCOPE)
if scope == USER_SCOPE:
scope = FULL_USER_SCOPE
@@ -42,13 +42,17 @@ def migrate_datastore():
if scope == SYSTEM_SCOPE:
scope = FULL_SYSTEM_SCOPE
- new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name,
- expire_timestamp=kvp.expire_timestamp,
- value=kvp.value, secret=secret,
- scope=scope)
+ new_kvp_db = KeyValuePairDB(
+ id=kvp_id,
+ name=kvp.name,
+ expire_timestamp=kvp.expire_timestamp,
+ value=kvp.value,
+ secret=secret,
+ scope=scope,
+ )
KeyValuePair.add_or_update(new_kvp_db)
except:
- print('ERROR: Failed migrating datastore item with name: %s' % kvp.name)
+ print("ERROR: Failed migrating datastore item with name: %s" % kvp.name)
tb.print_exc()
raise
@@ -62,10 +66,10 @@ def main():
# Migrate rules.
try:
migrate_datastore()
- print('SUCCESS: Datastore items migrated successfully.')
+ print("SUCCESS: Datastore items migrated successfully.")
exit_code = 0
except:
- print('ABORTED: Datastore migration aborted on first failure.')
+ print("ABORTED: Datastore migration aborted on first failure.")
exit_code = 1
# Disconnect from db.
@@ -73,5 +77,5 @@ def main():
sys.exit(exit_code)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py
index bb4ee666b9..9d09789413 100755
--- a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py
+++ b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py
@@ -39,12 +39,14 @@ def main():
try:
handler = scheduler_handler.get_handler()
handler._cleanup_policy_delayed()
- LOG.info('SUCCESS: Completed clean up of executions with deprecated policy-delayed status.')
+ LOG.info(
+ "SUCCESS: Completed clean up of executions with deprecated policy-delayed status."
+ )
exit_code = 0
except Exception as e:
LOG.error(
- 'ABORTED: Clean up of executions with deprecated policy-delayed status aborted on '
- 'first failure. %s' % e.message
+ "ABORTED: Clean up of executions with deprecated policy-delayed status aborted on "
+ "first failure. %s" % e.message
)
exit_code = 1
@@ -53,5 +55,5 @@ def main():
sys.exit(exit_code)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/st2common/bin/paramiko_ssh_evenlets_tester.py b/st2common/bin/paramiko_ssh_evenlets_tester.py
index 49a42545f8..af30196de1 100755
--- a/st2common/bin/paramiko_ssh_evenlets_tester.py
+++ b/st2common/bin/paramiko_ssh_evenlets_tester.py
@@ -18,6 +18,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import argparse
@@ -34,49 +35,54 @@ def main(user, pkey, password, hosts_str, cmd, file_path, dir_path, delete_dir):
if file_path:
if not os.path.exists(file_path):
- raise Exception('File not found.')
- results = client.put(file_path, '/home/lakshmi/test_file', mode="0660")
- pp.pprint('Copy results: \n%s' % results)
- results = client.run('ls -rlth')
- pp.pprint('ls results: \n%s' % results)
+ raise Exception("File not found.")
+ results = client.put(file_path, "/home/lakshmi/test_file", mode="0660")
+ pp.pprint("Copy results: \n%s" % results)
+ results = client.run("ls -rlth")
+ pp.pprint("ls results: \n%s" % results)
if dir_path:
if not os.path.exists(dir_path):
- raise Exception('File not found.')
- results = client.put(dir_path, '/home/lakshmi/', mode="0660")
- pp.pprint('Copy results: \n%s' % results)
- results = client.run('ls -rlth')
- pp.pprint('ls results: \n%s' % results)
+ raise Exception("File not found.")
+ results = client.put(dir_path, "/home/lakshmi/", mode="0660")
+ pp.pprint("Copy results: \n%s" % results)
+ results = client.run("ls -rlth")
+ pp.pprint("ls results: \n%s" % results)
if cmd:
results = client.run(cmd)
- pp.pprint('cmd results: \n%s' % results)
+ pp.pprint("cmd results: \n%s" % results)
if delete_dir:
results = client.delete_dir(delete_dir, force=True)
- pp.pprint('Delete results: \n%s' % results)
+ pp.pprint("Delete results: \n%s" % results)
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Parallel SSH tester.')
- parser.add_argument('--hosts', required=True,
- help='List of hosts to connect to')
- parser.add_argument('--private-key', required=False,
- help='Private key to use.')
- parser.add_argument('--password', required=False,
- help='Password.')
- parser.add_argument('--user', required=True,
- help='SSH user name.')
- parser.add_argument('--cmd', required=False,
- help='Command to run on host.')
- parser.add_argument('--file', required=False,
- help='Path of file to copy to remote host.')
- parser.add_argument('--dir', required=False,
- help='Path of dir to copy to remote host.')
- parser.add_argument('--delete-dir', required=False,
- help='Path of dir to delete on remote host.')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Parallel SSH tester.")
+ parser.add_argument("--hosts", required=True, help="List of hosts to connect to")
+ parser.add_argument("--private-key", required=False, help="Private key to use.")
+ parser.add_argument("--password", required=False, help="Password.")
+ parser.add_argument("--user", required=True, help="SSH user name.")
+ parser.add_argument("--cmd", required=False, help="Command to run on host.")
+ parser.add_argument(
+ "--file", required=False, help="Path of file to copy to remote host."
+ )
+ parser.add_argument(
+ "--dir", required=False, help="Path of dir to copy to remote host."
+ )
+ parser.add_argument(
+ "--delete-dir", required=False, help="Path of dir to delete on remote host."
+ )
args = parser.parse_args()
- main(user=args.user, pkey=args.private_key, password=args.password,
- hosts_str=args.hosts, cmd=args.cmd,
- file_path=args.file, dir_path=args.dir, delete_dir=args.delete_dir)
+ main(
+ user=args.user,
+ pkey=args.private_key,
+ password=args.password,
+ hosts_str=args.hosts,
+ cmd=args.cmd,
+ file_path=args.file,
+ dir_path=args.dir,
+ delete_dir=args.delete_dir,
+ )
diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests
index bed2826760..9f7c2306ab 100755
--- a/st2common/bin/st2-run-pack-tests
+++ b/st2common/bin/st2-run-pack-tests
@@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then
# Base options to enable test coverage reporting
# --with-coverage : enables coverage reporting
- # --cover-erase : removes old coverage reports before starting
+ # --cover-erase : removes old coverage reports before starting
NOSE_OPTS+=(--with-coverage --cover-erase)
# Now, by default nosetests reports test coverage for every module found
diff --git a/st2common/dist_utils.py b/st2common/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2common/dist_utils.py
+++ b/st2common/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2common/setup.py b/st2common/setup.py
index f68679af8c..908884260d 100644
--- a/st2common/setup.py
+++ b/st2common/setup.py
@@ -23,10 +23,10 @@
from dist_utils import apply_vagrant_workaround
from dist_utils import get_version_string
-ST2_COMPONENT = 'st2common'
+ST2_COMPONENT = "st2common"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
-INIT_FILE = os.path.join(BASE_DIR, 'st2common/__init__.py')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
+INIT_FILE = os.path.join(BASE_DIR, "st2common/__init__.py")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -34,41 +34,43 @@
setup(
name=ST2_COMPONENT,
version=get_version_string(INIT_FILE),
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
+ packages=find_packages(exclude=["setuptools", "tests"]),
scripts=[
- 'bin/st2-bootstrap-rmq',
- 'bin/st2-cleanup-db',
- 'bin/st2-register-content',
- 'bin/st2-purge-executions',
- 'bin/st2-purge-trigger-instances',
- 'bin/st2-run-pack-tests',
- 'bin/st2ctl',
- 'bin/st2-generate-symmetric-crypto-key',
- 'bin/st2-self-check',
- 'bin/st2-track-result',
- 'bin/st2-validate-pack-config',
- 'bin/st2-pack-install',
- 'bin/st2-pack-download',
- 'bin/st2-pack-setup-virtualenv'
+ "bin/st2-bootstrap-rmq",
+ "bin/st2-cleanup-db",
+ "bin/st2-register-content",
+ "bin/st2-purge-executions",
+ "bin/st2-purge-trigger-instances",
+ "bin/st2-run-pack-tests",
+ "bin/st2ctl",
+ "bin/st2-generate-symmetric-crypto-key",
+ "bin/st2-self-check",
+ "bin/st2-track-result",
+ "bin/st2-validate-pack-config",
+ "bin/st2-pack-install",
+ "bin/st2-pack-download",
+ "bin/st2-pack-setup-virtualenv",
],
entry_points={
- 'st2common.metrics.driver': [
- 'statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver',
- 'noop = st2common.metrics.drivers.noop_driver:NoopDriver',
- 'echo = st2common.metrics.drivers.echo_driver:EchoDriver'
+ "st2common.metrics.driver": [
+ "statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver",
+ "noop = st2common.metrics.drivers.noop_driver:NoopDriver",
+ "echo = st2common.metrics.drivers.echo_driver:EchoDriver",
],
- 'st2common.rbac.backend': [
- 'noop = st2common.rbac.backends.noop:NoOpRBACBackend'
+ "st2common.rbac.backend": [
+ "noop = st2common.rbac.backends.noop:NoOpRBACBackend"
],
- }
+ },
)
diff --git a/st2common/st2common/__init__.py b/st2common/st2common/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2common/st2common/__init__.py
+++ b/st2common/st2common/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2common/st2common/bootstrap/actionsregistrar.py b/st2common/st2common/bootstrap/actionsregistrar.py
index c21788fb14..f5265bac48 100644
--- a/st2common/st2common/bootstrap/actionsregistrar.py
+++ b/st2common/st2common/bootstrap/actionsregistrar.py
@@ -30,10 +30,7 @@
import st2common.util.action_db as action_utils
import st2common.validators.api.action as action_validator
-__all__ = [
- 'ActionsRegistrar',
- 'register_actions'
-]
+__all__ = ["ActionsRegistrar", "register_actions"]
LOG = logging.getLogger(__name__)
@@ -53,15 +50,18 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='actions')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="actions"
+ )
for pack, actions_dir in six.iteritems(content):
if not actions_dir:
- LOG.debug('Pack %s does not contain actions.', pack)
+ LOG.debug("Pack %s does not contain actions.", pack)
continue
try:
- LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir)
+ LOG.debug(
+ "Registering actions from pack %s:, dir: %s", pack, actions_dir
+ )
actions = self._get_actions_from_pack(actions_dir)
count = self._register_actions_from_pack(pack=pack, actions=actions)
registered_count += count
@@ -69,7 +69,9 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all actions from pack: %s', actions_dir)
+ LOG.exception(
+ "Failed registering all actions from pack: %s", actions_dir
+ )
return registered_count
@@ -80,10 +82,11 @@ def register_from_pack(self, pack_dir):
:return: Number of actions registered.
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- actions_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='actions')
+ actions_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="actions"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -92,16 +95,18 @@ def register_from_pack(self, pack_dir):
if not actions_dir:
return registered_count
- LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir)
+ LOG.debug("Registering actions from pack %s:, dir: %s", pack, actions_dir)
try:
actions = self._get_actions_from_pack(actions_dir=actions_dir)
- registered_count = self._register_actions_from_pack(pack=pack, actions=actions)
+ registered_count = self._register_actions_from_pack(
+ pack=pack, actions=actions
+ )
except Exception as e:
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all actions from pack: %s', actions_dir)
+ LOG.exception("Failed registering all actions from pack: %s", actions_dir)
return registered_count
@@ -109,29 +114,33 @@ def _get_actions_from_pack(self, actions_dir):
actions = self.get_resources_from_pack(resources_dir=actions_dir)
# Exclude global actions configuration file
- config_files = ['actions/config' + ext for ext in self.ALLOWED_EXTENSIONS]
+ config_files = ["actions/config" + ext for ext in self.ALLOWED_EXTENSIONS]
for config_file in config_files:
- actions = [file_path for file_path in actions if config_file not in file_path]
+ actions = [
+ file_path for file_path in actions if config_file not in file_path
+ ]
return actions
def _register_action(self, pack, action):
content = self._meta_loader.load(action)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
# Add in "metadata_file" attribute which stores path to the pack metadata file relative to
# the pack directory
- metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack,
- file_path=action,
- use_pack_cache=True)
- content['metadata_file'] = metadata_file
+ metadata_file = content_utils.get_relative_path_to_pack_file(
+ pack_ref=pack, file_path=action, use_pack_cache=True
+ )
+ content["metadata_file"] = metadata_file
action_api = ActionAPI(**content)
@@ -141,25 +150,29 @@ def _register_action(self, pack, action):
# We throw a more user-friendly exception on invalid parameter name
msg = six.text_type(e)
- is_invalid_parameter_name = 'does not match any of the regexes: ' in msg
+ is_invalid_parameter_name = "does not match any of the regexes: " in msg
if is_invalid_parameter_name:
- match = re.search('\'(.+?)\' does not match any of the regexes', msg)
+ match = re.search("'(.+?)' does not match any of the regexes", msg)
if match:
parameter_name = match.groups()[0]
else:
- parameter_name = 'unknown'
+ parameter_name = "unknown"
- new_msg = ('Parameter name "%s" is invalid. Valid characters for parameter name '
- 'are [a-zA-Z0-0_].' % (parameter_name))
- new_msg += '\n\n' + msg
+ new_msg = (
+ 'Parameter name "%s" is invalid. Valid characters for parameter name '
+ "are [a-zA-Z0-0_]." % (parameter_name)
+ )
+ new_msg += "\n\n" + msg
raise jsonschema.ValidationError(new_msg)
raise e
# Use in-memory cached RunnerTypeDB objects to reduce load on the database
if self._use_runners_cache:
- runner_type_db = self._runner_type_db_cache.get(action_api.runner_type, None)
+ runner_type_db = self._runner_type_db_cache.get(
+ action_api.runner_type, None
+ )
if not runner_type_db:
runner_type_db = action_validator.get_runner_model(action_api)
@@ -170,36 +183,47 @@ def _register_action(self, pack, action):
action_validator.validate_action(action_api, runner_type_db=runner_type_db)
model = ActionAPI.to_model(action_api)
- action_ref = ResourceReference.to_string_reference(pack=pack, name=str(content['name']))
+ action_ref = ResourceReference.to_string_reference(
+ pack=pack, name=str(content["name"])
+ )
existing = action_utils.get_action_by_ref(action_ref)
if not existing:
- LOG.debug('Action %s not found. Creating new one with: %s', action_ref, content)
+ LOG.debug(
+ "Action %s not found. Creating new one with: %s", action_ref, content
+ )
else:
- LOG.debug('Action %s found. Will be updated from: %s to: %s',
- action_ref, existing, model)
+ LOG.debug(
+ "Action %s found. Will be updated from: %s to: %s",
+ action_ref,
+ existing,
+ model,
+ )
model.id = existing.id
try:
model = Action.add_or_update(model)
- extra = {'action_db': model}
- LOG.audit('Action updated. Action %s from %s.', model, action, extra=extra)
+ extra = {"action_db": model}
+ LOG.audit("Action updated. Action %s from %s.", model, action, extra=extra)
except Exception:
- LOG.exception('Failed to write action to db %s.', model.name)
+ LOG.exception("Failed to write action to db %s.", model.name)
raise
def _register_actions_from_pack(self, pack, actions):
registered_count = 0
for action in actions:
try:
- LOG.debug('Loading action from %s.', action)
+ LOG.debug("Loading action from %s.", action)
self._register_action(pack=pack, action=action)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register action "%s" from pack "%s": %s' % (action, pack,
- six.text_type(e)))
+ msg = 'Failed to register action "%s" from pack "%s": %s' % (
+ action,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.exception('Unable to register action: %s', action)
+ LOG.exception("Unable to register action: %s", action)
continue
else:
registered_count += 1
@@ -207,16 +231,18 @@ def _register_actions_from_pack(self, pack, actions):
return registered_count
-def register_actions(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_actions(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = ActionsRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = ActionsRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/aliasesregistrar.py b/st2common/st2common/bootstrap/aliasesregistrar.py
index dbc9c3b0fc..c9d4ef7017 100644
--- a/st2common/st2common/bootstrap/aliasesregistrar.py
+++ b/st2common/st2common/bootstrap/aliasesregistrar.py
@@ -27,10 +27,7 @@
from st2common.persistence.actionalias import ActionAlias
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-__all__ = [
- 'AliasesRegistrar',
- 'register_aliases'
-]
+__all__ = ["AliasesRegistrar", "register_aliases"]
LOG = logging.getLogger(__name__)
@@ -50,15 +47,18 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='aliases')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="aliases"
+ )
for pack, aliases_dir in six.iteritems(content):
if not aliases_dir:
- LOG.debug('Pack %s does not contain aliases.', pack)
+ LOG.debug("Pack %s does not contain aliases.", pack)
continue
try:
- LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir)
+ LOG.debug(
+ "Registering aliases from pack %s:, dir: %s", pack, aliases_dir
+ )
aliases = self._get_aliases_from_pack(aliases_dir)
count = self._register_aliases_from_pack(pack=pack, aliases=aliases)
registered_count += count
@@ -66,7 +66,9 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all aliases from pack: %s', aliases_dir)
+ LOG.exception(
+ "Failed registering all aliases from pack: %s", aliases_dir
+ )
return registered_count
@@ -77,10 +79,11 @@ def register_from_pack(self, pack_dir):
:return: Number of aliases registered.
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- aliases_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='aliases')
+ aliases_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="aliases"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -89,16 +92,18 @@ def register_from_pack(self, pack_dir):
if not aliases_dir:
return registered_count
- LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir)
+ LOG.debug("Registering aliases from pack %s:, dir: %s", pack, aliases_dir)
try:
aliases = self._get_aliases_from_pack(aliases_dir=aliases_dir)
- registered_count = self._register_aliases_from_pack(pack=pack, aliases=aliases)
+ registered_count = self._register_aliases_from_pack(
+ pack=pack, aliases=aliases
+ )
except Exception as e:
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all aliases from pack: %s', aliases_dir)
+ LOG.exception("Failed registering all aliases from pack: %s", aliases_dir)
return registered_count
return registered_count
@@ -106,7 +111,9 @@ def register_from_pack(self, pack_dir):
def _get_aliases_from_pack(self, aliases_dir):
return self.get_resources_from_pack(resources_dir=aliases_dir)
- def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=False):
+ def _get_action_alias_db(
+ self, pack, action_alias, ignore_metadata_file_error=False
+ ):
"""
Retrieve ActionAliasDB object.
@@ -115,25 +122,27 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa
:type ignore_metadata_file_error: ``bool``
"""
content = self._meta_loader.load(action_alias)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
# Add in "metadata_file" attribute which stores path to the pack metadata file relative to
# the pack directory
try:
- metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack,
- file_path=action_alias,
- use_pack_cache=True)
+ metadata_file = content_utils.get_relative_path_to_pack_file(
+ pack_ref=pack, file_path=action_alias, use_pack_cache=True
+ )
except ValueError as e:
if not ignore_metadata_file_error:
raise e
else:
- content['metadata_file'] = metadata_file
+ content["metadata_file"] = metadata_file
action_alias_api = ActionAliasAPI(**content)
action_alias_api.validate()
@@ -142,28 +151,35 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa
return action_alias_db
def _register_action_alias(self, pack, action_alias):
- action_alias_db = self._get_action_alias_db(pack=pack,
- action_alias=action_alias)
+ action_alias_db = self._get_action_alias_db(
+ pack=pack, action_alias=action_alias
+ )
try:
action_alias_db.id = ActionAlias.get_by_name(action_alias_db.name).id
except StackStormDBObjectNotFoundError:
- LOG.debug('ActionAlias %s not found. Creating new one.', action_alias)
+ LOG.debug("ActionAlias %s not found. Creating new one.", action_alias)
action_ref = action_alias_db.action_ref
action_db = Action.get_by_ref(action_ref)
if not action_db:
- LOG.warning('Action %s not found in DB. Did you forget to register the action?',
- action_ref)
+ LOG.warning(
+ "Action %s not found in DB. Did you forget to register the action?",
+ action_ref,
+ )
try:
action_alias_db = ActionAlias.add_or_update(action_alias_db)
- extra = {'action_alias_db': action_alias_db}
- LOG.audit('Action alias updated. Action alias %s from %s.', action_alias_db,
- action_alias, extra=extra)
+ extra = {"action_alias_db": action_alias_db}
+ LOG.audit(
+ "Action alias updated. Action alias %s from %s.",
+ action_alias_db,
+ action_alias,
+ extra=extra,
+ )
except Exception:
- LOG.exception('Failed to create action alias %s.', action_alias_db.name)
+ LOG.exception("Failed to create action alias %s.", action_alias_db.name)
raise
def _register_aliases_from_pack(self, pack, aliases):
@@ -171,15 +187,18 @@ def _register_aliases_from_pack(self, pack, aliases):
for alias in aliases:
try:
- LOG.debug('Loading alias from %s.', alias)
+ LOG.debug("Loading alias from %s.", alias)
self._register_action_alias(pack, alias)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register alias "%s" from pack "%s": %s' % (alias, pack,
- six.text_type(e)))
+ msg = 'Failed to register alias "%s" from pack "%s": %s' % (
+ alias,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.exception('Unable to register alias: %s', alias)
+ LOG.exception("Unable to register alias: %s", alias)
continue
else:
registered_count += 1
@@ -187,8 +206,9 @@ def _register_aliases_from_pack(self, pack, aliases):
return registered_count
-def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_aliases(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
@@ -196,8 +216,9 @@ def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = AliasesRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = AliasesRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/base.py b/st2common/st2common/bootstrap/base.py
index 1757a2fa8e..1070a3af38 100644
--- a/st2common/st2common/bootstrap/base.py
+++ b/st2common/st2common/bootstrap/base.py
@@ -32,9 +32,7 @@
from st2common.util.pack import get_pack_ref_from_metadata
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-__all__ = [
- 'ResourceRegistrar'
-]
+__all__ = ["ResourceRegistrar"]
LOG = logging.getLogger(__name__)
@@ -44,16 +42,15 @@
# a long running process.
REGISTERED_PACKS_CACHE = {}
-EXCLUDE_FILE_PATTERNS = [
- '*.pyc',
- '.git/*'
-]
+EXCLUDE_FILE_PATTERNS = ["*.pyc", ".git/*"]
class ResourceRegistrar(object):
ALLOWED_EXTENSIONS = []
- def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False):
+ def __init__(
+ self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False
+ ):
"""
:param use_pack_cache: True to cache which packs have been registered in memory and making
sure packs are only registered once.
@@ -81,10 +78,10 @@ def get_resources_from_pack(self, resources_dir):
for ext in self.ALLOWED_EXTENSIONS:
resources_glob = resources_dir
- if resources_dir.endswith('/'):
+ if resources_dir.endswith("/"):
resources_glob = resources_dir + ext
else:
- resources_glob = resources_dir + '/*' + ext
+ resources_glob = resources_dir + "/*" + ext
resource_files = glob.glob(resources_glob)
resources.extend(resource_files)
@@ -121,7 +118,7 @@ def register_pack(self, pack_name, pack_dir):
# This pack has already been registered during this register content run
return
- LOG.debug('Registering pack: %s' % (pack_name))
+ LOG.debug("Registering pack: %s" % (pack_name))
REGISTERED_PACKS_CACHE[pack_name] = True
try:
@@ -148,19 +145,26 @@ def _register_pack(self, pack_name, pack_dir):
# Display a warning if pack contains deprecated config.yaml file. Support for those files
# will be fully removed in v2.4.0.
- config_path = os.path.join(pack_dir, 'config.yaml')
+ config_path = os.path.join(pack_dir, "config.yaml")
if os.path.isfile(config_path):
- LOG.error('Pack "%s" contains a deprecated config.yaml file (%s). '
- 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 '
- 'in favor of config.schema.yaml config schema files and config files in '
- '/opt/stackstorm/configs/ directory. Support for config.yaml files has '
- 'been removed in the release (v2.4.0) so please migrate. For more '
- 'information please refer to %s ' % (pack_db.name, config_path,
- 'https://docs.stackstorm.com/reference/pack_configs.html'))
+ LOG.error(
+ 'Pack "%s" contains a deprecated config.yaml file (%s). '
+ 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 '
+ "in favor of config.schema.yaml config schema files and config files in "
+ "/opt/stackstorm/configs/ directory. Support for config.yaml files has "
+ "been removed in the release (v2.4.0) so please migrate. For more "
+ "information please refer to %s "
+ % (
+ pack_db.name,
+ config_path,
+ "https://docs.stackstorm.com/reference/pack_configs.html",
+ )
+ )
# 2. Register corresponding pack config schema
- config_schema_db = self._register_pack_config_schema_db(pack_name=pack_name,
- pack_dir=pack_dir)
+ config_schema_db = self._register_pack_config_schema_db(
+ pack_name=pack_name, pack_dir=pack_dir
+ )
return pack_db, config_schema_db
@@ -173,25 +177,28 @@ def _register_pack_db(self, pack_name, pack_dir):
# 2hich are in sub-directories)
# 2. If attribute is not available, but pack name is and pack name meets the valid name
# criteria, we use that
- content['ref'] = get_pack_ref_from_metadata(metadata=content,
- pack_directory_name=pack_name)
+ content["ref"] = get_pack_ref_from_metadata(
+ metadata=content, pack_directory_name=pack_name
+ )
# Include a list of pack files
- pack_file_list = get_file_list(directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS)
- content['files'] = pack_file_list
- content['path'] = pack_dir
+ pack_file_list = get_file_list(
+ directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS
+ )
+ content["files"] = pack_file_list
+ content["path"] = pack_dir
pack_api = PackAPI(**content)
pack_api.validate()
pack_db = PackAPI.to_model(pack_api)
try:
- pack_db.id = Pack.get_by_ref(content['ref']).id
+ pack_db.id = Pack.get_by_ref(content["ref"]).id
except StackStormDBObjectNotFoundError:
- LOG.debug('Pack %s not found. Creating new one.', pack_name)
+ LOG.debug("Pack %s not found. Creating new one.", pack_name)
pack_db = Pack.add_or_update(pack_db)
- LOG.debug('Pack %s registered.' % (pack_name))
+ LOG.debug("Pack %s registered." % (pack_name))
return pack_db
def _register_pack_config_schema_db(self, pack_name, pack_dir):
@@ -204,11 +211,13 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir):
values = self._meta_loader.load(config_schema_path)
if not values:
- raise ValueError('Config schema "%s" is empty and invalid.' % (config_schema_path))
+ raise ValueError(
+ 'Config schema "%s" is empty and invalid.' % (config_schema_path)
+ )
content = {}
- content['pack'] = pack_name
- content['attributes'] = values
+ content["pack"] = pack_name
+ content["attributes"] = values
config_schema_api = ConfigSchemaAPI(**content)
config_schema_api = config_schema_api.validate()
@@ -217,8 +226,10 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir):
try:
config_schema_db.id = ConfigSchema.get_by_pack(pack_name).id
except StackStormDBObjectNotFoundError:
- LOG.debug('Config schema for pack %s not found. Creating new one.', pack_name)
+ LOG.debug(
+ "Config schema for pack %s not found. Creating new one.", pack_name
+ )
config_schema_db = ConfigSchema.add_or_update(config_schema_db)
- LOG.debug('Config schema for pack %s registered.' % (pack_name))
+ LOG.debug("Config schema for pack %s registered." % (pack_name))
return config_schema_db
diff --git a/st2common/st2common/bootstrap/configsregistrar.py b/st2common/st2common/bootstrap/configsregistrar.py
index fc7e05eb98..3cbc5283fc 100644
--- a/st2common/st2common/bootstrap/configsregistrar.py
+++ b/st2common/st2common/bootstrap/configsregistrar.py
@@ -28,9 +28,7 @@
from st2common.persistence.pack import Config
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-__all__ = [
- 'ConfigsRegistrar'
-]
+__all__ = ["ConfigsRegistrar"]
LOG = logging.getLogger(__name__)
@@ -44,11 +42,18 @@ class ConfigsRegistrar(ResourceRegistrar):
ALLOWED_EXTENSIONS = ALLOWED_EXTS
- def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False,
- validate_configs=True):
- super(ConfigsRegistrar, self).__init__(use_pack_cache=use_pack_cache,
- use_runners_cache=use_runners_cache,
- fail_on_failure=fail_on_failure)
+ def __init__(
+ self,
+ use_pack_cache=True,
+ use_runners_cache=False,
+ fail_on_failure=False,
+ validate_configs=True,
+ ):
+ super(ConfigsRegistrar, self).__init__(
+ use_pack_cache=use_pack_cache,
+ use_runners_cache=use_runners_cache,
+ fail_on_failure=fail_on_failure,
+ )
self._validate_configs = validate_configs
@@ -68,21 +73,29 @@ def register_from_packs(self, base_dirs):
if not os.path.isfile(config_path):
# Config for that pack doesn't exist
- LOG.debug('No config found for pack "%s" (file "%s" is not present).', pack_name,
- config_path)
+ LOG.debug(
+ 'No config found for pack "%s" (file "%s" is not present).',
+ pack_name,
+ config_path,
+ )
continue
try:
self._register_config_for_pack(pack=pack_name, config_path=config_path)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register config "%s" for pack "%s": %s' % (config_path,
- pack_name,
- six.text_type(e)))
+ msg = 'Failed to register config "%s" for pack "%s": %s' % (
+ config_path,
+ pack_name,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.exception('Failed to register config for pack "%s": %s', pack_name,
- six.text_type(e))
+ LOG.exception(
+ 'Failed to register config for pack "%s": %s',
+ pack_name,
+ six.text_type(e),
+ )
else:
registered_count += 1
@@ -92,7 +105,7 @@ def register_from_pack(self, pack_dir):
"""
Register config for a provided pack.
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack_name = os.path.split(pack_dir)
# Register pack first
@@ -106,8 +119,8 @@ def register_from_pack(self, pack_dir):
return 1
def _get_config_path_for_pack(self, pack_name):
- configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/')
- config_path = os.path.join(configs_path, '%s.yaml' % (pack_name))
+ configs_path = os.path.join(cfg.CONF.system.base_path, "configs/")
+ config_path = os.path.join(configs_path, "%s.yaml" % (pack_name))
return config_path
@@ -115,8 +128,8 @@ def _register_config_for_pack(self, pack, config_path):
content = {}
values = self._meta_loader.load(config_path)
- content['pack'] = pack
- content['values'] = values
+ content["pack"] = pack
+ content["values"] = values
config_api = ConfigAPI(**content)
config_api.validate(validate_against_schema=self._validate_configs)
@@ -136,17 +149,22 @@ def save_model(config_api):
try:
config_db = Config.add_or_update(config_db)
- extra = {'config_db': config_db}
+ extra = {"config_db": config_db}
LOG.audit('Config for pack "%s" is updated.', config_db.pack, extra=extra)
except Exception:
- LOG.exception('Failed to save config for pack %s.', pack)
+ LOG.exception("Failed to save config for pack %s.", pack)
raise
return config_db
-def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False, validate_configs=True):
+def register_configs(
+ packs_base_paths=None,
+ pack_dir=None,
+ use_pack_cache=True,
+ fail_on_failure=False,
+ validate_configs=True,
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
@@ -154,9 +172,11 @@ def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = ConfigsRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure,
- validate_configs=validate_configs)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=use_pack_cache,
+ fail_on_failure=fail_on_failure,
+ validate_configs=validate_configs,
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/policiesregistrar.py b/st2common/st2common/bootstrap/policiesregistrar.py
index b963eaf097..4f6f247694 100644
--- a/st2common/st2common/bootstrap/policiesregistrar.py
+++ b/st2common/st2common/bootstrap/policiesregistrar.py
@@ -30,11 +30,7 @@
from st2common.util import loader
-__all__ = [
- 'PolicyRegistrar',
- 'register_policy_types',
- 'register_policies'
-]
+__all__ = ["PolicyRegistrar", "register_policy_types", "register_policies"]
LOG = logging.getLogger(__name__)
@@ -55,15 +51,18 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='policies')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="policies"
+ )
for pack, policies_dir in six.iteritems(content):
if not policies_dir:
- LOG.debug('Pack %s does not contain policies.', pack)
+ LOG.debug("Pack %s does not contain policies.", pack)
continue
try:
- LOG.debug('Registering policies from pack %s:, dir: %s', pack, policies_dir)
+ LOG.debug(
+ "Registering policies from pack %s:, dir: %s", pack, policies_dir
+ )
policies = self._get_policies_from_pack(policies_dir)
count = self._register_policies_from_pack(pack=pack, policies=policies)
registered_count += count
@@ -71,7 +70,9 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all policies from pack: %s', policies_dir)
+ LOG.exception(
+ "Failed registering all policies from pack: %s", policies_dir
+ )
return registered_count
@@ -82,11 +83,12 @@ def register_from_pack(self, pack_dir):
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- policies_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='policies')
+ policies_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="policies"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -95,16 +97,18 @@ def register_from_pack(self, pack_dir):
if not policies_dir:
return registered_count
- LOG.debug('Registering policies from pack %s, dir: %s', pack, policies_dir)
+ LOG.debug("Registering policies from pack %s, dir: %s", pack, policies_dir)
try:
policies = self._get_policies_from_pack(policies_dir=policies_dir)
- registered_count = self._register_policies_from_pack(pack=pack, policies=policies)
+ registered_count = self._register_policies_from_pack(
+ pack=pack, policies=policies
+ )
except Exception as e:
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all policies from pack: %s', policies_dir)
+ LOG.exception("Failed registering all policies from pack: %s", policies_dir)
return registered_count
return registered_count
@@ -117,15 +121,18 @@ def _register_policies_from_pack(self, pack, policies):
for policy in policies:
try:
- LOG.debug('Loading policy from %s.', policy)
+ LOG.debug("Loading policy from %s.", policy)
self._register_policy(pack=pack, policy=policy)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register policy "%s" from pack "%s": %s' % (policy, pack,
- six.text_type(e)))
+ msg = 'Failed to register policy "%s" from pack "%s": %s' % (
+ policy,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.exception('Unable to register policy: %s', policy)
+ LOG.exception("Unable to register policy: %s", policy)
continue
else:
registered_count += 1
@@ -134,20 +141,22 @@ def _register_policies_from_pack(self, pack, policies):
def _register_policy(self, pack, policy):
content = self._meta_loader.load(policy)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
# Add in "metadata_file" attribute which stores path to the pack metadata file relative to
# the pack directory
- metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack,
- file_path=policy,
- use_pack_cache=True)
- content['metadata_file'] = metadata_file
+ metadata_file = content_utils.get_relative_path_to_pack_file(
+ pack_ref=pack, file_path=policy, use_pack_cache=True
+ )
+ content["metadata_file"] = metadata_file
policy_api = PolicyAPI(**content)
policy_api = policy_api.validate()
@@ -160,21 +169,21 @@ def _register_policy(self, pack, policy):
try:
policy_db = Policy.add_or_update(policy_db)
- extra = {'policy_db': policy_db}
+ extra = {"policy_db": policy_db}
LOG.audit('Policy "%s" is updated.', policy_db.ref, extra=extra)
except Exception:
- LOG.exception('Failed to create policy %s.', policy_api.name)
+ LOG.exception("Failed to create policy %s.", policy_api.name)
raise
def register_policy_types(module):
registered_count = 0
mod_path = os.path.dirname(os.path.realpath(sys.modules[module.__name__].__file__))
- path = os.path.join(mod_path, 'policies/meta')
+ path = os.path.join(mod_path, "policies/meta")
files = []
for ext in ALLOWED_EXTS:
- exp = '%s/*%s' % (path, ext)
+ exp = "%s/*%s" % (path, ext)
files += glob.glob(exp)
for f in files:
@@ -189,11 +198,13 @@ def register_policy_types(module):
if existing_entry:
policy_type_db.id = existing_entry.id
except StackStormDBObjectNotFoundError:
- LOG.debug('Policy type "%s" is not found. Creating new entry.',
- policy_type_db.ref)
+ LOG.debug(
+ 'Policy type "%s" is not found. Creating new entry.',
+ policy_type_db.ref,
+ )
policy_type_db = PolicyType.add_or_update(policy_type_db)
- extra = {'policy_type_db': policy_type_db}
+ extra = {"policy_type_db": policy_type_db}
LOG.audit('Policy type "%s" is updated.', policy_type_db.ref, extra=extra)
registered_count += 1
@@ -203,16 +214,18 @@ def register_policy_types(module):
return registered_count
-def register_policies(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_policies(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = PolicyRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = PolicyRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/rulesregistrar.py b/st2common/st2common/bootstrap/rulesregistrar.py
index c50b0d5eae..505f3e5337 100644
--- a/st2common/st2common/bootstrap/rulesregistrar.py
+++ b/st2common/st2common/bootstrap/rulesregistrar.py
@@ -25,14 +25,14 @@
from st2common.models.api.rule import RuleAPI
from st2common.models.system.common import ResourceReference
from st2common.persistence.rule import Rule
-from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count
+from st2common.services.triggers import (
+ cleanup_trigger_db_for_rule,
+ increment_trigger_ref_count,
+)
from st2common.exceptions.db import StackStormDBObjectNotFoundError
import st2common.content.utils as content_utils
-__all__ = [
- 'RulesRegistrar',
- 'register_rules'
-]
+__all__ = ["RulesRegistrar", "register_rules"]
LOG = logging.getLogger(__name__)
@@ -49,14 +49,15 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='rules')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="rules"
+ )
for pack, rules_dir in six.iteritems(content):
if not rules_dir:
- LOG.debug('Pack %s does not contain rules.', pack)
+ LOG.debug("Pack %s does not contain rules.", pack)
continue
try:
- LOG.debug('Registering rules from pack: %s', pack)
+ LOG.debug("Registering rules from pack: %s", pack)
rules = self._get_rules_from_pack(rules_dir)
count = self._register_rules_from_pack(pack, rules)
registered_count += count
@@ -64,7 +65,7 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all rules from pack: %s', rules_dir)
+ LOG.exception("Failed registering all rules from pack: %s", rules_dir)
return registered_count
@@ -75,10 +76,11 @@ def register_from_pack(self, pack_dir):
:return: Number of rules registered.
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- rules_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='rules')
+ rules_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="rules"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -87,7 +89,7 @@ def register_from_pack(self, pack_dir):
if not rules_dir:
return registered_count
- LOG.debug('Registering rules from pack %s:, dir: %s', pack, rules_dir)
+ LOG.debug("Registering rules from pack %s:, dir: %s", pack, rules_dir)
try:
rules = self._get_rules_from_pack(rules_dir=rules_dir)
@@ -96,7 +98,7 @@ def register_from_pack(self, pack_dir):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all rules from pack: %s', rules_dir)
+ LOG.exception("Failed registering all rules from pack: %s", rules_dir)
return registered_count
@@ -108,21 +110,23 @@ def _register_rules_from_pack(self, pack, rules):
# TODO: Refactor this monstrosity
for rule in rules:
- LOG.debug('Loading rule from %s.', rule)
+ LOG.debug("Loading rule from %s.", rule)
try:
content = self._meta_loader.load(rule)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
- metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack,
- file_path=rule,
- use_pack_cache=True)
- content['metadata_file'] = metadata_file
+ metadata_file = content_utils.get_relative_path_to_pack_file(
+ pack_ref=pack, file_path=rule, use_pack_cache=True
+ )
+ content["metadata_file"] = metadata_file
rule_api = RuleAPI(**content)
rule_api.validate()
@@ -134,35 +138,48 @@ def _register_rules_from_pack(self, pack, rules):
# delete so we don't have duplicates.
if pack_field != DEFAULT_PACK_NAME:
try:
- rule_ref = ResourceReference.to_string_reference(name=content['name'],
- pack=DEFAULT_PACK_NAME)
- LOG.debug('Looking for rule %s in pack %s', content['name'],
- DEFAULT_PACK_NAME)
+ rule_ref = ResourceReference.to_string_reference(
+ name=content["name"], pack=DEFAULT_PACK_NAME
+ )
+ LOG.debug(
+ "Looking for rule %s in pack %s",
+ content["name"],
+ DEFAULT_PACK_NAME,
+ )
existing = Rule.get_by_ref(rule_ref)
- LOG.debug('Existing = %s', existing)
+ LOG.debug("Existing = %s", existing)
if existing:
- LOG.debug('Found rule in pack default: %s; Deleting.', rule_ref)
+ LOG.debug(
+ "Found rule in pack default: %s; Deleting.", rule_ref
+ )
Rule.delete(existing)
except:
- LOG.exception('Exception deleting rule from %s pack.', DEFAULT_PACK_NAME)
+ LOG.exception(
+ "Exception deleting rule from %s pack.", DEFAULT_PACK_NAME
+ )
try:
- rule_ref = ResourceReference.to_string_reference(name=content['name'],
- pack=content['pack'])
+ rule_ref = ResourceReference.to_string_reference(
+ name=content["name"], pack=content["pack"]
+ )
existing = Rule.get_by_ref(rule_ref)
if existing:
rule_db.id = existing.id
- LOG.debug('Found existing rule: %s with id: %s', rule_ref, existing.id)
+ LOG.debug(
+ "Found existing rule: %s with id: %s", rule_ref, existing.id
+ )
except StackStormDBObjectNotFoundError:
- LOG.debug('Rule %s not found. Creating new one.', rule)
+ LOG.debug("Rule %s not found. Creating new one.", rule)
try:
rule_db = Rule.add_or_update(rule_db)
increment_trigger_ref_count(rule_api=rule_api)
- extra = {'rule_db': rule_db}
- LOG.audit('Rule updated. Rule %s from %s.', rule_db, rule, extra=extra)
+ extra = {"rule_db": rule_db}
+ LOG.audit(
+ "Rule updated. Rule %s from %s.", rule_db, rule, extra=extra
+ )
except Exception:
- LOG.exception('Failed to create rule %s.', rule_api.name)
+ LOG.exception("Failed to create rule %s.", rule_api.name)
# If there was an existing rule then the ref count was updated in
# to_model so it needs to be adjusted down here. Also, update could
@@ -171,27 +188,32 @@ def _register_rules_from_pack(self, pack, rules):
cleanup_trigger_db_for_rule(existing)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register rule "%s" from pack "%s": %s' % (rule, pack,
- six.text_type(e)))
+ msg = 'Failed to register rule "%s" from pack "%s": %s' % (
+ rule,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.exception('Failed registering rule from %s.', rule)
+ LOG.exception("Failed registering rule from %s.", rule)
else:
registered_count += 1
return registered_count
-def register_rules(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_rules(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = RulesRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = RulesRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/ruletypesregistrar.py b/st2common/st2common/bootstrap/ruletypesregistrar.py
index 735294cd23..90d4018a40 100644
--- a/st2common/st2common/bootstrap/ruletypesregistrar.py
+++ b/st2common/st2common/bootstrap/ruletypesregistrar.py
@@ -22,41 +22,36 @@
from st2common.persistence.rule import RuleType
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-__all__ = [
- 'register_rule_types',
- 'RULE_TYPES'
-]
+__all__ = ["register_rule_types", "RULE_TYPES"]
LOG = logging.getLogger(__name__)
RULE_TYPES = [
{
- 'name': RULE_TYPE_STANDARD,
- 'description': 'standard rule that is always applicable.',
- 'enabled': True,
- 'parameters': {
- }
+ "name": RULE_TYPE_STANDARD,
+ "description": "standard rule that is always applicable.",
+ "enabled": True,
+ "parameters": {},
},
{
- 'name': RULE_TYPE_BACKSTOP,
- 'description': 'Rule that applies when no other rule has matched for a specific Trigger.',
- 'enabled': True,
- 'parameters': {
- }
+ "name": RULE_TYPE_BACKSTOP,
+ "description": "Rule that applies when no other rule has matched for a specific Trigger.",
+ "enabled": True,
+ "parameters": {},
},
]
def register_rule_types():
- LOG.debug('Start : register default RuleTypes.')
+ LOG.debug("Start : register default RuleTypes.")
registered_count = 0
for rule_type in RULE_TYPES:
rule_type = copy.deepcopy(rule_type)
try:
- rule_type_db = RuleType.get_by_name(rule_type['name'])
+ rule_type_db = RuleType.get_by_name(rule_type["name"])
update = True
except StackStormDBObjectNotFoundError:
rule_type_db = None
@@ -72,16 +67,16 @@ def register_rule_types():
try:
rule_type_db = RuleType.add_or_update(rule_type_model)
- extra = {'rule_type_db': rule_type_db}
+ extra = {"rule_type_db": rule_type_db}
if update:
- LOG.audit('RuleType updated. RuleType %s', rule_type_db, extra=extra)
+ LOG.audit("RuleType updated. RuleType %s", rule_type_db, extra=extra)
else:
- LOG.audit('RuleType created. RuleType %s', rule_type_db, extra=extra)
+ LOG.audit("RuleType created. RuleType %s", rule_type_db, extra=extra)
except Exception:
- LOG.exception('Unable to register RuleType %s.', rule_type['name'])
+ LOG.exception("Unable to register RuleType %s.", rule_type["name"])
else:
registered_count += 1
- LOG.debug('End : register default RuleTypes.')
+ LOG.debug("End : register default RuleTypes.")
return registered_count
diff --git a/st2common/st2common/bootstrap/runnersregistrar.py b/st2common/st2common/bootstrap/runnersregistrar.py
index 3aa93da9b1..bb99389433 100644
--- a/st2common/st2common/bootstrap/runnersregistrar.py
+++ b/st2common/st2common/bootstrap/runnersregistrar.py
@@ -26,7 +26,7 @@
from st2common.util.action_db import get_runnertype_by_name
__all__ = [
- 'register_runner_types',
+ "register_runner_types",
]
@@ -37,7 +37,7 @@ def register_runners(experimental=False, fail_on_failure=True):
"""
Register runners
"""
- LOG.debug('Start : register runners')
+ LOG.debug("Start : register runners")
runner_count = 0
manager = ExtensionManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False)
@@ -46,28 +46,30 @@ def register_runners(experimental=False, fail_on_failure=True):
for name in extension_names:
LOG.debug('Found runner "%s"' % (name))
- manager = DriverManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name)
+ manager = DriverManager(
+ namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name
+ )
runner_metadata = manager.driver.get_metadata()
runner_count += register_runner(runner_metadata, experimental)
- LOG.debug('End : register runners')
+ LOG.debug("End : register runners")
return runner_count
def register_runner(runner_type, experimental):
# For backward compatibility reasons, we also register runners under the old names
- runner_names = [runner_type['name']] + runner_type.get('aliases', [])
+ runner_names = [runner_type["name"]] + runner_type.get("aliases", [])
for runner_name in runner_names:
- runner_type['name'] = runner_name
- runner_experimental = runner_type.get('experimental', False)
+ runner_type["name"] = runner_name
+ runner_experimental = runner_type.get("experimental", False)
if runner_experimental and not experimental:
LOG.debug('Skipping experimental runner "%s"' % (runner_name))
continue
# Remove additional, non db-model attributes
- non_db_attributes = ['experimental', 'aliases']
+ non_db_attributes = ["experimental", "aliases"]
for attribute in non_db_attributes:
if attribute in runner_type:
del runner_type[attribute]
@@ -81,13 +83,13 @@ def register_runner(runner_type, experimental):
# Note: We don't want to overwrite "enabled" attribute which is already in the database
# (aka we don't want to re-enable runner which has been disabled by the user)
- if runner_type_db and runner_type_db['enabled'] != runner_type['enabled']:
- runner_type['enabled'] = runner_type_db['enabled']
+ if runner_type_db and runner_type_db["enabled"] != runner_type["enabled"]:
+ runner_type["enabled"] = runner_type_db["enabled"]
# If package is not provided, assume it's the same as module name for backward
# compatibility reasons
- if not runner_type.get('runner_package', None):
- runner_type['runner_package'] = runner_type['runner_module']
+ if not runner_type.get("runner_package", None):
+ runner_type["runner_package"] = runner_type["runner_module"]
runner_type_api = RunnerTypeAPI(**runner_type)
runner_type_api.validate()
@@ -100,13 +102,17 @@ def register_runner(runner_type, experimental):
runner_type_db = RunnerType.add_or_update(runner_type_model)
- extra = {'runner_type_db': runner_type_db}
+ extra = {"runner_type_db": runner_type_db}
if update:
- LOG.audit('RunnerType updated. RunnerType %s', runner_type_db, extra=extra)
+ LOG.audit(
+ "RunnerType updated. RunnerType %s", runner_type_db, extra=extra
+ )
else:
- LOG.audit('RunnerType created. RunnerType %s', runner_type_db, extra=extra)
+ LOG.audit(
+ "RunnerType created. RunnerType %s", runner_type_db, extra=extra
+ )
except Exception:
- LOG.exception('Unable to register runner type %s.', runner_type['name'])
+ LOG.exception("Unable to register runner type %s.", runner_type["name"])
return 0
return 1
diff --git a/st2common/st2common/bootstrap/sensorsregistrar.py b/st2common/st2common/bootstrap/sensorsregistrar.py
index 5181270d79..8a91e23eea 100644
--- a/st2common/st2common/bootstrap/sensorsregistrar.py
+++ b/st2common/st2common/bootstrap/sensorsregistrar.py
@@ -26,10 +26,7 @@
from st2common.models.api.sensor import SensorTypeAPI
from st2common.persistence.sensor import SensorType
-__all__ = [
- 'SensorsRegistrar',
- 'register_sensors'
-]
+__all__ = ["SensorsRegistrar", "register_sensors"]
LOG = logging.getLogger(__name__)
@@ -51,15 +48,18 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='sensors')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="sensors"
+ )
for pack, sensors_dir in six.iteritems(content):
if not sensors_dir:
- LOG.debug('Pack %s does not contain sensors.', pack)
+ LOG.debug("Pack %s does not contain sensors.", pack)
continue
try:
- LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir)
+ LOG.debug(
+ "Registering sensors from pack %s:, dir: %s", pack, sensors_dir
+ )
sensors = self._get_sensors_from_pack(sensors_dir)
count = self._register_sensors_from_pack(pack=pack, sensors=sensors)
registered_count += count
@@ -67,8 +67,11 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir,
- six.text_type(e))
+ LOG.exception(
+ 'Failed registering all sensors from pack "%s": %s',
+ sensors_dir,
+ six.text_type(e),
+ )
return registered_count
@@ -79,10 +82,11 @@ def register_from_pack(self, pack_dir):
:return: Number of sensors registered.
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- sensors_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='sensors')
+ sensors_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="sensors"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -91,17 +95,22 @@ def register_from_pack(self, pack_dir):
if not sensors_dir:
return registered_count
- LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir)
+ LOG.debug("Registering sensors from pack %s:, dir: %s", pack, sensors_dir)
try:
sensors = self._get_sensors_from_pack(sensors_dir=sensors_dir)
- registered_count = self._register_sensors_from_pack(pack=pack, sensors=sensors)
+ registered_count = self._register_sensors_from_pack(
+ pack=pack, sensors=sensors
+ )
except Exception as e:
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir,
- six.text_type(e))
+ LOG.exception(
+ 'Failed registering all sensors from pack "%s": %s',
+ sensors_dir,
+ six.text_type(e),
+ )
return registered_count
@@ -115,11 +124,16 @@ def _register_sensors_from_pack(self, pack, sensors):
self._register_sensor_from_pack(pack=pack, sensor=sensor)
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register sensor "%s" from pack "%s": %s' % (sensor, pack,
- six.text_type(e)))
+ msg = 'Failed to register sensor "%s" from pack "%s": %s' % (
+ sensor,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.debug('Failed to register sensor "%s": %s', sensor, six.text_type(e))
+ LOG.debug(
+ 'Failed to register sensor "%s": %s', sensor, six.text_type(e)
+ )
else:
LOG.debug('Sensor "%s" successfully registered', sensor)
registered_count += 1
@@ -129,33 +143,35 @@ def _register_sensors_from_pack(self, pack, sensors):
def _register_sensor_from_pack(self, pack, sensor):
sensor_metadata_file_path = sensor
- LOG.debug('Loading sensor from %s.', sensor_metadata_file_path)
+ LOG.debug("Loading sensor from %s.", sensor_metadata_file_path)
content = self._meta_loader.load(file_path=sensor_metadata_file_path)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
- entry_point = content.get('entry_point', None)
+ entry_point = content.get("entry_point", None)
if not entry_point:
- raise ValueError('Sensor definition missing entry_point')
+ raise ValueError("Sensor definition missing entry_point")
# Add in "metadata_file" attribute which stores path to the pack metadata file relative to
# the pack directory
- metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack,
- file_path=sensor,
- use_pack_cache=True)
- content['metadata_file'] = metadata_file
+ metadata_file = content_utils.get_relative_path_to_pack_file(
+ pack_ref=pack, file_path=sensor, use_pack_cache=True
+ )
+ content["metadata_file"] = metadata_file
sensors_dir = os.path.dirname(sensor_metadata_file_path)
sensor_file_path = os.path.join(sensors_dir, entry_point)
- artifact_uri = 'file://%s' % (sensor_file_path)
- content['artifact_uri'] = artifact_uri
- content['entry_point'] = entry_point
+ artifact_uri = "file://%s" % (sensor_file_path)
+ content["artifact_uri"] = artifact_uri
+ content["entry_point"] = entry_point
sensor_api = SensorTypeAPI(**content)
sensor_model = SensorTypeAPI.to_model(sensor_api)
@@ -163,28 +179,33 @@ def _register_sensor_from_pack(self, pack, sensor):
sensor_types = SensorType.query(pack=sensor_model.pack, name=sensor_model.name)
if len(sensor_types) >= 1:
sensor_type = sensor_types[0]
- LOG.debug('Found existing sensor id:%s with name:%s. Will update it.',
- sensor_type.id, sensor_type.name)
+ LOG.debug(
+ "Found existing sensor id:%s with name:%s. Will update it.",
+ sensor_type.id,
+ sensor_type.name,
+ )
sensor_model.id = sensor_type.id
try:
sensor_model = SensorType.add_or_update(sensor_model)
except:
- LOG.exception('Failed creating sensor model for %s', sensor)
+ LOG.exception("Failed creating sensor model for %s", sensor)
return sensor_model
-def register_sensors(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_sensors(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = SensorsRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = SensorsRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/bootstrap/triggersregistrar.py b/st2common/st2common/bootstrap/triggersregistrar.py
index 4f95a6d0a3..180c9cb885 100644
--- a/st2common/st2common/bootstrap/triggersregistrar.py
+++ b/st2common/st2common/bootstrap/triggersregistrar.py
@@ -24,10 +24,7 @@
import st2common.content.utils as content_utils
from st2common.models.utils import sensor_type_utils
-__all__ = [
- 'TriggersRegistrar',
- 'register_triggers'
-]
+__all__ = ["TriggersRegistrar", "register_triggers"]
LOG = logging.getLogger(__name__)
@@ -47,15 +44,18 @@ def register_from_packs(self, base_dirs):
self.register_packs(base_dirs=base_dirs)
registered_count = 0
- content = self._pack_loader.get_content(base_dirs=base_dirs,
- content_type='triggers')
+ content = self._pack_loader.get_content(
+ base_dirs=base_dirs, content_type="triggers"
+ )
for pack, triggers_dir in six.iteritems(content):
if not triggers_dir:
- LOG.debug('Pack %s does not contain triggers.', pack)
+ LOG.debug("Pack %s does not contain triggers.", pack)
continue
try:
- LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir)
+ LOG.debug(
+ "Registering triggers from pack %s:, dir: %s", pack, triggers_dir
+ )
triggers = self._get_triggers_from_pack(triggers_dir)
count = self._register_triggers_from_pack(pack=pack, triggers=triggers)
registered_count += count
@@ -63,8 +63,11 @@ def register_from_packs(self, base_dirs):
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir,
- six.text_type(e))
+ LOG.exception(
+ 'Failed registering all triggers from pack "%s": %s',
+ triggers_dir,
+ six.text_type(e),
+ )
return registered_count
@@ -75,10 +78,11 @@ def register_from_pack(self, pack_dir):
:return: Number of triggers registered.
:rtype: ``int``
"""
- pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir
+ pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir
_, pack = os.path.split(pack_dir)
- triggers_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir,
- content_type='triggers')
+ triggers_dir = self._pack_loader.get_content_from_pack(
+ pack_dir=pack_dir, content_type="triggers"
+ )
# Register pack first
self.register_pack(pack_name=pack, pack_dir=pack_dir)
@@ -87,17 +91,22 @@ def register_from_pack(self, pack_dir):
if not triggers_dir:
return registered_count
- LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir)
+ LOG.debug("Registering triggers from pack %s:, dir: %s", pack, triggers_dir)
try:
triggers = self._get_triggers_from_pack(triggers_dir=triggers_dir)
- registered_count = self._register_triggers_from_pack(pack=pack, triggers=triggers)
+ registered_count = self._register_triggers_from_pack(
+ pack=pack, triggers=triggers
+ )
except Exception as e:
if self._fail_on_failure:
raise e
- LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir,
- six.text_type(e))
+ LOG.exception(
+ 'Failed registering all triggers from pack "%s": %s',
+ triggers_dir,
+ six.text_type(e),
+ )
return registered_count
@@ -107,20 +116,27 @@ def _get_triggers_from_pack(self, triggers_dir):
def _register_triggers_from_pack(self, pack, triggers):
registered_count = 0
- pack_base_path = content_utils.get_pack_base_path(pack_name=pack,
- include_trailing_slash=True)
+ pack_base_path = content_utils.get_pack_base_path(
+ pack_name=pack, include_trailing_slash=True
+ )
for trigger in triggers:
try:
- self._register_trigger_from_pack(pack_base_path=pack_base_path, pack=pack,
- trigger=trigger)
+ self._register_trigger_from_pack(
+ pack_base_path=pack_base_path, pack=pack, trigger=trigger
+ )
except Exception as e:
if self._fail_on_failure:
- msg = ('Failed to register trigger "%s" from pack "%s": %s' % (trigger, pack,
- six.text_type(e)))
+ msg = 'Failed to register trigger "%s" from pack "%s": %s' % (
+ trigger,
+ pack,
+ six.text_type(e),
+ )
raise ValueError(msg)
- LOG.debug('Failed to register trigger "%s": %s', trigger, six.text_type(e))
+ LOG.debug(
+ 'Failed to register trigger "%s": %s', trigger, six.text_type(e)
+ )
else:
LOG.debug('Trigger "%s" successfully registered', trigger)
registered_count += 1
@@ -130,37 +146,41 @@ def _register_triggers_from_pack(self, pack, triggers):
def _register_trigger_from_pack(self, pack_base_path, pack, trigger):
trigger_metadata_file_path = trigger
- LOG.debug('Loading trigger from %s.', trigger_metadata_file_path)
+ LOG.debug("Loading trigger from %s.", trigger_metadata_file_path)
content = self._meta_loader.load(file_path=trigger_metadata_file_path)
- pack_field = content.get('pack', None)
+ pack_field = content.get("pack", None)
if not pack_field:
- content['pack'] = pack
+ content["pack"] = pack
pack_field = pack
if pack_field != pack:
- raise Exception('Model is in pack "%s" but field "pack" is different: %s' %
- (pack, pack_field))
+ raise Exception(
+ 'Model is in pack "%s" but field "pack" is different: %s'
+ % (pack, pack_field)
+ )
# Add in "metadata_file" attribute which stores path to the pack metadata file relative to
# the pack directory
- metadata_file = trigger.replace(pack_base_path, '')
- content['metadata_file'] = metadata_file
+ metadata_file = trigger.replace(pack_base_path, "")
+ content["metadata_file"] = metadata_file
trigger_types = [content]
result = sensor_type_utils.create_trigger_types(trigger_types=trigger_types)
return result[0] if result else None
-def register_triggers(packs_base_paths=None, pack_dir=None, use_pack_cache=True,
- fail_on_failure=False):
+def register_triggers(
+ packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False
+):
if packs_base_paths:
assert isinstance(packs_base_paths, list)
if not packs_base_paths:
packs_base_paths = content_utils.get_packs_base_paths()
- registrar = TriggersRegistrar(use_pack_cache=use_pack_cache,
- fail_on_failure=fail_on_failure)
+ registrar = TriggersRegistrar(
+ use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure
+ )
if pack_dir:
result = registrar.register_from_pack(pack_dir=pack_dir)
diff --git a/st2common/st2common/callback/base.py b/st2common/st2common/callback/base.py
index ae1b55e501..a48fcbecb9 100644
--- a/st2common/st2common/callback/base.py
+++ b/st2common/st2common/callback/base.py
@@ -21,7 +21,7 @@
__all__ = [
- 'AsyncActionExecutionCallbackHandler',
+ "AsyncActionExecutionCallbackHandler",
]
@@ -30,7 +30,6 @@
@six.add_metaclass(abc.ABCMeta)
class AsyncActionExecutionCallbackHandler(object):
-
@staticmethod
@abc.abstractmethod
def callback(liveaction):
diff --git a/st2common/st2common/cmd/download_pack.py b/st2common/st2common/cmd/download_pack.py
index b22a3f0467..5ef0fb72b9 100644
--- a/st2common/st2common/cmd/download_pack.py
+++ b/st2common/st2common/cmd/download_pack.py
@@ -24,23 +24,34 @@
from st2common.util.pack_management import download_pack
from st2common.util.pack_management import get_and_set_proxy_config
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.MultiStrOpt('pack', default=None, required=True, positional=True,
- help='Name of the pack to install (download).'),
- cfg.BoolOpt('verify-ssl', default=True,
- help=('Verify SSL certificate of the Git repo from which the pack is '
- 'installed.')),
- cfg.BoolOpt('force', default=False,
- help='True to force pack download and ignore download '
- 'lock file if it exists.'),
+ cfg.MultiStrOpt(
+ "pack",
+ default=None,
+ required=True,
+ positional=True,
+ help="Name of the pack to install (download).",
+ ),
+ cfg.BoolOpt(
+ "verify-ssl",
+ default=True,
+ help=(
+ "Verify SSL certificate of the Git repo from which the pack is "
+ "installed."
+ ),
+ ),
+ cfg.BoolOpt(
+ "force",
+ default=False,
+ help="True to force pack download and ignore download "
+ "lock file if it exists.",
+ ),
]
do_register_cli_opts(cli_opts)
@@ -49,8 +60,12 @@ def main(argv):
_register_cli_opts()
# Parse CLI args, set up logging
- common_setup(config=config, setup_db=False, register_mq_exchanges=False,
- register_internal_trigger_types=False)
+ common_setup(
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_internal_trigger_types=False,
+ )
packs = cfg.CONF.pack
verify_ssl = cfg.CONF.verify_ssl
@@ -60,8 +75,13 @@ def main(argv):
for pack in packs:
LOG.info('Installing pack "%s"' % (pack))
- result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force,
- proxy_config=proxy_config, force_permissions=True)
+ result = download_pack(
+ pack=pack,
+ verify_ssl=verify_ssl,
+ force=force,
+ proxy_config=proxy_config,
+ force_permissions=True,
+ )
# Raw pack name excluding the version
pack_name = result[1]
diff --git a/st2common/st2common/cmd/generate_api_spec.py b/st2common/st2common/cmd/generate_api_spec.py
index 1b0a65ec8f..7ff7757b71 100644
--- a/st2common/st2common/cmd/generate_api_spec.py
+++ b/st2common/st2common/cmd/generate_api_spec.py
@@ -25,9 +25,7 @@
from st2common.script_setup import setup as common_setup
from st2common.script_setup import teardown as common_teardown
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
@@ -37,7 +35,7 @@ def setup():
def generate_spec():
- spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2')
+ spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2")
print(spec_string)
@@ -52,7 +50,7 @@ def main():
generate_spec()
ret = 0
except Exception:
- LOG.error('Failed to generate openapi.yaml file', exc_info=True)
+ LOG.error("Failed to generate openapi.yaml file", exc_info=True)
ret = 1
finally:
teartown()
diff --git a/st2common/st2common/cmd/install_pack.py b/st2common/st2common/cmd/install_pack.py
index 861d0d4041..42c2267012 100644
--- a/st2common/st2common/cmd/install_pack.py
+++ b/st2common/st2common/cmd/install_pack.py
@@ -25,23 +25,34 @@
from st2common.util.pack_management import get_and_set_proxy_config
from st2common.util.virtualenvs import setup_pack_virtualenv
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.MultiStrOpt('pack', default=None, required=True, positional=True,
- help='Name of the pack to install.'),
- cfg.BoolOpt('verify-ssl', default=True,
- help=('Verify SSL certificate of the Git repo from which the pack is '
- 'downloaded.')),
- cfg.BoolOpt('force', default=False,
- help='True to force pack installation and ignore install '
- 'lock file if it exists.'),
+ cfg.MultiStrOpt(
+ "pack",
+ default=None,
+ required=True,
+ positional=True,
+ help="Name of the pack to install.",
+ ),
+ cfg.BoolOpt(
+ "verify-ssl",
+ default=True,
+ help=(
+ "Verify SSL certificate of the Git repo from which the pack is "
+ "downloaded."
+ ),
+ ),
+ cfg.BoolOpt(
+ "force",
+ default=False,
+ help="True to force pack installation and ignore install "
+ "lock file if it exists.",
+ ),
]
do_register_cli_opts(cli_opts)
@@ -50,8 +61,12 @@ def main(argv):
_register_cli_opts()
# Parse CLI args, set up logging
- common_setup(config=config, setup_db=False, register_mq_exchanges=False,
- register_internal_trigger_types=False)
+ common_setup(
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_internal_trigger_types=False,
+ )
packs = cfg.CONF.pack
verify_ssl = cfg.CONF.verify_ssl
@@ -62,8 +77,13 @@ def main(argv):
for pack in packs:
# 1. Download the pack
LOG.info('Installing pack "%s"' % (pack))
- result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force,
- proxy_config=proxy_config, force_permissions=True)
+ result = download_pack(
+ pack=pack,
+ verify_ssl=verify_ssl,
+ force=force,
+ proxy_config=proxy_config,
+ force_permissions=True,
+ )
# Raw pack name excluding the version
pack_name = result[1]
@@ -78,9 +98,13 @@ def main(argv):
# 2. Setup pack virtual environment
LOG.info('Setting up virtualenv for pack "%s"' % (pack_name))
- setup_pack_virtualenv(pack_name=pack_name, update=False, logger=LOG,
- proxy_config=proxy_config,
- no_download=True)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ logger=LOG,
+ proxy_config=proxy_config,
+ no_download=True,
+ )
LOG.info('Successfully set up virtualenv for pack "%s"' % (pack_name))
return 0
diff --git a/st2common/st2common/cmd/purge_executions.py b/st2common/st2common/cmd/purge_executions.py
index dcf7b47b40..27225d661c 100755
--- a/st2common/st2common/cmd/purge_executions.py
+++ b/st2common/st2common/cmd/purge_executions.py
@@ -38,25 +38,30 @@
from st2common.constants.exit_codes import FAILURE_EXIT_CODE
from st2common.garbage_collection.executions import purge_executions
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.StrOpt('timestamp', default=None,
- help='Will delete execution and liveaction models older than ' +
- 'this UTC timestamp. ' +
- 'Example value: 2015-03-13T19:01:27.255542Z.'),
- cfg.StrOpt('action-ref', default='',
- help='action-ref to delete executions for.'),
- cfg.BoolOpt('purge-incomplete', default=False,
- help='Purge all models irrespective of their ``status``.' +
- 'By default, only executions in completed states such as "succeeeded" ' +
- ', "failed", "canceled" and "timed_out" are deleted.'),
+ cfg.StrOpt(
+ "timestamp",
+ default=None,
+ help="Will delete execution and liveaction models older than "
+ + "this UTC timestamp. "
+ + "Example value: 2015-03-13T19:01:27.255542Z.",
+ ),
+ cfg.StrOpt(
+ "action-ref", default="", help="action-ref to delete executions for."
+ ),
+ cfg.BoolOpt(
+ "purge-incomplete",
+ default=False,
+ help="Purge all models irrespective of their ``status``."
+ + 'By default, only executions in completed states such as "succeeeded" '
+ + ', "failed", "canceled" and "timed_out" are deleted.',
+ ),
]
do_register_cli_opts(cli_opts)
@@ -71,15 +76,19 @@ def main():
purge_incomplete = cfg.CONF.purge_incomplete
if not timestamp:
- LOG.error('Please supply a timestamp for purging models. Aborting.')
+ LOG.error("Please supply a timestamp for purging models. Aborting.")
return 1
else:
- timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=pytz.UTC)
try:
- purge_executions(logger=LOG, timestamp=timestamp, action_ref=action_ref,
- purge_incomplete=purge_incomplete)
+ purge_executions(
+ logger=LOG,
+ timestamp=timestamp,
+ action_ref=action_ref,
+ purge_incomplete=purge_incomplete,
+ )
except Exception as e:
LOG.exception(six.text_type(e))
return FAILURE_EXIT_CODE
diff --git a/st2common/st2common/cmd/purge_trigger_instances.py b/st2common/st2common/cmd/purge_trigger_instances.py
index e0908e9f8d..529b786678 100755
--- a/st2common/st2common/cmd/purge_trigger_instances.py
+++ b/st2common/st2common/cmd/purge_trigger_instances.py
@@ -38,19 +38,20 @@
from st2common.constants.exit_codes import FAILURE_EXIT_CODE
from st2common.garbage_collection.trigger_instances import purge_trigger_instances
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.StrOpt('timestamp', default=None,
- help='Will delete trigger instances older than ' +
- 'this UTC timestamp. ' +
- 'Example value: 2015-03-13T19:01:27.255542Z')
+ cfg.StrOpt(
+ "timestamp",
+ default=None,
+ help="Will delete trigger instances older than "
+ + "this UTC timestamp. "
+ + "Example value: 2015-03-13T19:01:27.255542Z",
+ )
]
do_register_cli_opts(cli_opts)
@@ -63,10 +64,10 @@ def main():
timestamp = cfg.CONF.timestamp
if not timestamp:
- LOG.error('Please supply a timestamp for purging models. Aborting.')
+ LOG.error("Please supply a timestamp for purging models. Aborting.")
return 1
else:
- timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
+ timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=pytz.UTC)
# Purge models.
diff --git a/st2common/st2common/cmd/setup_pack_virtualenv.py b/st2common/st2common/cmd/setup_pack_virtualenv.py
index 626bb389af..514b1cf2e0 100644
--- a/st2common/st2common/cmd/setup_pack_virtualenv.py
+++ b/st2common/st2common/cmd/setup_pack_virtualenv.py
@@ -22,23 +22,31 @@
from st2common.util.pack_management import get_and_set_proxy_config
from st2common.util.virtualenvs import setup_pack_virtualenv
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.MultiStrOpt('pack', default=None, required=True, positional=True,
- help='Name of the pack to setup the virtual environment for.'),
- cfg.BoolOpt('update', default=False,
- help=('Check this option if the virtual environment already exists and if you '
- 'only want to perform an update and installation of new dependencies. If '
- 'you don\'t check this option, the virtual environment will be destroyed '
- 'then re-created. If you check this and the virtual environment doesn\'t '
- 'exist, it will create it..')),
+ cfg.MultiStrOpt(
+ "pack",
+ default=None,
+ required=True,
+ positional=True,
+ help="Name of the pack to setup the virtual environment for.",
+ ),
+ cfg.BoolOpt(
+ "update",
+ default=False,
+ help=(
+ "Check this option if the virtual environment already exists and if you "
+ "only want to perform an update and installation of new dependencies. If "
+ "you don't check this option, the virtual environment will be destroyed "
+ "then re-created. If you check this and the virtual environment doesn't "
+ "exist, it will create it.."
+ ),
+ ),
]
do_register_cli_opts(cli_opts)
@@ -47,8 +55,12 @@ def main(argv):
_register_cli_opts()
# Parse CLI args, set up logging
- common_setup(config=config, setup_db=False, register_mq_exchanges=False,
- register_internal_trigger_types=False)
+ common_setup(
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_internal_trigger_types=False,
+ )
packs = cfg.CONF.pack
update = cfg.CONF.update
@@ -58,9 +70,13 @@ def main(argv):
for pack in packs:
# Setup pack virtual environment
LOG.info('Setting up virtualenv for pack "%s"' % (pack))
- setup_pack_virtualenv(pack_name=pack, update=update, logger=LOG,
- proxy_config=proxy_config,
- no_download=True)
+ setup_pack_virtualenv(
+ pack_name=pack,
+ update=update,
+ logger=LOG,
+ proxy_config=proxy_config,
+ no_download=True,
+ )
LOG.info('Successfully set up virtualenv for pack "%s"' % (pack))
return 0
diff --git a/st2common/st2common/cmd/validate_api_spec.py b/st2common/st2common/cmd/validate_api_spec.py
index 743b3e467a..4f317db4a4 100644
--- a/st2common/st2common/cmd/validate_api_spec.py
+++ b/st2common/st2common/cmd/validate_api_spec.py
@@ -33,19 +33,20 @@
import six
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
cfg.CONF.register_cli_opt(
- cfg.StrOpt('spec-file', short='f', required=False,
- default='st2common/st2common/openapi.yaml')
+ cfg.StrOpt(
+ "spec-file",
+ short="f",
+ required=False,
+ default="st2common/st2common/openapi.yaml",
+ )
)
cfg.CONF.register_cli_opt(
- cfg.BoolOpt('generate', short='-c', required=False,
- default=False)
+ cfg.BoolOpt("generate", short="-c", required=False, default=False)
)
LOG = logging.getLogger(__name__)
@@ -56,12 +57,12 @@ def setup():
def _validate_definitions(spec):
- defs = spec.get('definitions', None)
+ defs = spec.get("definitions", None)
error = False
verbose = cfg.CONF.verbose
for (model, definition) in six.iteritems(defs):
- api_model = definition.get('x-api-model', None)
+ api_model = definition.get("x-api-model", None)
if not api_model:
msg = (
@@ -69,7 +70,7 @@ def _validate_definitions(spec):
)
if verbose:
- LOG.info('Supplied definition for model %s: \n\n%s.', model, definition)
+ LOG.info("Supplied definition for model %s: \n\n%s.", model, definition)
error = True
LOG.error(msg)
@@ -82,18 +83,20 @@ def validate_spec():
generate_spec = cfg.CONF.generate
if not os.path.exists(spec_file) and not generate_spec:
- msg = ('No spec file found in location %s. ' % spec_file +
- 'Provide a valid spec file or ' +
- 'pass --generate-api-spec to genrate a spec.')
+ msg = (
+ "No spec file found in location %s. " % spec_file
+ + "Provide a valid spec file or "
+ + "pass --generate-api-spec to genrate a spec."
+ )
raise Exception(msg)
if generate_spec:
if not spec_file:
- raise Exception('Supply a path to write to spec file to.')
+ raise Exception("Supply a path to write to spec file to.")
- spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2')
+ spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2")
- with open(spec_file, 'w') as f:
+ with open(spec_file, "w") as f:
f.write(spec_string)
f.flush()
@@ -112,13 +115,15 @@ def main():
try:
# 1. Validate there are no duplicates keys in the YAML file
- spec_loader.load_spec('st2common', 'openapi.yaml.j2', allow_duplicate_keys=False)
+ spec_loader.load_spec(
+ "st2common", "openapi.yaml.j2", allow_duplicate_keys=False
+ )
# 2. Validate schema (currently broken)
# validate_spec()
ret = 0
except Exception:
- LOG.error('Failed to validate openapi.yaml file', exc_info=True)
+ LOG.error("Failed to validate openapi.yaml file", exc_info=True)
ret = 1
finally:
teartown()
diff --git a/st2common/st2common/cmd/validate_config.py b/st2common/st2common/cmd/validate_config.py
index 2bd5b58d0d..6b7bedd32f 100644
--- a/st2common/st2common/cmd/validate_config.py
+++ b/st2common/st2common/cmd/validate_config.py
@@ -31,9 +31,7 @@
from st2common.constants.exit_codes import FAILURE_EXIT_CODE
from st2common.util.pack import validate_config_against_schema
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
def _do_register_cli_opts(opts, ignore_errors=False):
@@ -47,10 +45,18 @@ def _do_register_cli_opts(opts, ignore_errors=False):
def _register_cli_opts():
cli_opts = [
- cfg.StrOpt('schema-path', default=None, required=True,
- help='Path to the config schema to use for validation.'),
- cfg.StrOpt('config-path', default=None, required=True,
- help='Path to the config file to validate.'),
+ cfg.StrOpt(
+ "schema-path",
+ default=None,
+ required=True,
+ help="Path to the config schema to use for validation.",
+ ),
+ cfg.StrOpt(
+ "config-path",
+ default=None,
+ required=True,
+ help="Path to the config file to validate.",
+ ),
]
do_register_cli_opts(cli_opts)
@@ -65,18 +71,24 @@ def main():
print('Validating config "%s" against schema in "%s"' % (config_path, schema_path))
- with open(schema_path, 'r') as fp:
+ with open(schema_path, "r") as fp:
config_schema = yaml.safe_load(fp.read())
- with open(config_path, 'r') as fp:
+ with open(config_path, "r") as fp:
config_object = yaml.safe_load(fp.read())
try:
- validate_config_against_schema(config_schema=config_schema, config_object=config_object,
- config_path=config_path)
+ validate_config_against_schema(
+ config_schema=config_schema,
+ config_object=config_object,
+ config_path=config_path,
+ )
except Exception as e:
- print('Failed to validate pack config.\n%s' % six.text_type(e))
+ print("Failed to validate pack config.\n%s" % six.text_type(e))
return FAILURE_EXIT_CODE
- print('Config "%s" successfully validated against schema in %s.' % (config_path, schema_path))
+ print(
+ 'Config "%s" successfully validated against schema in %s.'
+ % (config_path, schema_path)
+ )
return SUCCESS_EXIT_CODE
diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py
index 8ad77fa626..e7b30a9a7c 100644
--- a/st2common/st2common/config.py
+++ b/st2common/st2common/config.py
@@ -25,12 +25,7 @@
from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL
from st2common.constants.action import LIVEACTION_COMPLETED_STATES
-__all__ = [
- 'do_register_opts',
- 'do_register_cli_opts',
-
- 'parse_args'
-]
+__all__ = ["do_register_opts", "do_register_cli_opts", "parse_args"]
def do_register_opts(opts, group=None, ignore_errors=False):
@@ -57,447 +52,550 @@ def do_register_cli_opts(opt, ignore_errors=False):
def register_opts(ignore_errors=False):
rbac_opts = [
+ cfg.BoolOpt("enable", default=False, help="Enable RBAC."),
+ cfg.StrOpt("backend", default="noop", help="RBAC backend to use."),
cfg.BoolOpt(
- 'enable', default=False,
- help='Enable RBAC.'),
- cfg.StrOpt(
- 'backend', default='noop',
- help='RBAC backend to use.'),
- cfg.BoolOpt(
- 'sync_remote_groups', default=False,
- help='True to synchronize remote groups returned by the auth backed for each '
- 'StackStorm user with local StackStorm roles based on the group to role '
- 'mapping definition files.'),
+ "sync_remote_groups",
+ default=False,
+ help="True to synchronize remote groups returned by the auth backed for each "
+ "StackStorm user with local StackStorm roles based on the group to role "
+ "mapping definition files.",
+ ),
cfg.BoolOpt(
- 'permission_isolation', default=False,
- help='Isolate resources by user. For now, these resources only include rules and '
- 'executions. All resources can only be viewed or executed by the owning user '
- 'except the admin and system_user who can view or run everything.')
+ "permission_isolation",
+ default=False,
+ help="Isolate resources by user. For now, these resources only include rules and "
+ "executions. All resources can only be viewed or executed by the owning user "
+ "except the admin and system_user who can view or run everything.",
+ ),
]
- do_register_opts(rbac_opts, 'rbac', ignore_errors)
+ do_register_opts(rbac_opts, "rbac", ignore_errors)
system_user_opts = [
+ cfg.StrOpt("user", default="stanley", help="Default system user."),
cfg.StrOpt(
- 'user', default='stanley',
- help='Default system user.'),
- cfg.StrOpt(
- 'ssh_key_file', default='/home/stanley/.ssh/stanley_rsa',
- help='SSH private key for the system user.')
+ "ssh_key_file",
+ default="/home/stanley/.ssh/stanley_rsa",
+ help="SSH private key for the system user.",
+ ),
]
- do_register_opts(system_user_opts, 'system_user', ignore_errors)
+ do_register_opts(system_user_opts, "system_user", ignore_errors)
schema_opts = [
- cfg.IntOpt(
- 'version', default=4,
- help='Version of JSON schema to use.'),
+ cfg.IntOpt("version", default=4, help="Version of JSON schema to use."),
cfg.StrOpt(
- 'draft', default='http://json-schema.org/draft-04/schema#',
- help='URL to the JSON schema draft.')
+ "draft",
+ default="http://json-schema.org/draft-04/schema#",
+ help="URL to the JSON schema draft.",
+ ),
]
- do_register_opts(schema_opts, 'schema', ignore_errors)
+ do_register_opts(schema_opts, "schema", ignore_errors)
system_opts = [
- cfg.BoolOpt(
- 'debug', default=False,
- help='Enable debug mode.'),
+ cfg.BoolOpt("debug", default=False, help="Enable debug mode."),
cfg.StrOpt(
- 'base_path', default='/opt/stackstorm',
- help='Base path to all st2 artifacts.'),
+ "base_path",
+ default="/opt/stackstorm",
+ help="Base path to all st2 artifacts.",
+ ),
cfg.BoolOpt(
- 'validate_trigger_parameters', default=True,
- help='True to validate parameters for non-system trigger types when creating'
- 'a rule. By default, only parameters for system triggers are validated.'),
+ "validate_trigger_parameters",
+ default=True,
+ help="True to validate parameters for non-system trigger types when creating"
+ "a rule. By default, only parameters for system triggers are validated.",
+ ),
cfg.BoolOpt(
- 'validate_trigger_payload', default=True,
- help='True to validate payload for non-system trigger types when dispatching a trigger '
- 'inside the sensor. By default, only payload for system triggers is validated.'),
+ "validate_trigger_payload",
+ default=True,
+ help="True to validate payload for non-system trigger types when dispatching a trigger "
+ "inside the sensor. By default, only payload for system triggers is validated.",
+ ),
cfg.BoolOpt(
- 'validate_output_schema', default=False,
- help='True to validate action and runner output against schema.')
+ "validate_output_schema",
+ default=False,
+ help="True to validate action and runner output against schema.",
+ ),
]
- do_register_opts(system_opts, 'system', ignore_errors)
+ do_register_opts(system_opts, "system", ignore_errors)
- system_packs_base_path = os.path.join(cfg.CONF.system.base_path, 'packs')
- system_runners_base_path = os.path.join(cfg.CONF.system.base_path, 'runners')
+ system_packs_base_path = os.path.join(cfg.CONF.system.base_path, "packs")
+ system_runners_base_path = os.path.join(cfg.CONF.system.base_path, "runners")
content_opts = [
cfg.StrOpt(
- 'pack_group', default='st2packs',
- help='User group that can write to packs directory.'),
- cfg.StrOpt(
- 'system_packs_base_path', default=system_packs_base_path,
- help='Path to the directory which contains system packs.'),
- cfg.StrOpt(
- 'system_runners_base_path', default=system_runners_base_path,
- help='Path to the directory which contains system runners. '
- 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'),
- cfg.StrOpt(
- 'packs_base_paths', default=None,
- help='Paths which will be searched for integration packs.'),
- cfg.StrOpt(
- 'runners_base_paths', default=None,
- help='Paths which will be searched for runners. '
- 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'),
+ "pack_group",
+ default="st2packs",
+ help="User group that can write to packs directory.",
+ ),
+ cfg.StrOpt(
+ "system_packs_base_path",
+ default=system_packs_base_path,
+ help="Path to the directory which contains system packs.",
+ ),
+ cfg.StrOpt(
+ "system_runners_base_path",
+ default=system_runners_base_path,
+ help="Path to the directory which contains system runners. "
+ "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0",
+ ),
+ cfg.StrOpt(
+ "packs_base_paths",
+ default=None,
+ help="Paths which will be searched for integration packs.",
+ ),
+ cfg.StrOpt(
+ "runners_base_paths",
+ default=None,
+ help="Paths which will be searched for runners. "
+ "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0",
+ ),
cfg.ListOpt(
- 'index_url', default=['https://index.stackstorm.org/v1/index.json'],
- help='A URL pointing to the pack index. StackStorm Exchange is used by '
- 'default. Use a comma-separated list for multiple indexes if you '
- 'want to get other packs discovered with "st2 pack search".'),
+ "index_url",
+ default=["https://index.stackstorm.org/v1/index.json"],
+ help="A URL pointing to the pack index. StackStorm Exchange is used by "
+ "default. Use a comma-separated list for multiple indexes if you "
+ 'want to get other packs discovered with "st2 pack search".',
+ ),
]
- do_register_opts(content_opts, 'content', ignore_errors)
+ do_register_opts(content_opts, "content", ignore_errors)
webui_opts = [
cfg.StrOpt(
- 'webui_base_url', default='https://%s' % socket.getfqdn(),
- help='Base https URL to access st2 Web UI. This is used to construct history URLs '
- 'that are sent out when chatops is used to kick off executions.')
+ "webui_base_url",
+ default="https://%s" % socket.getfqdn(),
+ help="Base https URL to access st2 Web UI. This is used to construct history URLs "
+ "that are sent out when chatops is used to kick off executions.",
+ )
]
- do_register_opts(webui_opts, 'webui', ignore_errors)
+ do_register_opts(webui_opts, "webui", ignore_errors)
db_opts = [
- cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='host of db server'),
+ cfg.StrOpt("host", default="127.0.0.1", help="host of db server"),
+ cfg.IntOpt("port", default=27017, help="port of db server"),
+ cfg.StrOpt("db_name", default="st2", help="name of database"),
+ cfg.StrOpt("username", help="username for db login"),
+ cfg.StrOpt("password", help="password for db login"),
cfg.IntOpt(
- 'port', default=27017,
- help='port of db server'),
- cfg.StrOpt(
- 'db_name', default='st2',
- help='name of database'),
- cfg.StrOpt(
- 'username',
- help='username for db login'),
- cfg.StrOpt(
- 'password',
- help='password for db login'),
+ "connection_timeout",
+ default=3 * 1000,
+ help="Connection and server selection timeout (in ms).",
+ ),
cfg.IntOpt(
- 'connection_timeout', default=3 * 1000,
- help='Connection and server selection timeout (in ms).'),
+ "connection_retry_max_delay_m",
+ default=3,
+ help="Connection retry total time (minutes).",
+ ),
cfg.IntOpt(
- 'connection_retry_max_delay_m', default=3,
- help='Connection retry total time (minutes).'),
+ "connection_retry_backoff_max_s",
+ default=10,
+ help="Connection retry backoff max (seconds).",
+ ),
cfg.IntOpt(
- 'connection_retry_backoff_max_s', default=10,
- help='Connection retry backoff max (seconds).'),
- cfg.IntOpt(
- 'connection_retry_backoff_mul', default=1,
- help='Backoff multiplier (seconds).'),
+ "connection_retry_backoff_mul",
+ default=1,
+ help="Backoff multiplier (seconds).",
+ ),
cfg.BoolOpt(
- 'ssl', default=False,
- help='Create the connection to mongodb using SSL'),
- cfg.StrOpt(
- 'ssl_keyfile', default=None,
- help='Private keyfile used to identify the local connection against MongoDB.'),
- cfg.StrOpt(
- 'ssl_certfile', default=None,
- help='Certificate file used to identify the localconnection'),
- cfg.StrOpt(
- 'ssl_cert_reqs', default=None, choices='none, optional, required',
- help='Specifies whether a certificate is required from the other side of the '
- 'connection, and whether it will be validated if provided'),
- cfg.StrOpt(
- 'ssl_ca_certs', default=None,
- help='ca_certs file contains a set of concatenated CA certificates, which are '
- 'used to validate certificates passed from MongoDB.'),
+ "ssl", default=False, help="Create the connection to mongodb using SSL"
+ ),
+ cfg.StrOpt(
+ "ssl_keyfile",
+ default=None,
+ help="Private keyfile used to identify the local connection against MongoDB.",
+ ),
+ cfg.StrOpt(
+ "ssl_certfile",
+ default=None,
+ help="Certificate file used to identify the localconnection",
+ ),
+ cfg.StrOpt(
+ "ssl_cert_reqs",
+ default=None,
+ choices="none, optional, required",
+ help="Specifies whether a certificate is required from the other side of the "
+ "connection, and whether it will be validated if provided",
+ ),
+ cfg.StrOpt(
+ "ssl_ca_certs",
+ default=None,
+ help="ca_certs file contains a set of concatenated CA certificates, which are "
+ "used to validate certificates passed from MongoDB.",
+ ),
cfg.BoolOpt(
- 'ssl_match_hostname', default=True,
- help='If True and `ssl_cert_reqs` is not None, enables hostname verification'),
- cfg.StrOpt(
- 'authentication_mechanism', default=None,
- help='Specifies database authentication mechanisms. '
- 'By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, '
- 'MONGODB-CR (MongoDB Challenge Response protocol) for older servers.')
+ "ssl_match_hostname",
+ default=True,
+ help="If True and `ssl_cert_reqs` is not None, enables hostname verification",
+ ),
+ cfg.StrOpt(
+ "authentication_mechanism",
+ default=None,
+ help="Specifies database authentication mechanisms. "
+ "By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, "
+ "MONGODB-CR (MongoDB Challenge Response protocol) for older servers.",
+ ),
]
- do_register_opts(db_opts, 'database', ignore_errors)
+ do_register_opts(db_opts, "database", ignore_errors)
messaging_opts = [
# It would be nice to be able to deprecate url and completely switch to using
# url. However, this will be a breaking change and will have impact so allowing both.
cfg.StrOpt(
- 'url', default='amqp://guest:guest@127.0.0.1:5672//',
- help='URL of the messaging server.'),
+ "url",
+ default="amqp://guest:guest@127.0.0.1:5672//",
+ help="URL of the messaging server.",
+ ),
cfg.ListOpt(
- 'cluster_urls', default=[],
- help='URL of all the nodes in a messaging service cluster.'),
+ "cluster_urls",
+ default=[],
+ help="URL of all the nodes in a messaging service cluster.",
+ ),
cfg.IntOpt(
- 'connection_retries', default=10,
- help='How many times should we retry connection before failing.'),
+ "connection_retries",
+ default=10,
+ help="How many times should we retry connection before failing.",
+ ),
cfg.IntOpt(
- 'connection_retry_wait', default=10000,
- help='How long should we wait between connection retries.'),
+ "connection_retry_wait",
+ default=10000,
+ help="How long should we wait between connection retries.",
+ ),
cfg.BoolOpt(
- 'ssl', default=False,
- help='Use SSL / TLS to connect to the messaging server. Same as '
- 'appending "?ssl=true" at the end of the connection URL string.'),
- cfg.StrOpt(
- 'ssl_keyfile', default=None,
- help='Private keyfile used to identify the local connection against RabbitMQ.'),
- cfg.StrOpt(
- 'ssl_certfile', default=None,
- help='Certificate file used to identify the local connection (client).'),
- cfg.StrOpt(
- 'ssl_cert_reqs', default=None, choices='none, optional, required',
- help='Specifies whether a certificate is required from the other side of the '
- 'connection, and whether it will be validated if provided.'),
- cfg.StrOpt(
- 'ssl_ca_certs', default=None,
- help='ca_certs file contains a set of concatenated CA certificates, which are '
- 'used to validate certificates passed from RabbitMQ.'),
- cfg.StrOpt(
- 'login_method', default=None,
- help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).')
+ "ssl",
+ default=False,
+ help="Use SSL / TLS to connect to the messaging server. Same as "
+ 'appending "?ssl=true" at the end of the connection URL string.',
+ ),
+ cfg.StrOpt(
+ "ssl_keyfile",
+ default=None,
+ help="Private keyfile used to identify the local connection against RabbitMQ.",
+ ),
+ cfg.StrOpt(
+ "ssl_certfile",
+ default=None,
+ help="Certificate file used to identify the local connection (client).",
+ ),
+ cfg.StrOpt(
+ "ssl_cert_reqs",
+ default=None,
+ choices="none, optional, required",
+ help="Specifies whether a certificate is required from the other side of the "
+ "connection, and whether it will be validated if provided.",
+ ),
+ cfg.StrOpt(
+ "ssl_ca_certs",
+ default=None,
+ help="ca_certs file contains a set of concatenated CA certificates, which are "
+ "used to validate certificates passed from RabbitMQ.",
+ ),
+ cfg.StrOpt(
+ "login_method",
+ default=None,
+ help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).",
+ ),
]
- do_register_opts(messaging_opts, 'messaging', ignore_errors)
+ do_register_opts(messaging_opts, "messaging", ignore_errors)
syslog_opts = [
+ cfg.StrOpt("host", default="127.0.0.1", help="Host for the syslog server."),
+ cfg.IntOpt("port", default=514, help="Port for the syslog server."),
+ cfg.StrOpt("facility", default="local7", help="Syslog facility level."),
cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='Host for the syslog server.'),
- cfg.IntOpt(
- 'port', default=514,
- help='Port for the syslog server.'),
- cfg.StrOpt(
- 'facility', default='local7',
- help='Syslog facility level.'),
- cfg.StrOpt(
- 'protocol', default='udp',
- help='Transport protocol to use (udp / tcp).')
+ "protocol", default="udp", help="Transport protocol to use (udp / tcp)."
+ ),
]
- do_register_opts(syslog_opts, 'syslog', ignore_errors)
+ do_register_opts(syslog_opts, "syslog", ignore_errors)
log_opts = [
- cfg.ListOpt(
- 'excludes', default='',
- help='Exclusion list of loggers to omit.'),
+ cfg.ListOpt("excludes", default="", help="Exclusion list of loggers to omit."),
cfg.BoolOpt(
- 'redirect_stderr', default=False,
- help='Controls if stderr should be redirected to the logs.'),
+ "redirect_stderr",
+ default=False,
+ help="Controls if stderr should be redirected to the logs.",
+ ),
cfg.BoolOpt(
- 'mask_secrets', default=True,
- help='True to mask secrets in the log files.'),
+ "mask_secrets", default=True, help="True to mask secrets in the log files."
+ ),
cfg.ListOpt(
- 'mask_secrets_blacklist', default=[],
- help='Blacklist of additional attribute names to mask in the log messages.')
+ "mask_secrets_blacklist",
+ default=[],
+ help="Blacklist of additional attribute names to mask in the log messages.",
+ ),
]
- do_register_opts(log_opts, 'log', ignore_errors)
+ do_register_opts(log_opts, "log", ignore_errors)
# Common API options
api_opts = [
- cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='StackStorm API server host'),
- cfg.IntOpt(
- 'port', default=9101,
- help='StackStorm API server port'),
+ cfg.StrOpt("host", default="127.0.0.1", help="StackStorm API server host"),
+ cfg.IntOpt("port", default=9101, help="StackStorm API server port"),
cfg.ListOpt(
- 'allow_origin', default=['http://127.0.0.1:3000'],
- help='List of origins allowed for api, auth and stream'),
+ "allow_origin",
+ default=["http://127.0.0.1:3000"],
+ help="List of origins allowed for api, auth and stream",
+ ),
cfg.BoolOpt(
- 'mask_secrets', default=True,
- help='True to mask secrets in the API responses')
+ "mask_secrets",
+ default=True,
+ help="True to mask secrets in the API responses",
+ ),
]
- do_register_opts(api_opts, 'api', ignore_errors)
+ do_register_opts(api_opts, "api", ignore_errors)
# Key Value store options
keyvalue_opts = [
cfg.BoolOpt(
- 'enable_encryption', default=True,
- help='Allow encryption of values in key value stored qualified as "secret".'),
- cfg.StrOpt(
- 'encryption_key_path', default='',
- help='Location of the symmetric encryption key for encrypting values in kvstore. '
- 'This key should be in JSON and should\'ve been generated using '
- 'st2-generate-symmetric-crypto-key tool.')
+ "enable_encryption",
+ default=True,
+ help='Allow encryption of values in key value stored qualified as "secret".',
+ ),
+ cfg.StrOpt(
+ "encryption_key_path",
+ default="",
+ help="Location of the symmetric encryption key for encrypting values in kvstore. "
+ "This key should be in JSON and should've been generated using "
+ "st2-generate-symmetric-crypto-key tool.",
+ ),
]
- do_register_opts(keyvalue_opts, group='keyvalue')
+ do_register_opts(keyvalue_opts, group="keyvalue")
# Common auth options
auth_opts = [
cfg.StrOpt(
- 'api_url', default=None,
- help='Base URL to the API endpoint excluding the version'),
- cfg.BoolOpt(
- 'enable', default=True,
- help='Enable authentication middleware.'),
+ "api_url",
+ default=None,
+ help="Base URL to the API endpoint excluding the version",
+ ),
+ cfg.BoolOpt("enable", default=True, help="Enable authentication middleware."),
cfg.IntOpt(
- 'token_ttl', default=(24 * 60 * 60),
- help='Access token ttl in seconds.'),
+ "token_ttl", default=(24 * 60 * 60), help="Access token ttl in seconds."
+ ),
# This TTL is used for tokens which belong to StackStorm services
cfg.IntOpt(
- 'service_token_ttl', default=(24 * 60 * 60),
- help='Service token ttl in seconds.')
+ "service_token_ttl",
+ default=(24 * 60 * 60),
+ help="Service token ttl in seconds.",
+ ),
]
- do_register_opts(auth_opts, 'auth', ignore_errors)
+ do_register_opts(auth_opts, "auth", ignore_errors)
# Runner options
default_python_bin_path = sys.executable
base_dir = os.path.dirname(os.path.realpath(default_python_bin_path))
- default_virtualenv_bin_path = os.path.join(base_dir, 'virtualenv')
+ default_virtualenv_bin_path = os.path.join(base_dir, "virtualenv")
action_runner_opts = [
# Common runner options
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.actionrunner.conf',
- help='location of the logging.conf file'),
-
+ "logging",
+ default="/etc/st2/logging.actionrunner.conf",
+ help="location of the logging.conf file",
+ ),
# Python runner options
cfg.StrOpt(
- 'python_binary', default=default_python_bin_path,
- help='Python binary which will be used by Python actions.'),
- cfg.StrOpt(
- 'virtualenv_binary', default=default_virtualenv_bin_path,
- help='Virtualenv binary which should be used to create pack virtualenvs.'),
- cfg.StrOpt(
- 'python_runner_log_level', default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL,
- help='Default log level to use for Python runner actions. Can be overriden on '
- 'invocation basis using "log_level" runner parameter.'),
+ "python_binary",
+ default=default_python_bin_path,
+ help="Python binary which will be used by Python actions.",
+ ),
+ cfg.StrOpt(
+ "virtualenv_binary",
+ default=default_virtualenv_bin_path,
+ help="Virtualenv binary which should be used to create pack virtualenvs.",
+ ),
+ cfg.StrOpt(
+ "python_runner_log_level",
+ default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL,
+ help="Default log level to use for Python runner actions. Can be overriden on "
+ 'invocation basis using "log_level" runner parameter.',
+ ),
cfg.ListOpt(
- 'virtualenv_opts', default=['--system-site-packages'],
+ "virtualenv_opts",
+ default=["--system-site-packages"],
help='List of virtualenv options to be passsed to "virtualenv" command that '
- 'creates pack virtualenv.'),
+ "creates pack virtualenv.",
+ ),
cfg.ListOpt(
- 'pip_opts', default=[],
+ "pip_opts",
+ default=[],
help='List of pip options to be passed to "pip install" command when installing pack '
- 'dependencies into pack virtual environment.'),
+ "dependencies into pack virtual environment.",
+ ),
cfg.BoolOpt(
- 'stream_output', default=True,
- help='True to store and stream action output (stdout and stderr) in real-time.'),
+ "stream_output",
+ default=True,
+ help="True to store and stream action output (stdout and stderr) in real-time.",
+ ),
cfg.IntOpt(
- 'stream_output_buffer_size', default=-1,
- help=('Buffer size to use for real time action output streaming. 0 means unbuffered '
- '1 means line buffered, -1 means system default, which usually means fully '
- 'buffered and any other positive value means use a buffer of (approximately) '
- 'that size'))
+ "stream_output_buffer_size",
+ default=-1,
+ help=(
+ "Buffer size to use for real time action output streaming. 0 means unbuffered "
+ "1 means line buffered, -1 means system default, which usually means fully "
+ "buffered and any other positive value means use a buffer of (approximately) "
+ "that size"
+ ),
+ ),
]
- do_register_opts(action_runner_opts, group='actionrunner')
+ do_register_opts(action_runner_opts, group="actionrunner")
dispatcher_pool_opts = [
cfg.IntOpt(
- 'workflows_pool_size', default=40,
- help='Internal pool size for dispatcher used by workflow actions.'),
+ "workflows_pool_size",
+ default=40,
+ help="Internal pool size for dispatcher used by workflow actions.",
+ ),
cfg.IntOpt(
- 'actions_pool_size', default=60,
- help='Internal pool size for dispatcher used by regular actions.')
+ "actions_pool_size",
+ default=60,
+ help="Internal pool size for dispatcher used by regular actions.",
+ ),
]
- do_register_opts(dispatcher_pool_opts, group='actionrunner')
+ do_register_opts(dispatcher_pool_opts, group="actionrunner")
ssh_runner_opts = [
cfg.StrOpt(
- 'remote_dir', default='/tmp',
- help='Location of the script on the remote filesystem.'),
+ "remote_dir",
+ default="/tmp",
+ help="Location of the script on the remote filesystem.",
+ ),
cfg.BoolOpt(
- 'allow_partial_failure', default=False,
- help='How partial success of actions run on multiple nodes should be treated.'),
+ "allow_partial_failure",
+ default=False,
+ help="How partial success of actions run on multiple nodes should be treated.",
+ ),
cfg.IntOpt(
- 'max_parallel_actions', default=50,
- help='Max number of parallel remote SSH actions that should be run. '
- 'Works only with Paramiko SSH runner.'),
+ "max_parallel_actions",
+ default=50,
+ help="Max number of parallel remote SSH actions that should be run. "
+ "Works only with Paramiko SSH runner.",
+ ),
cfg.BoolOpt(
- 'use_ssh_config', default=False,
- help='Use the .ssh/config file. Useful to override ports etc.'),
- cfg.StrOpt(
- 'ssh_config_file_path', default='~/.ssh/config',
- help='Path to the ssh config file.'),
+ "use_ssh_config",
+ default=False,
+ help="Use the .ssh/config file. Useful to override ports etc.",
+ ),
+ cfg.StrOpt(
+ "ssh_config_file_path",
+ default="~/.ssh/config",
+ help="Path to the ssh config file.",
+ ),
cfg.IntOpt(
- 'ssh_connect_timeout', default=60,
- help='Max time in seconds to establish the SSH connection.')
+ "ssh_connect_timeout",
+ default=60,
+ help="Max time in seconds to establish the SSH connection.",
+ ),
]
- do_register_opts(ssh_runner_opts, group='ssh_runner')
+ do_register_opts(ssh_runner_opts, group="ssh_runner")
# Common options (used by action runner and sensor container)
action_sensor_opts = [
cfg.BoolOpt(
- 'enable', default=True,
- help='Whether to enable or disable the ability to post a trigger on action.'),
+ "enable",
+ default=True,
+ help="Whether to enable or disable the ability to post a trigger on action.",
+ ),
cfg.ListOpt(
- 'emit_when', default=LIVEACTION_COMPLETED_STATES,
- help='List of execution statuses for which a trigger will be emitted. ')
+ "emit_when",
+ default=LIVEACTION_COMPLETED_STATES,
+ help="List of execution statuses for which a trigger will be emitted. ",
+ ),
]
- do_register_opts(action_sensor_opts, group='action_sensor')
+ do_register_opts(action_sensor_opts, group="action_sensor")
# Common options for content
pack_lib_opts = [
cfg.BoolOpt(
- 'enable_common_libs', default=False,
- help='Enable/Disable support for pack common libs. '
- 'Setting this config to ``True`` would allow you to '
- 'place common library code for sensors and actions in lib/ folder '
- 'in packs and use them in python sensors and actions. '
- 'See https://docs.stackstorm.com/reference/'
- 'sharing_code_sensors_actions.html '
- 'for details.')
+ "enable_common_libs",
+ default=False,
+ help="Enable/Disable support for pack common libs. "
+ "Setting this config to ``True`` would allow you to "
+ "place common library code for sensors and actions in lib/ folder "
+ "in packs and use them in python sensors and actions. "
+ "See https://docs.stackstorm.com/reference/"
+ "sharing_code_sensors_actions.html "
+ "for details.",
+ )
]
- do_register_opts(pack_lib_opts, group='packs')
+ do_register_opts(pack_lib_opts, group="packs")
# Coordination options
coord_opts = [
- cfg.StrOpt(
- 'url', default=None,
- help='Endpoint for the coordination server.'),
+ cfg.StrOpt("url", default=None, help="Endpoint for the coordination server."),
cfg.IntOpt(
- 'lock_timeout', default=60,
- help='TTL for the lock if backend suports it.'),
+ "lock_timeout", default=60, help="TTL for the lock if backend suports it."
+ ),
cfg.BoolOpt(
- 'service_registry', default=False,
- help='True to register StackStorm services in a service registry.'),
+ "service_registry",
+ default=False,
+ help="True to register StackStorm services in a service registry.",
+ ),
]
- do_register_opts(coord_opts, 'coordination', ignore_errors)
+ do_register_opts(coord_opts, "coordination", ignore_errors)
# XXX: This is required for us to support deprecated config group results_tracker
query_opts = [
cfg.IntOpt(
- 'thread_pool_size',
- help='Number of threads to use to query external workflow systems.'),
+ "thread_pool_size",
+ help="Number of threads to use to query external workflow systems.",
+ ),
cfg.FloatOpt(
- 'query_interval',
- help='Time interval between subsequent queries for a context '
- 'to external workflow system.')
+ "query_interval",
+ help="Time interval between subsequent queries for a context "
+ "to external workflow system.",
+ ),
]
- do_register_opts(query_opts, group='results_tracker', ignore_errors=ignore_errors)
+ do_register_opts(query_opts, group="results_tracker", ignore_errors=ignore_errors)
# Common stream options
stream_opts = [
cfg.IntOpt(
- 'heartbeat', default=25,
- help='Send empty message every N seconds to keep connection open')
+ "heartbeat",
+ default=25,
+ help="Send empty message every N seconds to keep connection open",
+ )
]
- do_register_opts(stream_opts, group='stream', ignore_errors=ignore_errors)
+ do_register_opts(stream_opts, group="stream", ignore_errors=ignore_errors)
# Common CLI options
cli_opts = [
cfg.BoolOpt(
- 'debug', default=False,
- help='Enable debug mode. By default this will set all log levels to DEBUG.'),
+ "debug",
+ default=False,
+ help="Enable debug mode. By default this will set all log levels to DEBUG.",
+ ),
cfg.BoolOpt(
- 'profile', default=False,
- help='Enable profile mode. In the profile mode all the MongoDB queries and '
- 'related profile data are logged.'),
+ "profile",
+ default=False,
+ help="Enable profile mode. In the profile mode all the MongoDB queries and "
+ "related profile data are logged.",
+ ),
cfg.BoolOpt(
- 'use-debugger', default=True,
- help='Enables debugger. Note that using this option changes how the '
- 'eventlet library is used to support async IO. This could result in '
- 'failures that do not occur under normal operation.')
+ "use-debugger",
+ default=True,
+ help="Enables debugger. Note that using this option changes how the "
+ "eventlet library is used to support async IO. This could result in "
+ "failures that do not occur under normal operation.",
+ ),
]
do_register_cli_opts(cli_opts, ignore_errors=ignore_errors)
@@ -505,92 +603,121 @@ def register_opts(ignore_errors=False):
# Metrics Options stream options
metrics_opts = [
cfg.StrOpt(
- 'driver', default='noop',
- help='Driver type for metrics collection.'),
+ "driver", default="noop", help="Driver type for metrics collection."
+ ),
cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='Destination server to connect to if driver requires connection.'),
+ "host",
+ default="127.0.0.1",
+ help="Destination server to connect to if driver requires connection.",
+ ),
cfg.IntOpt(
- 'port', default=8125,
- help='Destination port to connect to if driver requires connection.'),
- cfg.StrOpt(
- 'prefix', default=None,
- help='Optional prefix which is prepended to all the metric names. Comes handy when '
- 'you want to submit metrics from various environment to the same metric '
- 'backend instance.'),
+ "port",
+ default=8125,
+ help="Destination port to connect to if driver requires connection.",
+ ),
+ cfg.StrOpt(
+ "prefix",
+ default=None,
+ help="Optional prefix which is prepended to all the metric names. Comes handy when "
+ "you want to submit metrics from various environment to the same metric "
+ "backend instance.",
+ ),
cfg.FloatOpt(
- 'sample_rate', default=1,
- help='Randomly sample and only send metrics for X% of metric operations to the '
- 'backend. Default value of 1 means no sampling is done and all the metrics are '
- 'sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.')
-
+ "sample_rate",
+ default=1,
+ help="Randomly sample and only send metrics for X% of metric operations to the "
+ "backend. Default value of 1 means no sampling is done and all the metrics are "
+ "sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.",
+ ),
]
- do_register_opts(metrics_opts, group='metrics', ignore_errors=ignore_errors)
+ do_register_opts(metrics_opts, group="metrics", ignore_errors=ignore_errors)
# Common timers engine options
timer_logging_opts = [
cfg.StrOpt(
- 'logging', default=None,
- help='Location of the logging configuration file. '
- 'NOTE: Deprecated in favor of timersengine.logging'),
+ "logging",
+ default=None,
+ help="Location of the logging configuration file. "
+ "NOTE: Deprecated in favor of timersengine.logging",
+ ),
]
timers_engine_logging_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.timersengine.conf',
- help='Location of the logging configuration file.')
+ "logging",
+ default="/etc/st2/logging.timersengine.conf",
+ help="Location of the logging configuration file.",
+ )
]
- do_register_opts(timer_logging_opts, group='timer', ignore_errors=ignore_errors)
- do_register_opts(timers_engine_logging_opts, group='timersengine', ignore_errors=ignore_errors)
+ do_register_opts(timer_logging_opts, group="timer", ignore_errors=ignore_errors)
+ do_register_opts(
+ timers_engine_logging_opts, group="timersengine", ignore_errors=ignore_errors
+ )
# NOTE: We default old style deprecated "timer" options to None so our code
# works correclty and "timersengine" has precedence over "timers"
# NOTE: "timer" section will be removed in v3.1
timer_opts = [
cfg.StrOpt(
- 'local_timezone', default=None,
- help='Timezone pertaining to the location where st2 is run. '
- 'NOTE: Deprecated in favor of timersengine.local_timezone'),
+ "local_timezone",
+ default=None,
+ help="Timezone pertaining to the location where st2 is run. "
+ "NOTE: Deprecated in favor of timersengine.local_timezone",
+ ),
cfg.BoolOpt(
- 'enable', default=None,
- help='Specify to enable timer service. '
- 'NOTE: Deprecated in favor of timersengine.enable'),
+ "enable",
+ default=None,
+ help="Specify to enable timer service. "
+ "NOTE: Deprecated in favor of timersengine.enable",
+ ),
]
timers_engine_opts = [
cfg.StrOpt(
- 'local_timezone', default='America/Los_Angeles',
- help='Timezone pertaining to the location where st2 is run.'),
- cfg.BoolOpt(
- 'enable', default=True,
- help='Specify to enable timer service.')
+ "local_timezone",
+ default="America/Los_Angeles",
+ help="Timezone pertaining to the location where st2 is run.",
+ ),
+ cfg.BoolOpt("enable", default=True, help="Specify to enable timer service."),
]
- do_register_opts(timer_opts, group='timer', ignore_errors=ignore_errors)
- do_register_opts(timers_engine_opts, group='timersengine', ignore_errors=ignore_errors)
+ do_register_opts(timer_opts, group="timer", ignore_errors=ignore_errors)
+ do_register_opts(
+ timers_engine_opts, group="timersengine", ignore_errors=ignore_errors
+ )
# Workflow engine options
workflow_engine_opts = [
cfg.IntOpt(
- 'retry_stop_max_msec', default=60000,
- help='Max time to stop retrying.'),
+ "retry_stop_max_msec", default=60000, help="Max time to stop retrying."
+ ),
cfg.IntOpt(
- 'retry_wait_fixed_msec', default=1000,
- help='Interval inbetween retries.'),
+ "retry_wait_fixed_msec", default=1000, help="Interval inbetween retries."
+ ),
cfg.FloatOpt(
- 'retry_max_jitter_msec', default=1000,
- help='Max jitter interval to smooth out retries.'),
+ "retry_max_jitter_msec",
+ default=1000,
+ help="Max jitter interval to smooth out retries.",
+ ),
cfg.IntOpt(
- 'gc_max_idle_sec', default=0,
- help='Max seconds to allow workflow execution be idled before it is identified as '
- 'orphaned and cancelled by the garbage collector. A value of zero means the '
- 'feature is disabled. This is disabled by default.')
+ "gc_max_idle_sec",
+ default=0,
+ help="Max seconds to allow workflow execution be idled before it is identified as "
+ "orphaned and cancelled by the garbage collector. A value of zero means the "
+ "feature is disabled. This is disabled by default.",
+ ),
]
- do_register_opts(workflow_engine_opts, group='workflow_engine', ignore_errors=ignore_errors)
+ do_register_opts(
+ workflow_engine_opts, group="workflow_engine", ignore_errors=ignore_errors
+ )
def parse_args(args=None):
register_opts()
- cfg.CONF(args=args, version=VERSION_STRING, default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
diff --git a/st2common/st2common/constants/action.py b/st2common/st2common/constants/action.py
index c28725f225..5587b0be91 100644
--- a/st2common/st2common/constants/action.py
+++ b/st2common/st2common/constants/action.py
@@ -14,61 +14,56 @@
# limitations under the License.
__all__ = [
- 'ACTION_NAME',
- 'ACTION_ID',
-
- 'LIBS_DIR',
-
- 'LIVEACTION_STATUS_REQUESTED',
- 'LIVEACTION_STATUS_SCHEDULED',
- 'LIVEACTION_STATUS_DELAYED',
- 'LIVEACTION_STATUS_RUNNING',
- 'LIVEACTION_STATUS_SUCCEEDED',
- 'LIVEACTION_STATUS_FAILED',
- 'LIVEACTION_STATUS_TIMED_OUT',
- 'LIVEACTION_STATUS_CANCELING',
- 'LIVEACTION_STATUS_CANCELED',
- 'LIVEACTION_STATUS_PENDING',
- 'LIVEACTION_STATUS_PAUSING',
- 'LIVEACTION_STATUS_PAUSED',
- 'LIVEACTION_STATUS_RESUMING',
-
- 'LIVEACTION_STATUSES',
- 'LIVEACTION_RUNNABLE_STATES',
- 'LIVEACTION_DELAYED_STATES',
- 'LIVEACTION_CANCELABLE_STATES',
- 'LIVEACTION_FAILED_STATES',
- 'LIVEACTION_COMPLETED_STATES',
-
- 'ACTION_OUTPUT_RESULT_DELIMITER',
- 'ACTION_CONTEXT_KV_PREFIX',
- 'ACTION_PARAMETERS_KV_PREFIX',
- 'ACTION_RESULTS_KV_PREFIX',
-
- 'WORKFLOW_RUNNER_TYPES'
+ "ACTION_NAME",
+ "ACTION_ID",
+ "LIBS_DIR",
+ "LIVEACTION_STATUS_REQUESTED",
+ "LIVEACTION_STATUS_SCHEDULED",
+ "LIVEACTION_STATUS_DELAYED",
+ "LIVEACTION_STATUS_RUNNING",
+ "LIVEACTION_STATUS_SUCCEEDED",
+ "LIVEACTION_STATUS_FAILED",
+ "LIVEACTION_STATUS_TIMED_OUT",
+ "LIVEACTION_STATUS_CANCELING",
+ "LIVEACTION_STATUS_CANCELED",
+ "LIVEACTION_STATUS_PENDING",
+ "LIVEACTION_STATUS_PAUSING",
+ "LIVEACTION_STATUS_PAUSED",
+ "LIVEACTION_STATUS_RESUMING",
+ "LIVEACTION_STATUSES",
+ "LIVEACTION_RUNNABLE_STATES",
+ "LIVEACTION_DELAYED_STATES",
+ "LIVEACTION_CANCELABLE_STATES",
+ "LIVEACTION_FAILED_STATES",
+ "LIVEACTION_COMPLETED_STATES",
+ "ACTION_OUTPUT_RESULT_DELIMITER",
+ "ACTION_CONTEXT_KV_PREFIX",
+ "ACTION_PARAMETERS_KV_PREFIX",
+ "ACTION_RESULTS_KV_PREFIX",
+ "WORKFLOW_RUNNER_TYPES",
]
-ACTION_NAME = 'name'
-ACTION_ID = 'id'
-ACTION_PACK = 'pack'
+ACTION_NAME = "name"
+ACTION_ID = "id"
+ACTION_PACK = "pack"
-LIBS_DIR = 'lib'
+LIBS_DIR = "lib"
-LIVEACTION_STATUS_REQUESTED = 'requested'
-LIVEACTION_STATUS_SCHEDULED = 'scheduled'
-LIVEACTION_STATUS_DELAYED = 'delayed'
-LIVEACTION_STATUS_RUNNING = 'running'
-LIVEACTION_STATUS_SUCCEEDED = 'succeeded'
-LIVEACTION_STATUS_FAILED = 'failed'
-LIVEACTION_STATUS_TIMED_OUT = 'timeout'
-LIVEACTION_STATUS_ABANDONED = 'abandoned'
-LIVEACTION_STATUS_CANCELING = 'canceling'
-LIVEACTION_STATUS_CANCELED = 'canceled'
-LIVEACTION_STATUS_PENDING = 'pending'
-LIVEACTION_STATUS_PAUSING = 'pausing'
-LIVEACTION_STATUS_PAUSED = 'paused'
-LIVEACTION_STATUS_RESUMING = 'resuming'
+LIVEACTION_STATUS_REQUESTED = "requested"
+LIVEACTION_STATUS_SCHEDULED = "scheduled"
+LIVEACTION_STATUS_DELAYED = "delayed"
+LIVEACTION_STATUS_RUNNING = "running"
+LIVEACTION_STATUS_SUCCEEDED = "succeeded"
+LIVEACTION_STATUS_FAILED = "failed"
+LIVEACTION_STATUS_TIMED_OUT = "timeout"
+LIVEACTION_STATUS_ABANDONED = "abandoned"
+LIVEACTION_STATUS_CANCELING = "canceling"
+LIVEACTION_STATUS_CANCELED = "canceled"
+LIVEACTION_STATUS_PENDING = "pending"
+LIVEACTION_STATUS_PAUSING = "pausing"
+LIVEACTION_STATUS_PAUSED = "paused"
+LIVEACTION_STATUS_RESUMING = "resuming"
LIVEACTION_STATUSES = [
LIVEACTION_STATUS_REQUESTED,
@@ -84,25 +79,23 @@
LIVEACTION_STATUS_PENDING,
LIVEACTION_STATUS_PAUSING,
LIVEACTION_STATUS_PAUSED,
- LIVEACTION_STATUS_RESUMING
+ LIVEACTION_STATUS_RESUMING,
]
-ACTION_OUTPUT_RESULT_DELIMITER = '%%%%%~=~=~=************=~=~=~%%%%'
-ACTION_CONTEXT_KV_PREFIX = 'action_context'
-ACTION_PARAMETERS_KV_PREFIX = 'action_parameters'
-ACTION_RESULTS_KV_PREFIX = 'action_results'
+ACTION_OUTPUT_RESULT_DELIMITER = "%%%%%~=~=~=************=~=~=~%%%%"
+ACTION_CONTEXT_KV_PREFIX = "action_context"
+ACTION_PARAMETERS_KV_PREFIX = "action_parameters"
+ACTION_RESULTS_KV_PREFIX = "action_results"
LIVEACTION_RUNNABLE_STATES = [
LIVEACTION_STATUS_REQUESTED,
LIVEACTION_STATUS_SCHEDULED,
LIVEACTION_STATUS_PAUSING,
LIVEACTION_STATUS_PAUSED,
- LIVEACTION_STATUS_RESUMING
+ LIVEACTION_STATUS_RESUMING,
]
-LIVEACTION_DELAYED_STATES = [
- LIVEACTION_STATUS_DELAYED
-]
+LIVEACTION_DELAYED_STATES = [LIVEACTION_STATUS_DELAYED]
LIVEACTION_CANCELABLE_STATES = [
LIVEACTION_STATUS_REQUESTED,
@@ -111,7 +104,7 @@
LIVEACTION_STATUS_RUNNING,
LIVEACTION_STATUS_PAUSING,
LIVEACTION_STATUS_PAUSED,
- LIVEACTION_STATUS_RESUMING
+ LIVEACTION_STATUS_RESUMING,
]
LIVEACTION_COMPLETED_STATES = [
@@ -119,29 +112,20 @@
LIVEACTION_STATUS_FAILED,
LIVEACTION_STATUS_TIMED_OUT,
LIVEACTION_STATUS_CANCELED,
- LIVEACTION_STATUS_ABANDONED
+ LIVEACTION_STATUS_ABANDONED,
]
LIVEACTION_FAILED_STATES = [
LIVEACTION_STATUS_FAILED,
LIVEACTION_STATUS_TIMED_OUT,
- LIVEACTION_STATUS_ABANDONED
+ LIVEACTION_STATUS_ABANDONED,
]
-LIVEACTION_PAUSE_STATES = [
- LIVEACTION_STATUS_PAUSING,
- LIVEACTION_STATUS_PAUSED
-]
+LIVEACTION_PAUSE_STATES = [LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED]
-LIVEACTION_CANCEL_STATES = [
- LIVEACTION_STATUS_CANCELING,
- LIVEACTION_STATUS_CANCELED
-]
+LIVEACTION_CANCEL_STATES = [LIVEACTION_STATUS_CANCELING, LIVEACTION_STATUS_CANCELED]
-WORKFLOW_RUNNER_TYPES = [
- 'action-chain',
- 'orquesta'
-]
+WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"]
# Linux's limit for param size
_LINUX_PARAM_LIMIT = 131072
diff --git a/st2common/st2common/constants/api.py b/st2common/st2common/constants/api.py
index c1df81fb0d..2690133314 100644
--- a/st2common/st2common/constants/api.py
+++ b/st2common/st2common/constants/api.py
@@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'DEFAULT_API_VERSION'
-]
+__all__ = ["DEFAULT_API_VERSION"]
-DEFAULT_API_VERSION = 'v1'
+DEFAULT_API_VERSION = "v1"
-REQUEST_ID_HEADER = 'X-Request-ID'
+REQUEST_ID_HEADER = "X-Request-ID"
diff --git a/st2common/st2common/constants/auth.py b/st2common/st2common/constants/auth.py
index f0664739ce..7b4003c0ef 100644
--- a/st2common/st2common/constants/auth.py
+++ b/st2common/st2common/constants/auth.py
@@ -14,26 +14,22 @@
# limitations under the License.
__all__ = [
- 'VALID_MODES',
- 'DEFAULT_MODE',
- 'DEFAULT_BACKEND',
-
- 'HEADER_ATTRIBUTE_NAME',
- 'QUERY_PARAM_ATTRIBUTE_NAME'
+ "VALID_MODES",
+ "DEFAULT_MODE",
+ "DEFAULT_BACKEND",
+ "HEADER_ATTRIBUTE_NAME",
+ "QUERY_PARAM_ATTRIBUTE_NAME",
]
-VALID_MODES = [
- 'proxy',
- 'standalone'
-]
+VALID_MODES = ["proxy", "standalone"]
-HEADER_ATTRIBUTE_NAME = 'X-Auth-Token'
-QUERY_PARAM_ATTRIBUTE_NAME = 'x-auth-token'
+HEADER_ATTRIBUTE_NAME = "X-Auth-Token"
+QUERY_PARAM_ATTRIBUTE_NAME = "x-auth-token"
-HEADER_API_KEY_ATTRIBUTE_NAME = 'St2-Api-Key'
-QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = 'st2-api-key'
+HEADER_API_KEY_ATTRIBUTE_NAME = "St2-Api-Key"
+QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = "st2-api-key"
-DEFAULT_MODE = 'standalone'
+DEFAULT_MODE = "standalone"
-DEFAULT_BACKEND = 'flat_file'
-DEFAULT_SSO_BACKEND = 'noop'
+DEFAULT_BACKEND = "flat_file"
+DEFAULT_SSO_BACKEND = "noop"
diff --git a/st2common/st2common/constants/error_messages.py b/st2common/st2common/constants/error_messages.py
index 7aa56c4025..7c70377721 100644
--- a/st2common/st2common/constants/error_messages.py
+++ b/st2common/st2common/constants/error_messages.py
@@ -13,21 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'PACK_VIRTUALENV_DOESNT_EXIST',
- 'PYTHON2_DEPRECATION'
-]
+__all__ = ["PACK_VIRTUALENV_DOESNT_EXIST", "PYTHON2_DEPRECATION"]
-PACK_VIRTUALENV_DOESNT_EXIST = '''
+PACK_VIRTUALENV_DOESNT_EXIST = """
The virtual environment (%(virtualenv_path)s) for pack "%(pack)s" does not exist. Normally this is
created when you install a pack using "st2 pack install". If you installed your pack by some other
means, you can create a new virtual environment using the command:
"st2 run packs.setup_virtualenv packs=%(pack)s"
-'''
+"""
-PYTHON2_DEPRECATION = 'DEPRECATION WARNING. Support for python 2 will be removed in future ' \
- 'StackStorm releases. Please ensure that all packs used are python ' \
- '3 compatible. Your StackStorm installation may be upgraded from ' \
- 'python 2 to python 3 in future platform releases. It is recommended ' \
- 'to plan the manual migration to a python 3 native platform, e.g. ' \
- 'Ubuntu 18.04 LTS or CentOS/RHEL 8.'
+PYTHON2_DEPRECATION = (
+ "DEPRECATION WARNING. Support for python 2 will be removed in future "
+ "StackStorm releases. Please ensure that all packs used are python "
+ "3 compatible. Your StackStorm installation may be upgraded from "
+ "python 2 to python 3 in future platform releases. It is recommended "
+ "to plan the manual migration to a python 3 native platform, e.g. "
+ "Ubuntu 18.04 LTS or CentOS/RHEL 8."
+)
diff --git a/st2common/st2common/constants/exit_codes.py b/st2common/st2common/constants/exit_codes.py
index 8fd1efd9a7..1b32e89e26 100644
--- a/st2common/st2common/constants/exit_codes.py
+++ b/st2common/st2common/constants/exit_codes.py
@@ -14,10 +14,10 @@
# limitations under the License.
__all__ = [
- 'SUCCESS_EXIT_CODE',
- 'FAILURE_EXIT_CODE',
- 'SIGKILL_EXIT_CODE',
- 'SIGTERM_EXIT_CODE'
+ "SUCCESS_EXIT_CODE",
+ "FAILURE_EXIT_CODE",
+ "SIGKILL_EXIT_CODE",
+ "SIGTERM_EXIT_CODE",
]
SUCCESS_EXIT_CODE = 0
diff --git a/st2common/st2common/constants/garbage_collection.py b/st2common/st2common/constants/garbage_collection.py
index dad3121896..ac8a2aac5f 100644
--- a/st2common/st2common/constants/garbage_collection.py
+++ b/st2common/st2common/constants/garbage_collection.py
@@ -14,10 +14,10 @@
# limitations under the License.
__all__ = [
- 'DEFAULT_COLLECTION_INTERVAL',
- 'DEFAULT_SLEEP_DELAY',
- 'MINIMUM_TTL_DAYS',
- 'MINIMUM_TTL_DAYS_EXECUTION_OUTPUT'
+ "DEFAULT_COLLECTION_INTERVAL",
+ "DEFAULT_SLEEP_DELAY",
+ "MINIMUM_TTL_DAYS",
+ "MINIMUM_TTL_DAYS_EXECUTION_OUTPUT",
]
diff --git a/st2common/st2common/constants/keyvalue.py b/st2common/st2common/constants/keyvalue.py
index 2897f1e32d..7a21eab8ec 100644
--- a/st2common/st2common/constants/keyvalue.py
+++ b/st2common/st2common/constants/keyvalue.py
@@ -14,46 +14,49 @@
# limitations under the License.
__all__ = [
- 'ALLOWED_SCOPES',
- 'SYSTEM_SCOPE',
- 'FULL_SYSTEM_SCOPE',
- 'SYSTEM_SCOPES',
- 'USER_SCOPE',
- 'FULL_USER_SCOPE',
- 'USER_SCOPES',
- 'USER_SEPARATOR',
-
- 'DATASTORE_SCOPE_SEPARATOR',
- 'DATASTORE_KEY_SEPARATOR'
+ "ALLOWED_SCOPES",
+ "SYSTEM_SCOPE",
+ "FULL_SYSTEM_SCOPE",
+ "SYSTEM_SCOPES",
+ "USER_SCOPE",
+ "FULL_USER_SCOPE",
+ "USER_SCOPES",
+ "USER_SEPARATOR",
+ "DATASTORE_SCOPE_SEPARATOR",
+ "DATASTORE_KEY_SEPARATOR",
]
-ALL_SCOPE = 'all'
+ALL_SCOPE = "all"
# Parent namespace for all items in key-value store
-DATASTORE_PARENT_SCOPE = 'st2kv'
-DATASTORE_SCOPE_SEPARATOR = '.' # To separate scope from datastore namespace. E.g. st2kv.system
+DATASTORE_PARENT_SCOPE = "st2kv"
+DATASTORE_SCOPE_SEPARATOR = (
+ "." # To separate scope from datastore namespace. E.g. st2kv.system
+)
# Namespace to contain all system/global scoped variables in key-value store.
-SYSTEM_SCOPE = 'system'
-FULL_SYSTEM_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, SYSTEM_SCOPE)
+SYSTEM_SCOPE = "system"
+FULL_SYSTEM_SCOPE = "%s%s%s" % (
+ DATASTORE_PARENT_SCOPE,
+ DATASTORE_SCOPE_SEPARATOR,
+ SYSTEM_SCOPE,
+)
SYSTEM_SCOPES = [SYSTEM_SCOPE]
# Namespace to contain all user scoped variables in key-value store.
-USER_SCOPE = 'user'
-FULL_USER_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, USER_SCOPE)
+USER_SCOPE = "user"
+FULL_USER_SCOPE = "%s%s%s" % (
+ DATASTORE_PARENT_SCOPE,
+ DATASTORE_SCOPE_SEPARATOR,
+ USER_SCOPE,
+)
USER_SCOPES = [USER_SCOPE]
-USER_SEPARATOR = ':'
+USER_SEPARATOR = ":"
# Separator for keys in the datastore
-DATASTORE_KEY_SEPARATOR = ':'
-
-ALLOWED_SCOPES = [
- SYSTEM_SCOPE,
- USER_SCOPE,
+DATASTORE_KEY_SEPARATOR = ":"
- FULL_SYSTEM_SCOPE,
- FULL_USER_SCOPE
-]
+ALLOWED_SCOPES = [SYSTEM_SCOPE, USER_SCOPE, FULL_SYSTEM_SCOPE, FULL_USER_SCOPE]
diff --git a/st2common/st2common/constants/logging.py b/st2common/st2common/constants/logging.py
index b62a59bd00..0985a03947 100644
--- a/st2common/st2common/constants/logging.py
+++ b/st2common/st2common/constants/logging.py
@@ -16,11 +16,9 @@
from __future__ import absolute_import
import os
-__all__ = [
- 'DEFAULT_LOGGING_CONF_PATH'
-]
+__all__ = ["DEFAULT_LOGGING_CONF_PATH"]
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
-DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, '../conf/base.logging.conf')
+DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, "../conf/base.logging.conf")
DEFAULT_LOGGING_CONF_PATH = os.path.abspath(DEFAULT_LOGGING_CONF_PATH)
diff --git a/st2common/st2common/constants/meta.py b/st2common/st2common/constants/meta.py
index ac4859b5e1..acd348a355 100644
--- a/st2common/st2common/constants/meta.py
+++ b/st2common/st2common/constants/meta.py
@@ -16,10 +16,7 @@
from __future__ import absolute_import
import yaml
-__all__ = [
- 'ALLOWED_EXTS',
- 'PARSER_FUNCS'
-]
+__all__ = ["ALLOWED_EXTS", "PARSER_FUNCS"]
-ALLOWED_EXTS = ['.yaml', '.yml']
-PARSER_FUNCS = {'.yml': yaml.safe_load, '.yaml': yaml.safe_load}
+ALLOWED_EXTS = [".yaml", ".yml"]
+PARSER_FUNCS = {".yml": yaml.safe_load, ".yaml": yaml.safe_load}
diff --git a/st2common/st2common/constants/pack.py b/st2common/st2common/constants/pack.py
index 91ae5a5e2c..f782a6920c 100644
--- a/st2common/st2common/constants/pack.py
+++ b/st2common/st2common/constants/pack.py
@@ -14,81 +14,74 @@
# limitations under the License.
__all__ = [
- 'PACKS_PACK_NAME',
- 'PACK_REF_WHITELIST_REGEX',
- 'PACK_RESERVED_CHARACTERS',
- 'PACK_VERSION_SEPARATOR',
- 'PACK_VERSION_REGEX',
- 'ST2_VERSION_REGEX',
- 'SYSTEM_PACK_NAME',
- 'PACKS_PACK_NAME',
- 'LINUX_PACK_NAME',
- 'SYSTEM_PACK_NAMES',
- 'CHATOPS_PACK_NAME',
- 'USER_PACK_NAME_BLACKLIST',
- 'BASE_PACK_REQUIREMENTS',
- 'MANIFEST_FILE_NAME',
- 'CONFIG_SCHEMA_FILE_NAME'
+ "PACKS_PACK_NAME",
+ "PACK_REF_WHITELIST_REGEX",
+ "PACK_RESERVED_CHARACTERS",
+ "PACK_VERSION_SEPARATOR",
+ "PACK_VERSION_REGEX",
+ "ST2_VERSION_REGEX",
+ "SYSTEM_PACK_NAME",
+ "PACKS_PACK_NAME",
+ "LINUX_PACK_NAME",
+ "SYSTEM_PACK_NAMES",
+ "CHATOPS_PACK_NAME",
+ "USER_PACK_NAME_BLACKLIST",
+ "BASE_PACK_REQUIREMENTS",
+ "MANIFEST_FILE_NAME",
+ "CONFIG_SCHEMA_FILE_NAME",
]
# Prefix for render context w/ config
-PACK_CONFIG_CONTEXT_KV_PREFIX = 'config_context'
+PACK_CONFIG_CONTEXT_KV_PREFIX = "config_context"
# A list of allowed characters for the pack name
-PACK_REF_WHITELIST_REGEX = r'^[a-z0-9_]+$'
+PACK_REF_WHITELIST_REGEX = r"^[a-z0-9_]+$"
# Check for a valid semver string
-PACK_VERSION_REGEX = r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$' # noqa
+PACK_VERSION_REGEX = r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$" # noqa
# Special characters which can't be used in pack names
-PACK_RESERVED_CHARACTERS = [
- '.'
-]
+PACK_RESERVED_CHARACTERS = ["."]
# Version sperator when version is supplied in pack name
# Example: libcloud@1.0.1
-PACK_VERSION_SEPARATOR = '='
+PACK_VERSION_SEPARATOR = "="
# Check for st2 version in engines
-ST2_VERSION_REGEX = r'^((>?>|>=|=|<=|<)\s*[0-9]+\.[0-9]+\.[0-9]+?(\s*,)?\s*)+$'
+ST2_VERSION_REGEX = r"^((>?>|>=|=|<=|<)\s*[0-9]+\.[0-9]+\.[0-9]+?(\s*,)?\s*)+$"
# Name used for system pack
-SYSTEM_PACK_NAME = 'core'
+SYSTEM_PACK_NAME = "core"
# Name used for pack management pack
-PACKS_PACK_NAME = 'packs'
+PACKS_PACK_NAME = "packs"
# Name used for linux pack
-LINUX_PACK_NAME = 'linux'
+LINUX_PACK_NAME = "linux"
# Name of the default pack
-DEFAULT_PACK_NAME = 'default'
+DEFAULT_PACK_NAME = "default"
# Name of the chatops pack
-CHATOPS_PACK_NAME = 'chatops'
+CHATOPS_PACK_NAME = "chatops"
# A list of system pack names
SYSTEM_PACK_NAMES = [
CHATOPS_PACK_NAME,
SYSTEM_PACK_NAME,
PACKS_PACK_NAME,
- LINUX_PACK_NAME
+ LINUX_PACK_NAME,
]
# A list of pack names which can't be used by user-supplied packs
-USER_PACK_NAME_BLACKLIST = [
- SYSTEM_PACK_NAME,
- PACKS_PACK_NAME
-]
+USER_PACK_NAME_BLACKLIST = [SYSTEM_PACK_NAME, PACKS_PACK_NAME]
# Python requirements which are common to all the packs and are installed into the Python pack
# sandbox (virtualenv)
-BASE_PACK_REQUIREMENTS = [
- 'six>=1.9.0,<2.0'
-]
+BASE_PACK_REQUIREMENTS = ["six>=1.9.0,<2.0"]
# Name of the pack manifest file
-MANIFEST_FILE_NAME = 'pack.yaml'
+MANIFEST_FILE_NAME = "pack.yaml"
# File name for the config schema file
-CONFIG_SCHEMA_FILE_NAME = 'config.schema.yaml'
+CONFIG_SCHEMA_FILE_NAME = "config.schema.yaml"
diff --git a/st2common/st2common/constants/policy.py b/st2common/st2common/constants/policy.py
index e36ce8fc12..7ce7093ed5 100644
--- a/st2common/st2common/constants/policy.py
+++ b/st2common/st2common/constants/policy.py
@@ -13,13 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'POLICY_TYPES_REQUIRING_LOCK'
-]
+__all__ = ["POLICY_TYPES_REQUIRING_LOCK"]
# Concurrency policies require scheduler to acquire a distributed lock to prevent race
# in scheduling when there are multiple scheduler instances.
-POLICY_TYPES_REQUIRING_LOCK = [
- 'action.concurrency',
- 'action.concurrency.attr'
-]
+POLICY_TYPES_REQUIRING_LOCK = ["action.concurrency", "action.concurrency.attr"]
diff --git a/st2common/st2common/constants/rule_enforcement.py b/st2common/st2common/constants/rule_enforcement.py
index fced450304..ceece2d6e1 100644
--- a/st2common/st2common/constants/rule_enforcement.py
+++ b/st2common/st2common/constants/rule_enforcement.py
@@ -14,16 +14,15 @@
# limitations under the License.
__all__ = [
- 'RULE_ENFORCEMENT_STATUS_SUCCEEDED',
- 'RULE_ENFORCEMENT_STATUS_FAILED',
-
- 'RULE_ENFORCEMENT_STATUSES'
+ "RULE_ENFORCEMENT_STATUS_SUCCEEDED",
+ "RULE_ENFORCEMENT_STATUS_FAILED",
+ "RULE_ENFORCEMENT_STATUSES",
]
-RULE_ENFORCEMENT_STATUS_SUCCEEDED = 'succeeded'
-RULE_ENFORCEMENT_STATUS_FAILED = 'failed'
+RULE_ENFORCEMENT_STATUS_SUCCEEDED = "succeeded"
+RULE_ENFORCEMENT_STATUS_FAILED = "failed"
RULE_ENFORCEMENT_STATUSES = [
RULE_ENFORCEMENT_STATUS_SUCCEEDED,
- RULE_ENFORCEMENT_STATUS_FAILED
+ RULE_ENFORCEMENT_STATUS_FAILED,
]
diff --git a/st2common/st2common/constants/rules.py b/st2common/st2common/constants/rules.py
index 393e94aebb..929e4b5e92 100644
--- a/st2common/st2common/constants/rules.py
+++ b/st2common/st2common/constants/rules.py
@@ -13,10 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-TRIGGER_PAYLOAD_PREFIX = 'trigger'
-TRIGGER_ITEM_PAYLOAD_PREFIX = 'item'
+TRIGGER_PAYLOAD_PREFIX = "trigger"
+TRIGGER_ITEM_PAYLOAD_PREFIX = "item"
-RULE_TYPE_STANDARD = 'standard'
-RULE_TYPE_BACKSTOP = 'backstop'
+RULE_TYPE_STANDARD = "standard"
+RULE_TYPE_BACKSTOP = "backstop"
-MATCH_CRITERIA = r'({{)\s*(.*)\s*(}})'
+MATCH_CRITERIA = r"({{)\s*(.*)\s*(}})"
diff --git a/st2common/st2common/constants/runners.py b/st2common/st2common/constants/runners.py
index fe78a6497f..52ec738384 100644
--- a/st2common/st2common/constants/runners.py
+++ b/st2common/st2common/constants/runners.py
@@ -17,36 +17,28 @@
from oslo_config import cfg
__all__ = [
- 'RUNNER_NAME_WHITELIST',
-
- 'MANIFEST_FILE_NAME',
-
- 'LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT',
-
- 'REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT',
- 'REMOTE_RUNNER_DEFAULT_REMOTE_DIR',
- 'REMOTE_RUNNER_PRIVATE_KEY_HEADER',
-
- 'PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT',
- 'PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE',
-
- 'WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT',
-
- 'COMMON_ACTION_ENV_VARIABLE_PREFIX',
- 'COMMON_ACTION_ENV_VARIABLES',
-
- 'DEFAULT_SSH_PORT',
-
- 'RUNNERS_NAMESPACE'
+ "RUNNER_NAME_WHITELIST",
+ "MANIFEST_FILE_NAME",
+ "LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT",
+ "REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT",
+ "REMOTE_RUNNER_DEFAULT_REMOTE_DIR",
+ "REMOTE_RUNNER_PRIVATE_KEY_HEADER",
+ "PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT",
+ "PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE",
+ "WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT",
+ "COMMON_ACTION_ENV_VARIABLE_PREFIX",
+ "COMMON_ACTION_ENV_VARIABLES",
+ "DEFAULT_SSH_PORT",
+ "RUNNERS_NAMESPACE",
]
DEFAULT_SSH_PORT = 22
# A list of allowed characters for the pack name
-RUNNER_NAME_WHITELIST = r'^[A-Za-z0-9_-]+'
+RUNNER_NAME_WHITELIST = r"^[A-Za-z0-9_-]+"
# Manifest file name for runners
-MANIFEST_FILE_NAME = 'runner.yaml'
+MANIFEST_FILE_NAME = "runner.yaml"
# Local runner
LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT = 60
@@ -57,9 +49,9 @@
try:
REMOTE_RUNNER_DEFAULT_REMOTE_DIR = cfg.CONF.ssh_runner.remote_dir
except:
- REMOTE_RUNNER_DEFAULT_REMOTE_DIR = '/tmp'
+ REMOTE_RUNNER_DEFAULT_REMOTE_DIR = "/tmp"
-REMOTE_RUNNER_PRIVATE_KEY_HEADER = 'PRIVATE KEY-----'.lower()
+REMOTE_RUNNER_PRIVATE_KEY_HEADER = "PRIVATE KEY-----".lower()
# Python runner
# Default timeout (in seconds) for actions executed by Python runner
@@ -69,20 +61,20 @@
# action returns invalid status from the run() method
PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE = 220
-PYTHON_RUNNER_DEFAULT_LOG_LEVEL = 'DEBUG'
+PYTHON_RUNNER_DEFAULT_LOG_LEVEL = "DEBUG"
# Windows runner
WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT = 10 * 60
# Prefix for common st2 environment variables which are available to the actions
-COMMON_ACTION_ENV_VARIABLE_PREFIX = 'ST2_ACTION_'
+COMMON_ACTION_ENV_VARIABLE_PREFIX = "ST2_ACTION_"
# Common st2 environment variables which are available to the actions
COMMON_ACTION_ENV_VARIABLES = [
- 'ST2_ACTION_PACK_NAME',
- 'ST2_ACTION_EXECUTION_ID',
- 'ST2_ACTION_API_URL',
- 'ST2_ACTION_AUTH_TOKEN'
+ "ST2_ACTION_PACK_NAME",
+ "ST2_ACTION_EXECUTION_ID",
+ "ST2_ACTION_API_URL",
+ "ST2_ACTION_AUTH_TOKEN",
]
# Namespaces for dynamically loaded runner modules
diff --git a/st2common/st2common/constants/scheduler.py b/st2common/st2common/constants/scheduler.py
index d825d2aed0..fb97971a3c 100644
--- a/st2common/st2common/constants/scheduler.py
+++ b/st2common/st2common/constants/scheduler.py
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'SCHEDULER_ENABLED_LOG_LINE',
- 'SCHEDULER_DISABLED_LOG_LINE'
-]
+__all__ = ["SCHEDULER_ENABLED_LOG_LINE", "SCHEDULER_DISABLED_LOG_LINE"]
# Integration tests look for these loglines to validate scheduler enable/disable
-SCHEDULER_ENABLED_LOG_LINE = 'Scheduler is enabled.'
-SCHEDULER_DISABLED_LOG_LINE = 'Scheduler is disabled.'
+SCHEDULER_ENABLED_LOG_LINE = "Scheduler is enabled."
+SCHEDULER_DISABLED_LOG_LINE = "Scheduler is disabled."
diff --git a/st2common/st2common/constants/secrets.py b/st2common/st2common/constants/secrets.py
index d3f9e53b9e..ef9a02d5ee 100644
--- a/st2common/st2common/constants/secrets.py
+++ b/st2common/st2common/constants/secrets.py
@@ -13,22 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'MASKED_ATTRIBUTES_BLACKLIST',
- 'MASKED_ATTRIBUTE_VALUE'
-]
+__all__ = ["MASKED_ATTRIBUTES_BLACKLIST", "MASKED_ATTRIBUTE_VALUE"]
# A blacklist of attributes which should be masked in the log messages by default.
# Note: If an attribute is an object or a dict, we try to recursively process it and mask the
# values.
MASKED_ATTRIBUTES_BLACKLIST = [
- 'password',
- 'auth_token',
- 'token',
- 'secret',
- 'credentials',
- 'st2_auth_token'
+ "password",
+ "auth_token",
+ "token",
+ "secret",
+ "credentials",
+ "st2_auth_token",
]
# Value with which the masked attribute values are replaced
-MASKED_ATTRIBUTE_VALUE = '********'
+MASKED_ATTRIBUTE_VALUE = "********"
diff --git a/st2common/st2common/constants/sensors.py b/st2common/st2common/constants/sensors.py
index 3ba4f9487d..a2d7903d18 100644
--- a/st2common/st2common/constants/sensors.py
+++ b/st2common/st2common/constants/sensors.py
@@ -17,7 +17,7 @@
MINIMUM_POLL_INTERVAL = 4
# keys for PARTITION loaders
-DEFAULT_PARTITION_LOADER = 'default'
-KVSTORE_PARTITION_LOADER = 'kvstore'
-FILE_PARTITION_LOADER = 'file'
-HASH_PARTITION_LOADER = 'hash'
+DEFAULT_PARTITION_LOADER = "default"
+KVSTORE_PARTITION_LOADER = "kvstore"
+FILE_PARTITION_LOADER = "file"
+HASH_PARTITION_LOADER = "hash"
diff --git a/st2common/st2common/constants/system.py b/st2common/st2common/constants/system.py
index dcb8ee699c..9736527171 100644
--- a/st2common/st2common/constants/system.py
+++ b/st2common/st2common/constants/system.py
@@ -20,15 +20,14 @@
from st2common import __version__
__all__ = [
- 'VERSION_STRING',
- 'DEFAULT_CONFIG_FILE_PATH',
-
- 'API_URL_ENV_VARIABLE_NAME',
- 'AUTH_TOKEN_ENV_VARIABLE_NAME',
+ "VERSION_STRING",
+ "DEFAULT_CONFIG_FILE_PATH",
+ "API_URL_ENV_VARIABLE_NAME",
+ "AUTH_TOKEN_ENV_VARIABLE_NAME",
]
-VERSION_STRING = 'StackStorm v%s' % (__version__)
-DEFAULT_CONFIG_FILE_PATH = os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')
+VERSION_STRING = "StackStorm v%s" % (__version__)
+DEFAULT_CONFIG_FILE_PATH = os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf")
-API_URL_ENV_VARIABLE_NAME = 'ST2_API_URL'
-AUTH_TOKEN_ENV_VARIABLE_NAME = 'ST2_AUTH_TOKEN'
+API_URL_ENV_VARIABLE_NAME = "ST2_API_URL"
+AUTH_TOKEN_ENV_VARIABLE_NAME = "ST2_AUTH_TOKEN"
diff --git a/st2common/st2common/constants/timer.py b/st2common/st2common/constants/timer.py
index 0f191a8027..9772743792 100644
--- a/st2common/st2common/constants/timer.py
+++ b/st2common/st2common/constants/timer.py
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'TIMER_ENABLED_LOG_LINE',
- 'TIMER_DISABLED_LOG_LINE'
-]
+__all__ = ["TIMER_ENABLED_LOG_LINE", "TIMER_DISABLED_LOG_LINE"]
# Integration tests look for these loglines to validate timer enable/disable
-TIMER_ENABLED_LOG_LINE = 'Timer is enabled.'
-TIMER_DISABLED_LOG_LINE = 'Timer is disabled.'
+TIMER_ENABLED_LOG_LINE = "Timer is enabled."
+TIMER_DISABLED_LOG_LINE = "Timer is disabled."
diff --git a/st2common/st2common/constants/trace.py b/st2common/st2common/constants/trace.py
index d900912c60..f7e4242da1 100644
--- a/st2common/st2common/constants/trace.py
+++ b/st2common/st2common/constants/trace.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = ['TRACE_CONTEXT', 'TRACE_ID']
+__all__ = ["TRACE_CONTEXT", "TRACE_ID"]
-TRACE_CONTEXT = 'trace_context'
-TRACE_ID = 'trace_tag'
+TRACE_CONTEXT = "trace_context"
+TRACE_ID = "trace_tag"
diff --git a/st2common/st2common/constants/triggers.py b/st2common/st2common/constants/triggers.py
index 4a0ccc8e4e..14ab861fd5 100644
--- a/st2common/st2common/constants/triggers.py
+++ b/st2common/st2common/constants/triggers.py
@@ -18,244 +18,200 @@
from st2common.models.system.common import ResourceReference
__all__ = [
- 'WEBHOOKS_PARAMETERS_SCHEMA',
- 'WEBHOOKS_PAYLOAD_SCHEMA',
- 'INTERVAL_PARAMETERS_SCHEMA',
- 'DATE_PARAMETERS_SCHEMA',
- 'CRON_PARAMETERS_SCHEMA',
- 'TIMER_PAYLOAD_SCHEMA',
-
- 'ACTION_SENSOR_TRIGGER',
- 'NOTIFY_TRIGGER',
- 'ACTION_FILE_WRITTEN_TRIGGER',
- 'INQUIRY_TRIGGER',
-
- 'TIMER_TRIGGER_TYPES',
- 'WEBHOOK_TRIGGER_TYPES',
- 'WEBHOOK_TRIGGER_TYPE',
- 'INTERNAL_TRIGGER_TYPES',
- 'SYSTEM_TRIGGER_TYPES',
-
- 'INTERVAL_TIMER_TRIGGER_REF',
- 'DATE_TIMER_TRIGGER_REF',
- 'CRON_TIMER_TRIGGER_REF',
-
- 'TRIGGER_INSTANCE_STATUSES',
- 'TRIGGER_INSTANCE_PENDING',
- 'TRIGGER_INSTANCE_PROCESSING',
- 'TRIGGER_INSTANCE_PROCESSED',
- 'TRIGGER_INSTANCE_PROCESSING_FAILED'
+ "WEBHOOKS_PARAMETERS_SCHEMA",
+ "WEBHOOKS_PAYLOAD_SCHEMA",
+ "INTERVAL_PARAMETERS_SCHEMA",
+ "DATE_PARAMETERS_SCHEMA",
+ "CRON_PARAMETERS_SCHEMA",
+ "TIMER_PAYLOAD_SCHEMA",
+ "ACTION_SENSOR_TRIGGER",
+ "NOTIFY_TRIGGER",
+ "ACTION_FILE_WRITTEN_TRIGGER",
+ "INQUIRY_TRIGGER",
+ "TIMER_TRIGGER_TYPES",
+ "WEBHOOK_TRIGGER_TYPES",
+ "WEBHOOK_TRIGGER_TYPE",
+ "INTERNAL_TRIGGER_TYPES",
+ "SYSTEM_TRIGGER_TYPES",
+ "INTERVAL_TIMER_TRIGGER_REF",
+ "DATE_TIMER_TRIGGER_REF",
+ "CRON_TIMER_TRIGGER_REF",
+ "TRIGGER_INSTANCE_STATUSES",
+ "TRIGGER_INSTANCE_PENDING",
+ "TRIGGER_INSTANCE_PROCESSING",
+ "TRIGGER_INSTANCE_PROCESSED",
+ "TRIGGER_INSTANCE_PROCESSING_FAILED",
]
# Action resource triggers
ACTION_SENSOR_TRIGGER = {
- 'name': 'st2.generic.actiontrigger',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating the completion of an action execution.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'execution_id': {},
- 'status': {},
- 'start_timestamp': {},
- 'action_name': {},
- 'action_ref': {},
- 'runner_ref': {},
- 'parameters': {},
- 'result': {}
- }
- }
+ "name": "st2.generic.actiontrigger",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating the completion of an action execution.",
+ "payload_schema": {
+ "type": "object",
+ "properties": {
+ "execution_id": {},
+ "status": {},
+ "start_timestamp": {},
+ "action_name": {},
+ "action_ref": {},
+ "runner_ref": {},
+ "parameters": {},
+ "result": {},
+ },
+ },
}
ACTION_FILE_WRITTEN_TRIGGER = {
- 'name': 'st2.action.file_written',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating action file being written on disk.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'ref': {},
- 'file_path': {},
- 'host_info': {}
- }
- }
+ "name": "st2.action.file_written",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating action file being written on disk.",
+ "payload_schema": {
+ "type": "object",
+ "properties": {"ref": {}, "file_path": {}, "host_info": {}},
+ },
}
NOTIFY_TRIGGER = {
- 'name': 'st2.generic.notifytrigger',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Notification trigger.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'execution_id': {},
- 'status': {},
- 'start_timestamp': {},
- 'end_timestamp': {},
- 'action_ref': {},
- 'runner_ref': {},
- 'channel': {},
- 'route': {},
- 'message': {},
- 'data': {}
- }
- }
+ "name": "st2.generic.notifytrigger",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Notification trigger.",
+ "payload_schema": {
+ "type": "object",
+ "properties": {
+ "execution_id": {},
+ "status": {},
+ "start_timestamp": {},
+ "end_timestamp": {},
+ "action_ref": {},
+ "runner_ref": {},
+ "channel": {},
+ "route": {},
+ "message": {},
+ "data": {},
+ },
+ },
}
INQUIRY_TRIGGER = {
- 'name': 'st2.generic.inquiry',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger indicating a new "inquiry" has entered "pending" status',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'description': 'ID of the new inquiry.',
- 'required': True
+ "name": "st2.generic.inquiry",
+ "pack": SYSTEM_PACK_NAME,
+ "description": 'Trigger indicating a new "inquiry" has entered "pending" status',
+ "payload_schema": {
+ "type": "object",
+ "properties": {
+ "id": {
+ "type": "string",
+ "description": "ID of the new inquiry.",
+ "required": True,
+ },
+ "route": {
+ "type": "string",
+ "description": "An arbitrary value for allowing rules "
+ "to route to proper notification channel.",
+ "required": True,
},
- 'route': {
- 'type': 'string',
- 'description': 'An arbitrary value for allowing rules '
- 'to route to proper notification channel.',
- 'required': True
- }
},
- "additionalProperties": False
- }
+ "additionalProperties": False,
+ },
}
# Sensor spawn/exit triggers.
SENSOR_SPAWN_TRIGGER = {
- 'name': 'st2.sensor.process_spawn',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger indicating sensor process is started up.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'object': {}
- }
- }
+ "name": "st2.sensor.process_spawn",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger indicating sensor process is started up.",
+ "payload_schema": {"type": "object", "properties": {"object": {}}},
}
SENSOR_EXIT_TRIGGER = {
- 'name': 'st2.sensor.process_exit',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger indicating sensor process is stopped.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'object': {}
- }
- }
+ "name": "st2.sensor.process_exit",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger indicating sensor process is stopped.",
+ "payload_schema": {"type": "object", "properties": {"object": {}}},
}
# KeyValuePair resource triggers
KEY_VALUE_PAIR_CREATE_TRIGGER = {
- 'name': 'st2.key_value_pair.create',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating datastore item creation.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'object': {}
- }
- }
+ "name": "st2.key_value_pair.create",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating datastore item creation.",
+ "payload_schema": {"type": "object", "properties": {"object": {}}},
}
KEY_VALUE_PAIR_UPDATE_TRIGGER = {
- 'name': 'st2.key_value_pair.update',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating datastore set action.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'object': {}
- }
- }
+ "name": "st2.key_value_pair.update",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating datastore set action.",
+ "payload_schema": {"type": "object", "properties": {"object": {}}},
}
KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER = {
- 'name': 'st2.key_value_pair.value_change',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating a change of datastore item value.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'old_object': {},
- 'new_object': {}
- }
- }
+ "name": "st2.key_value_pair.value_change",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating a change of datastore item value.",
+ "payload_schema": {
+ "type": "object",
+ "properties": {"old_object": {}, "new_object": {}},
+ },
}
KEY_VALUE_PAIR_DELETE_TRIGGER = {
- 'name': 'st2.key_value_pair.delete',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Trigger encapsulating datastore item deletion.',
- 'payload_schema': {
- 'type': 'object',
- 'properties': {
- 'object': {}
- }
- }
+ "name": "st2.key_value_pair.delete",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Trigger encapsulating datastore item deletion.",
+ "payload_schema": {"type": "object", "properties": {"object": {}}},
}
# Internal system triggers which are available for each resource
INTERNAL_TRIGGER_TYPES = {
- 'action': [
+ "action": [
ACTION_SENSOR_TRIGGER,
NOTIFY_TRIGGER,
ACTION_FILE_WRITTEN_TRIGGER,
- INQUIRY_TRIGGER
- ],
- 'sensor': [
- SENSOR_SPAWN_TRIGGER,
- SENSOR_EXIT_TRIGGER
+ INQUIRY_TRIGGER,
],
- 'key_value_pair': [
+ "sensor": [SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER],
+ "key_value_pair": [
KEY_VALUE_PAIR_CREATE_TRIGGER,
KEY_VALUE_PAIR_UPDATE_TRIGGER,
KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER,
- KEY_VALUE_PAIR_DELETE_TRIGGER
- ]
+ KEY_VALUE_PAIR_DELETE_TRIGGER,
+ ],
}
WEBHOOKS_PARAMETERS_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'url': {
- 'type': 'string',
- 'required': True
- }
- },
- 'additionalProperties': False
+ "type": "object",
+ "properties": {"url": {"type": "string", "required": True}},
+ "additionalProperties": False,
}
WEBHOOKS_PAYLOAD_SCHEMA = {
- 'type': 'object',
- 'properties': {
- 'headers': {
- 'type': 'object'
- },
- 'body': {
- 'anyOf': [
- {'type': 'array'},
- {'type': 'object'},
+ "type": "object",
+ "properties": {
+ "headers": {"type": "object"},
+ "body": {
+ "anyOf": [
+ {"type": "array"},
+ {"type": "object"},
]
- }
- }
+ },
+ },
}
WEBHOOK_TRIGGER_TYPES = {
- ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.webhook'): {
- 'name': 'st2.webhook',
- 'pack': SYSTEM_PACK_NAME,
- 'description': ('Trigger type for registering webhooks that can consume'
- ' arbitrary payload.'),
- 'parameters_schema': WEBHOOKS_PARAMETERS_SCHEMA,
- 'payload_schema': WEBHOOKS_PAYLOAD_SCHEMA
+ ResourceReference.to_string_reference(SYSTEM_PACK_NAME, "st2.webhook"): {
+ "name": "st2.webhook",
+ "pack": SYSTEM_PACK_NAME,
+ "description": (
+ "Trigger type for registering webhooks that can consume"
+ " arbitrary payload."
+ ),
+ "parameters_schema": WEBHOOKS_PARAMETERS_SCHEMA,
+ "payload_schema": WEBHOOKS_PAYLOAD_SCHEMA,
}
}
WEBHOOK_TRIGGER_TYPE = list(WEBHOOK_TRIGGER_TYPES.keys())[0]
@@ -265,107 +221,69 @@
INTERVAL_PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
- "timezone": {
- "type": "string"
- },
+ "timezone": {"type": "string"},
"unit": {
"enum": ["weeks", "days", "hours", "minutes", "seconds"],
- "required": True
+ "required": True,
},
- "delta": {
- "type": "integer",
- "required": True
-
- }
+ "delta": {"type": "integer", "required": True},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
DATE_PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
- "timezone": {
- "type": "string"
- },
- "date": {
- "type": "string",
- "format": "date-time",
- "required": True
- }
+ "timezone": {"type": "string"},
+ "date": {"type": "string", "format": "date-time", "required": True},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
CRON_PARAMETERS_SCHEMA = {
"type": "object",
"properties": {
- "timezone": {
- "type": "string"
- },
+ "timezone": {"type": "string"},
"year": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
},
"month": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 1,
- "maximum": 12
+ "maximum": 12,
},
"day": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 1,
- "maximum": 31
+ "maximum": 31,
},
"week": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 1,
- "maximum": 53
+ "maximum": 53,
},
"day_of_week": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 0,
- "maximum": 6
+ "maximum": 6,
},
"hour": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 0,
- "maximum": 23
+ "maximum": 23,
},
"minute": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 0,
- "maximum": 59
+ "maximum": 59,
},
"second": {
- "anyOf": [
- {"type": "string"},
- {"type": "integer"}
- ],
+ "anyOf": [{"type": "string"}, {"type": "integer"}],
"minimum": 0,
- "maximum": 59
- }
+ "maximum": 59,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
TIMER_PAYLOAD_SCHEMA = {
@@ -374,61 +292,62 @@
"executed_at": {
"type": "string",
"format": "date-time",
- "default": "2014-07-30 05:04:24.578325"
+ "default": "2014-07-30 05:04:24.578325",
},
- "schedule": {
- "type": "object",
- "default": {
- "delta": 30,
- "units": "seconds"
- }
- }
- }
+ "schedule": {"type": "object", "default": {"delta": 30, "units": "seconds"}},
+ },
}
-INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME,
- 'st2.IntervalTimer')
-DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.DateTimer')
-CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.CronTimer')
+INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(
+ SYSTEM_PACK_NAME, "st2.IntervalTimer"
+)
+DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(
+ SYSTEM_PACK_NAME, "st2.DateTimer"
+)
+CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(
+ SYSTEM_PACK_NAME, "st2.CronTimer"
+)
TIMER_TRIGGER_TYPES = {
INTERVAL_TIMER_TRIGGER_REF: {
- 'name': 'st2.IntervalTimer',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Triggers on specified intervals. e.g. every 30s, 1week etc.',
- 'payload_schema': TIMER_PAYLOAD_SCHEMA,
- 'parameters_schema': INTERVAL_PARAMETERS_SCHEMA
+ "name": "st2.IntervalTimer",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Triggers on specified intervals. e.g. every 30s, 1week etc.",
+ "payload_schema": TIMER_PAYLOAD_SCHEMA,
+ "parameters_schema": INTERVAL_PARAMETERS_SCHEMA,
},
DATE_TIMER_TRIGGER_REF: {
- 'name': 'st2.DateTimer',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Triggers exactly once when the current time matches the specified time. '
- 'e.g. timezone:UTC date:2014-12-31 23:59:59.',
- 'payload_schema': TIMER_PAYLOAD_SCHEMA,
- 'parameters_schema': DATE_PARAMETERS_SCHEMA
+ "name": "st2.DateTimer",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Triggers exactly once when the current time matches the specified time. "
+ "e.g. timezone:UTC date:2014-12-31 23:59:59.",
+ "payload_schema": TIMER_PAYLOAD_SCHEMA,
+ "parameters_schema": DATE_PARAMETERS_SCHEMA,
},
CRON_TIMER_TRIGGER_REF: {
- 'name': 'st2.CronTimer',
- 'pack': SYSTEM_PACK_NAME,
- 'description': 'Triggers whenever current time matches the specified time constaints like '
- 'a UNIX cron scheduler.',
- 'payload_schema': TIMER_PAYLOAD_SCHEMA,
- 'parameters_schema': CRON_PARAMETERS_SCHEMA
- }
+ "name": "st2.CronTimer",
+ "pack": SYSTEM_PACK_NAME,
+ "description": "Triggers whenever current time matches the specified time constaints like "
+ "a UNIX cron scheduler.",
+ "payload_schema": TIMER_PAYLOAD_SCHEMA,
+ "parameters_schema": CRON_PARAMETERS_SCHEMA,
+ },
}
-SYSTEM_TRIGGER_TYPES = dict(list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items()))
+SYSTEM_TRIGGER_TYPES = dict(
+ list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items())
+)
# various status to record lifecycle of a TriggerInstance
-TRIGGER_INSTANCE_PENDING = 'pending'
-TRIGGER_INSTANCE_PROCESSING = 'processing'
-TRIGGER_INSTANCE_PROCESSED = 'processed'
-TRIGGER_INSTANCE_PROCESSING_FAILED = 'processing_failed'
+TRIGGER_INSTANCE_PENDING = "pending"
+TRIGGER_INSTANCE_PROCESSING = "processing"
+TRIGGER_INSTANCE_PROCESSED = "processed"
+TRIGGER_INSTANCE_PROCESSING_FAILED = "processing_failed"
TRIGGER_INSTANCE_STATUSES = [
TRIGGER_INSTANCE_PENDING,
TRIGGER_INSTANCE_PROCESSING,
TRIGGER_INSTANCE_PROCESSED,
- TRIGGER_INSTANCE_PROCESSING_FAILED
+ TRIGGER_INSTANCE_PROCESSING_FAILED,
]
diff --git a/st2common/st2common/constants/types.py b/st2common/st2common/constants/types.py
index 7873d5b665..01ec79605f 100644
--- a/st2common/st2common/constants/types.py
+++ b/st2common/st2common/constants/types.py
@@ -16,9 +16,7 @@
from __future__ import absolute_import
from st2common.util.enum import Enum
-__all__ = [
- 'ResourceType'
-]
+__all__ = ["ResourceType"]
class ResourceType(Enum):
@@ -27,37 +25,37 @@ class ResourceType(Enum):
"""
# System resources
- RUNNER_TYPE = 'runner_type'
+ RUNNER_TYPE = "runner_type"
# Pack resources
- PACK = 'pack'
- ACTION = 'action'
- ACTION_ALIAS = 'action_alias'
- SENSOR_TYPE = 'sensor_type'
- TRIGGER_TYPE = 'trigger_type'
- TRIGGER = 'trigger'
- TRIGGER_INSTANCE = 'trigger_instance'
- RULE = 'rule'
- RULE_ENFORCEMENT = 'rule_enforcement'
+ PACK = "pack"
+ ACTION = "action"
+ ACTION_ALIAS = "action_alias"
+ SENSOR_TYPE = "sensor_type"
+ TRIGGER_TYPE = "trigger_type"
+ TRIGGER = "trigger"
+ TRIGGER_INSTANCE = "trigger_instance"
+ RULE = "rule"
+ RULE_ENFORCEMENT = "rule_enforcement"
# Note: Policy type is a global resource and policy belong to a pack
- POLICY_TYPE = 'policy_type'
- POLICY = 'policy'
+ POLICY_TYPE = "policy_type"
+ POLICY = "policy"
# Other resources
- EXECUTION = 'execution'
- EXECUTION_REQUEST = 'execution_request'
- KEY_VALUE_PAIR = 'key_value_pair'
+ EXECUTION = "execution"
+ EXECUTION_REQUEST = "execution_request"
+ KEY_VALUE_PAIR = "key_value_pair"
- WEBHOOK = 'webhook'
- TIMER = 'timer'
- API_KEY = 'api_key'
- TRACE = 'trace'
- TIMER = 'timer'
+ WEBHOOK = "webhook"
+ TIMER = "timer"
+ API_KEY = "api_key"
+ TRACE = "trace"
+ TIMER = "timer"
# Special resource type for stream related stuff
- STREAM = 'stream'
+ STREAM = "stream"
- INQUIRY = 'inquiry'
+ INQUIRY = "inquiry"
- UNKNOWN = 'unknown'
+ UNKNOWN = "unknown"
diff --git a/st2common/st2common/content/bootstrap.py b/st2common/st2common/content/bootstrap.py
index 7690e2dd01..da76123023 100644
--- a/st2common/st2common/content/bootstrap.py
+++ b/st2common/st2common/content/bootstrap.py
@@ -38,49 +38,64 @@
from st2common.metrics.base import Timer
from st2common.util.virtualenvs import setup_pack_virtualenv
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
-LOG = logging.getLogger('st2common.content.bootstrap')
+LOG = logging.getLogger("st2common.content.bootstrap")
-cfg.CONF.register_cli_opt(cfg.BoolOpt('experimental', default=False))
+cfg.CONF.register_cli_opt(cfg.BoolOpt("experimental", default=False))
def register_opts():
content_opts = [
- cfg.BoolOpt('all', default=False, help='Register sensors, actions and rules.'),
- cfg.BoolOpt('triggers', default=False, help='Register triggers.'),
- cfg.BoolOpt('sensors', default=False, help='Register sensors.'),
- cfg.BoolOpt('actions', default=False, help='Register actions.'),
- cfg.BoolOpt('runners', default=False, help='Register runners.'),
- cfg.BoolOpt('rules', default=False, help='Register rules.'),
- cfg.BoolOpt('aliases', default=False, help='Register aliases.'),
- cfg.BoolOpt('policies', default=False, help='Register policies.'),
- cfg.BoolOpt('configs', default=False, help='Register and load pack configs.'),
-
- cfg.StrOpt('pack', default=None, help='Directory to the pack to register content from.'),
- cfg.StrOpt('runner-dir', default=None, help='Directory to load runners from.'),
- cfg.BoolOpt('setup-virtualenvs', default=False, help=('Setup Python virtual environments '
- 'all the Python runner actions.')),
- cfg.BoolOpt('recreate-virtualenvs', default=False, help=('Recreate Python virtual '
- 'environments for all the Python '
- 'Python runner actions.')),
-
+ cfg.BoolOpt("all", default=False, help="Register sensors, actions and rules."),
+ cfg.BoolOpt("triggers", default=False, help="Register triggers."),
+ cfg.BoolOpt("sensors", default=False, help="Register sensors."),
+ cfg.BoolOpt("actions", default=False, help="Register actions."),
+ cfg.BoolOpt("runners", default=False, help="Register runners."),
+ cfg.BoolOpt("rules", default=False, help="Register rules."),
+ cfg.BoolOpt("aliases", default=False, help="Register aliases."),
+ cfg.BoolOpt("policies", default=False, help="Register policies."),
+ cfg.BoolOpt("configs", default=False, help="Register and load pack configs."),
+ cfg.StrOpt(
+ "pack", default=None, help="Directory to the pack to register content from."
+ ),
+ cfg.StrOpt("runner-dir", default=None, help="Directory to load runners from."),
+ cfg.BoolOpt(
+ "setup-virtualenvs",
+ default=False,
+ help=(
+ "Setup Python virtual environments " "all the Python runner actions."
+ ),
+ ),
+ cfg.BoolOpt(
+ "recreate-virtualenvs",
+ default=False,
+ help=(
+ "Recreate Python virtual "
+ "environments for all the Python "
+ "Python runner actions."
+ ),
+ ),
# General options
# Note: This value should default to False since we want fail on failure behavior by
# default.
- cfg.BoolOpt('no-fail-on-failure', default=False,
- help=('Don\'t exit with non-zero if some resource registration fails.')),
+ cfg.BoolOpt(
+ "no-fail-on-failure",
+ default=False,
+ help=("Don't exit with non-zero if some resource registration fails."),
+ ),
# Note: Fail on failure is now a default behavior. This flag is only left here for backward
# compatibility reasons, but it's not actually used.
- cfg.BoolOpt('fail-on-failure', default=True,
- help=('Exit with non-zero if some resource registration fails.'))
+ cfg.BoolOpt(
+ "fail-on-failure",
+ default=True,
+ help=("Exit with non-zero if some resource registration fails."),
+ ),
]
try:
- cfg.CONF.register_cli_opts(content_opts, group='register')
+ cfg.CONF.register_cli_opts(content_opts, group="register")
except:
- sys.stderr.write('Failed registering opts.\n')
+ sys.stderr.write("Failed registering opts.\n")
register_opts()
@@ -91,9 +106,9 @@ def setup_virtualenvs(recreate_virtualenvs=False):
Setup Python virtual environments for all the registered or the provided pack.
"""
- LOG.info('=========================================================')
- LOG.info('########### Setting up virtual environments #############')
- LOG.info('=========================================================')
+ LOG.info("=========================================================")
+ LOG.info("########### Setting up virtual environments #############")
+ LOG.info("=========================================================")
pack_dir = cfg.CONF.register.pack
fail_on_failure = not cfg.CONF.register.no_fail_on_failure
@@ -134,15 +149,19 @@ def setup_virtualenvs(recreate_virtualenvs=False):
setup_pack_virtualenv(pack_name=pack_name, update=update, logger=LOG)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to setup virtualenv for pack "%s": %s', pack_name, e,
- exc_info=exc_info)
+ LOG.warning(
+ 'Failed to setup virtualenv for pack "%s": %s',
+ pack_name,
+ e,
+ exc_info=exc_info,
+ )
if fail_on_failure:
raise e
else:
setup_count += 1
- LOG.info('Setup virtualenv for %s pack(s).' % (setup_count))
+ LOG.info("Setup virtualenv for %s pack(s)." % (setup_count))
def register_triggers():
@@ -152,22 +171,21 @@ def register_triggers():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering triggers #####################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.triggers'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering triggers #####################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.triggers"):
registered_count = triggers_registrar.register_triggers(
- pack_dir=pack_dir,
- fail_on_failure=fail_on_failure
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s triggers.' % (registered_count))
+ LOG.info("Registered %s triggers." % (registered_count))
def register_sensors():
@@ -177,22 +195,21 @@ def register_sensors():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering sensors ######################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.sensors'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering sensors ######################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.sensors"):
registered_count = sensors_registrar.register_sensors(
- pack_dir=pack_dir,
- fail_on_failure=fail_on_failure
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s sensors.' % (registered_count))
+ LOG.info("Registered %s sensors." % (registered_count))
def register_runners():
@@ -202,24 +219,23 @@ def register_runners():
# 1. Register runner types
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering runners ######################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.runners'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering runners ######################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.runners"):
registered_count = runners_registrar.register_runners(
- fail_on_failure=fail_on_failure,
- experimental=False
+ fail_on_failure=fail_on_failure, experimental=False
)
except Exception as error:
exc_info = not fail_on_failure
# TODO: Narrow exception window
- LOG.warning('Failed to register runners: %s', error, exc_info=exc_info)
+ LOG.warning("Failed to register runners: %s", error, exc_info=exc_info)
if fail_on_failure:
raise error
- LOG.info('Registered %s runners.', registered_count)
+ LOG.info("Registered %s runners.", registered_count)
def register_actions():
@@ -231,22 +247,21 @@ def register_actions():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering actions ######################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.actions'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering actions ######################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.actions"):
registered_count = actions_registrar.register_actions(
- pack_dir=pack_dir,
- fail_on_failure=fail_on_failure
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register actions: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register actions: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s actions.' % (registered_count))
+ LOG.info("Registered %s actions." % (registered_count))
def register_rules():
@@ -257,28 +272,27 @@ def register_rules():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering rules ########################')
- LOG.info('=========================================================')
+ LOG.info("=========================================================")
+ LOG.info("############## Registering rules ########################")
+ LOG.info("=========================================================")
rule_types_registrar.register_rule_types()
except Exception as e:
- LOG.warning('Failed to register rule types: %s', e, exc_info=True)
+ LOG.warning("Failed to register rule types: %s", e, exc_info=True)
return
try:
- with Timer(key='st2.register.rules'):
+ with Timer(key="st2.register.rules"):
registered_count = rules_registrar.register_rules(
- pack_dir=pack_dir,
- fail_on_failure=fail_on_failure
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register rules: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register rules: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s rules.', registered_count)
+ LOG.info("Registered %s rules.", registered_count)
def register_aliases():
@@ -288,21 +302,20 @@ def register_aliases():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering aliases ######################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.aliases'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering aliases ######################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.aliases"):
registered_count = aliases_registrar.register_aliases(
- pack_dir=pack_dir,
- fail_on_failure=fail_on_failure
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
)
except Exception as e:
if fail_on_failure:
raise e
- LOG.warning('Failed to register aliases.', exc_info=True)
+ LOG.warning("Failed to register aliases.", exc_info=True)
- LOG.info('Registered %s aliases.', registered_count)
+ LOG.info("Registered %s aliases.", registered_count)
def register_policies():
@@ -313,31 +326,32 @@ def register_policies():
registered_type_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering policy types #################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.policies'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering policy types #################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.policies"):
registered_type_count = policies_registrar.register_policy_types(st2common)
except Exception:
- LOG.warning('Failed to register policy types.', exc_info=True)
+ LOG.warning("Failed to register policy types.", exc_info=True)
- LOG.info('Registered %s policy types.', registered_type_count)
+ LOG.info("Registered %s policy types.", registered_type_count)
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering policies #####################')
- LOG.info('=========================================================')
- registered_count = policies_registrar.register_policies(pack_dir=pack_dir,
- fail_on_failure=fail_on_failure)
+ LOG.info("=========================================================")
+ LOG.info("############## Registering policies #####################")
+ LOG.info("=========================================================")
+ registered_count = policies_registrar.register_policies(
+ pack_dir=pack_dir, fail_on_failure=fail_on_failure
+ )
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register policies: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register policies: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s policies.', registered_count)
+ LOG.info("Registered %s policies.", registered_count)
def register_configs():
@@ -347,23 +361,23 @@ def register_configs():
registered_count = 0
try:
- LOG.info('=========================================================')
- LOG.info('############## Registering configs ######################')
- LOG.info('=========================================================')
- with Timer(key='st2.register.configs'):
+ LOG.info("=========================================================")
+ LOG.info("############## Registering configs ######################")
+ LOG.info("=========================================================")
+ with Timer(key="st2.register.configs"):
registered_count = configs_registrar.register_configs(
pack_dir=pack_dir,
fail_on_failure=fail_on_failure,
- validate_configs=True
+ validate_configs=True,
)
except Exception as e:
exc_info = not fail_on_failure
- LOG.warning('Failed to register configs: %s', e, exc_info=exc_info)
+ LOG.warning("Failed to register configs: %s", e, exc_info=exc_info)
if fail_on_failure:
raise e
- LOG.info('Registered %s configs.' % (registered_count))
+ LOG.info("Registered %s configs." % (registered_count))
def register_content():
@@ -416,8 +430,12 @@ def register_content():
def setup(argv):
- common_setup(config=config, setup_db=True, register_mq_exchanges=True,
- register_internal_trigger_types=True)
+ common_setup(
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_internal_trigger_types=True,
+ )
def teardown():
@@ -431,5 +449,5 @@ def main(argv):
# This script registers actions and rules from content-packs.
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/st2common/st2common/content/loader.py b/st2common/st2common/content/loader.py
index 0dfae4c0b6..420323fd76 100644
--- a/st2common/st2common/content/loader.py
+++ b/st2common/st2common/content/loader.py
@@ -28,10 +28,7 @@
if six.PY2:
from io import open
-__all__ = [
- 'ContentPackLoader',
- 'MetaLoader'
-]
+__all__ = ["ContentPackLoader", "MetaLoader"]
LOG = logging.getLogger(__name__)
@@ -45,12 +42,12 @@ class ContentPackLoader(object):
# content - they just return a path
ALLOWED_CONTENT_TYPES = [
- 'triggers',
- 'sensors',
- 'actions',
- 'rules',
- 'aliases',
- 'policies'
+ "triggers",
+ "sensors",
+ "actions",
+ "rules",
+ "aliases",
+ "policies",
]
def get_packs(self, base_dirs):
@@ -91,7 +88,7 @@ def get_content(self, base_dirs, content_type):
assert isinstance(base_dirs, list)
if content_type not in self.ALLOWED_CONTENT_TYPES:
- raise ValueError('Unsupported content_type: %s' % (content_type))
+ raise ValueError("Unsupported content_type: %s" % (content_type))
content = {}
pack_to_dir_map = {}
@@ -99,14 +96,18 @@ def get_content(self, base_dirs, content_type):
if not os.path.isdir(base_dir):
raise ValueError('Directory "%s" doesn\'t exist' % (base_dir))
- dir_content = self._get_content_from_dir(base_dir=base_dir, content_type=content_type)
+ dir_content = self._get_content_from_dir(
+ base_dir=base_dir, content_type=content_type
+ )
# Check for duplicate packs
for pack_name, pack_content in six.iteritems(dir_content):
if pack_name in content:
pack_dir = pack_to_dir_map[pack_name]
- LOG.warning('Pack "%s" already found in "%s", ignoring content from "%s"' %
- (pack_name, pack_dir, base_dir))
+ LOG.warning(
+ 'Pack "%s" already found in "%s", ignoring content from "%s"'
+ % (pack_name, pack_dir, base_dir)
+ )
else:
content[pack_name] = pack_content
pack_to_dir_map[pack_name] = base_dir
@@ -126,13 +127,14 @@ def get_content_from_pack(self, pack_dir, content_type):
:rtype: ``str``
"""
if content_type not in self.ALLOWED_CONTENT_TYPES:
- raise ValueError('Unsupported content_type: %s' % (content_type))
+ raise ValueError("Unsupported content_type: %s" % (content_type))
if not os.path.isdir(pack_dir):
raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir))
- content = self._get_content_from_pack_dir(pack_dir=pack_dir,
- content_type=content_type)
+ content = self._get_content_from_pack_dir(
+ pack_dir=pack_dir, content_type=content_type
+ )
return content
def _get_packs_from_dir(self, base_dir):
@@ -154,8 +156,9 @@ def _get_content_from_dir(self, base_dir, content_type):
# Ignore missing or non directories
try:
- pack_content = self._get_content_from_pack_dir(pack_dir=pack_dir,
- content_type=content_type)
+ pack_content = self._get_content_from_pack_dir(
+ pack_dir=pack_dir, content_type=content_type
+ )
except ValueError:
continue
else:
@@ -170,13 +173,13 @@ def _get_content_from_pack_dir(self, pack_dir, content_type):
actions=self._get_actions,
rules=self._get_rules,
aliases=self._get_aliases,
- policies=self._get_policies
+ policies=self._get_policies,
)
get_func = content_types.get(content_type)
if get_func is None:
- raise ValueError('Invalid content_type: %s' % (content_type))
+ raise ValueError("Invalid content_type: %s" % (content_type))
if not os.path.isdir(pack_dir):
raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir))
@@ -185,22 +188,22 @@ def _get_content_from_pack_dir(self, pack_dir, content_type):
return pack_content
def _get_triggers(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='triggers')
+ return self._get_folder(pack_dir=pack_dir, content_type="triggers")
def _get_sensors(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='sensors')
+ return self._get_folder(pack_dir=pack_dir, content_type="sensors")
def _get_actions(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='actions')
+ return self._get_folder(pack_dir=pack_dir, content_type="actions")
def _get_rules(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='rules')
+ return self._get_folder(pack_dir=pack_dir, content_type="rules")
def _get_aliases(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='aliases')
+ return self._get_folder(pack_dir=pack_dir, content_type="aliases")
def _get_policies(self, pack_dir):
- return self._get_folder(pack_dir=pack_dir, content_type='policies')
+ return self._get_folder(pack_dir=pack_dir, content_type="policies")
def _get_folder(self, pack_dir, content_type):
path = os.path.join(pack_dir, content_type)
@@ -233,8 +236,10 @@ def load(self, file_path, expected_type=None):
file_name, file_ext = os.path.splitext(file_path)
if file_ext not in ALLOWED_EXTS:
- raise Exception('Unsupported meta type %s, file %s. Allowed: %s' %
- (file_ext, file_path, ALLOWED_EXTS))
+ raise Exception(
+ "Unsupported meta type %s, file %s. Allowed: %s"
+ % (file_ext, file_path, ALLOWED_EXTS)
+ )
result = self._load(PARSER_FUNCS[file_ext], file_path)
@@ -246,12 +251,12 @@ def load(self, file_path, expected_type=None):
return result
def _load(self, parser_func, file_path):
- with open(file_path, 'r', encoding='utf-8') as fd:
+ with open(file_path, "r", encoding="utf-8") as fd:
try:
return parser_func(fd)
except ValueError:
- LOG.exception('Failed loading content from %s.', file_path)
+ LOG.exception("Failed loading content from %s.", file_path)
raise
except ParserError:
- LOG.exception('Failed loading content from %s.', file_path)
+ LOG.exception("Failed loading content from %s.", file_path)
raise
diff --git a/st2common/st2common/content/utils.py b/st2common/st2common/content/utils.py
index 3bd5e2b12c..ad9386acf6 100644
--- a/st2common/st2common/content/utils.py
+++ b/st2common/st2common/content/utils.py
@@ -24,22 +24,24 @@
from st2common.util.shell import quote_unix
__all__ = [
- 'get_pack_group',
- 'get_system_packs_base_path',
- 'get_packs_base_paths',
- 'get_pack_base_path',
- 'get_pack_directory',
- 'get_pack_file_abs_path',
- 'get_pack_resource_file_abs_path',
- 'get_relative_path_to_pack_file',
- 'check_pack_directory_exists',
- 'check_pack_content_directory_exists'
+ "get_pack_group",
+ "get_system_packs_base_path",
+ "get_packs_base_paths",
+ "get_pack_base_path",
+ "get_pack_directory",
+ "get_pack_file_abs_path",
+ "get_pack_resource_file_abs_path",
+ "get_relative_path_to_pack_file",
+ "check_pack_directory_exists",
+ "check_pack_content_directory_exists",
]
INVALID_FILE_PATH_ERROR = """
Invalid file path: "%s". File path needs to be relative to the pack%sdirectory (%s).
For example "my_%s.py".
-""".strip().replace('\n', ' ')
+""".strip().replace(
+ "\n", " "
+)
# Cache which stores pack name -> pack base path mappings
PACK_NAME_TO_BASE_PATH_CACHE = {}
@@ -70,10 +72,10 @@ def get_packs_base_paths():
:rtype: ``list``
"""
system_packs_base_path = get_system_packs_base_path()
- packs_base_paths = cfg.CONF.content.packs_base_paths or ''
+ packs_base_paths = cfg.CONF.content.packs_base_paths or ""
# Remove trailing colon (if present)
- if packs_base_paths.endswith(':'):
+ if packs_base_paths.endswith(":"):
packs_base_paths = packs_base_paths[:-1]
result = []
@@ -81,7 +83,7 @@ def get_packs_base_paths():
if system_packs_base_path:
result.append(system_packs_base_path)
- packs_base_paths = packs_base_paths.split(':')
+ packs_base_paths = packs_base_paths.split(":")
result = result + packs_base_paths
result = [path for path in result if path]
@@ -223,22 +225,28 @@ def get_entry_point_abs_path(pack=None, entry_point=None, use_pack_cache=False):
return None
if os.path.isabs(entry_point):
- pack_base_path = get_pack_base_path(pack_name=pack, use_pack_cache=use_pack_cache)
+ pack_base_path = get_pack_base_path(
+ pack_name=pack, use_pack_cache=use_pack_cache
+ )
common_prefix = os.path.commonprefix([pack_base_path, entry_point])
if common_prefix != pack_base_path:
- raise ValueError('Entry point file "%s" is located outside of the pack directory' %
- (entry_point))
+ raise ValueError(
+ 'Entry point file "%s" is located outside of the pack directory'
+ % (entry_point)
+ )
return entry_point
- entry_point_abs_path = get_pack_resource_file_abs_path(pack_ref=pack,
- resource_type='action',
- file_path=entry_point)
+ entry_point_abs_path = get_pack_resource_file_abs_path(
+ pack_ref=pack, resource_type="action", file_path=entry_point
+ )
return entry_point_abs_path
-def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cache=False):
+def get_pack_file_abs_path(
+ pack_ref, file_path, resource_type=None, use_pack_cache=False
+):
"""
Retrieve full absolute path to the pack file.
@@ -258,36 +266,46 @@ def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cac
:rtype: ``str``
"""
- pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache)
+ pack_base_path = get_pack_base_path(
+ pack_name=pack_ref, use_pack_cache=use_pack_cache
+ )
if resource_type:
- resource_type_plural = ' %ss ' % (resource_type)
- resource_base_path = os.path.join(pack_base_path, '%ss/' % (resource_type))
+ resource_type_plural = " %ss " % (resource_type)
+ resource_base_path = os.path.join(pack_base_path, "%ss/" % (resource_type))
else:
- resource_type_plural = ' '
+ resource_type_plural = " "
resource_base_path = pack_base_path
path_components = []
path_components.append(pack_base_path)
# Normalize the path to prevent directory traversal
- normalized_file_path = os.path.normpath('/' + file_path).lstrip('/')
+ normalized_file_path = os.path.normpath("/" + file_path).lstrip("/")
if normalized_file_path != file_path:
- msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path,
- resource_type or 'action')
+ msg = INVALID_FILE_PATH_ERROR % (
+ file_path,
+ resource_type_plural,
+ resource_base_path,
+ resource_type or "action",
+ )
raise ValueError(msg)
path_components.append(normalized_file_path)
- result = os.path.join(*path_components) # pylint: disable=E1120
+ result = os.path.join(*path_components) # pylint: disable=E1120
assert normalized_file_path in result
# Final safety check for common prefix to avoid traversal attack
common_prefix = os.path.commonprefix([pack_base_path, result])
if common_prefix != pack_base_path:
- msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path,
- resource_type or 'action')
+ msg = INVALID_FILE_PATH_ERROR % (
+ file_path,
+ resource_type_plural,
+ resource_base_path,
+ resource_type or "action",
+ )
raise ValueError(msg)
return result
@@ -313,19 +331,20 @@ def get_pack_resource_file_abs_path(pack_ref, resource_type, file_path):
:rtype: ``str``
"""
path_components = []
- if resource_type == 'action':
- path_components.append('actions/')
- elif resource_type == 'sensor':
- path_components.append('sensors/')
- elif resource_type == 'rule':
- path_components.append('rules/')
+ if resource_type == "action":
+ path_components.append("actions/")
+ elif resource_type == "sensor":
+ path_components.append("sensors/")
+ elif resource_type == "rule":
+ path_components.append("rules/")
else:
- raise ValueError('Invalid resource type: %s' % (resource_type))
+ raise ValueError("Invalid resource type: %s" % (resource_type))
path_components.append(file_path)
file_path = os.path.join(*path_components) # pylint: disable=E1120
- result = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path,
- resource_type=resource_type)
+ result = get_pack_file_abs_path(
+ pack_ref=pack_ref, file_path=file_path, resource_type=resource_type
+ )
return result
@@ -341,7 +360,9 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False):
:rtype: ``str``
"""
- pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache)
+ pack_base_path = get_pack_base_path(
+ pack_name=pack_ref, use_pack_cache=use_pack_cache
+ )
if not os.path.isabs(file_path):
return file_path
@@ -350,8 +371,10 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False):
common_prefix = os.path.commonprefix([pack_base_path, file_path])
if common_prefix != pack_base_path:
- raise ValueError('file_path (%s) is not located inside the pack directory (%s)' %
- (file_path, pack_base_path))
+ raise ValueError(
+ "file_path (%s) is not located inside the pack directory (%s)"
+ % (file_path, pack_base_path)
+ )
relative_path = os.path.relpath(file_path, common_prefix)
return relative_path
@@ -381,15 +404,15 @@ def get_aliases_base_paths():
:rtype: ``list``
"""
- aliases_base_paths = cfg.CONF.content.aliases_base_paths or ''
+ aliases_base_paths = cfg.CONF.content.aliases_base_paths or ""
# Remove trailing colon (if present)
- if aliases_base_paths.endswith(':'):
+ if aliases_base_paths.endswith(":"):
aliases_base_paths = aliases_base_paths[:-1]
result = []
- aliases_base_paths = aliases_base_paths.split(':')
+ aliases_base_paths = aliases_base_paths.split(":")
result = aliases_base_paths
result = [path for path in result if path]
diff --git a/st2common/st2common/content/validators.py b/st2common/st2common/content/validators.py
index bba9c446e3..8b1ab822c0 100644
--- a/st2common/st2common/content/validators.py
+++ b/st2common/st2common/content/validators.py
@@ -19,20 +19,16 @@
from st2common.constants.pack import USER_PACK_NAME_BLACKLIST
-__all__ = [
- 'RequirementsValidator',
- 'validate_pack_name'
-]
+__all__ = ["RequirementsValidator", "validate_pack_name"]
class RequirementsValidator(object):
-
@staticmethod
def validate(requirements_file):
if not os.path.exists(requirements_file):
- raise Exception('Requirements file %s not found.' % requirements_file)
+ raise Exception("Requirements file %s not found." % requirements_file)
missing = []
- with open(requirements_file, 'r') as f:
+ with open(requirements_file, "r") as f:
for line in f:
rqmnt = line.strip()
try:
@@ -54,10 +50,9 @@ def validate_pack_name(name):
:rtype: ``str``
"""
if not name:
- raise ValueError('Content pack name cannot be empty')
+ raise ValueError("Content pack name cannot be empty")
if name.lower() in USER_PACK_NAME_BLACKLIST:
- raise ValueError('Name "%s" is blacklisted and can\'t be used' %
- (name.lower()))
+ raise ValueError('Name "%s" is blacklisted and can\'t be used' % (name.lower()))
return name
diff --git a/st2common/st2common/database_setup.py b/st2common/st2common/database_setup.py
index 2678ecbf2e..2e2e7d2a17 100644
--- a/st2common/st2common/database_setup.py
+++ b/st2common/st2common/database_setup.py
@@ -23,29 +23,27 @@
from st2common.models import db
from st2common.persistence import db_init
-__all__ = [
- 'db_config',
- 'db_setup',
- 'db_teardown'
-]
+__all__ = ["db_config", "db_setup", "db_teardown"]
def db_config():
- username = getattr(cfg.CONF.database, 'username', None)
- password = getattr(cfg.CONF.database, 'password', None)
-
- return {'db_name': cfg.CONF.database.db_name,
- 'db_host': cfg.CONF.database.host,
- 'db_port': cfg.CONF.database.port,
- 'username': username,
- 'password': password,
- 'ssl': cfg.CONF.database.ssl,
- 'ssl_keyfile': cfg.CONF.database.ssl_keyfile,
- 'ssl_certfile': cfg.CONF.database.ssl_certfile,
- 'ssl_cert_reqs': cfg.CONF.database.ssl_cert_reqs,
- 'ssl_ca_certs': cfg.CONF.database.ssl_ca_certs,
- 'authentication_mechanism': cfg.CONF.database.authentication_mechanism,
- 'ssl_match_hostname': cfg.CONF.database.ssl_match_hostname}
+ username = getattr(cfg.CONF.database, "username", None)
+ password = getattr(cfg.CONF.database, "password", None)
+
+ return {
+ "db_name": cfg.CONF.database.db_name,
+ "db_host": cfg.CONF.database.host,
+ "db_port": cfg.CONF.database.port,
+ "username": username,
+ "password": password,
+ "ssl": cfg.CONF.database.ssl,
+ "ssl_keyfile": cfg.CONF.database.ssl_keyfile,
+ "ssl_certfile": cfg.CONF.database.ssl_certfile,
+ "ssl_cert_reqs": cfg.CONF.database.ssl_cert_reqs,
+ "ssl_ca_certs": cfg.CONF.database.ssl_ca_certs,
+ "authentication_mechanism": cfg.CONF.database.authentication_mechanism,
+ "ssl_match_hostname": cfg.CONF.database.ssl_match_hostname,
+ }
def db_setup(ensure_indexes=True):
@@ -53,7 +51,7 @@ def db_setup(ensure_indexes=True):
Creates the database and indexes (optional).
"""
db_cfg = db_config()
- db_cfg['ensure_indexes'] = ensure_indexes
+ db_cfg["ensure_indexes"] = ensure_indexes
connection = db_init.db_setup_with_retry(**db_cfg)
return connection
diff --git a/st2common/st2common/exceptions/__init__.py b/st2common/st2common/exceptions/__init__.py
index ec4e9430e9..065d3ff0fe 100644
--- a/st2common/st2common/exceptions/__init__.py
+++ b/st2common/st2common/exceptions/__init__.py
@@ -16,24 +16,26 @@
class StackStormBaseException(Exception):
"""
- The root of the exception class hierarchy for all
- StackStorm server exceptions.
+ The root of the exception class hierarchy for all
+ StackStorm server exceptions.
- For exceptions raised by plug-ins, see StackStormPluginException
- class.
+ For exceptions raised by plug-ins, see StackStormPluginException
+ class.
"""
+
pass
class StackStormPluginException(StackStormBaseException):
"""
- The root of the exception class hierarchy for all
- exceptions that are defined as part of a StackStorm
- plug-in API.
-
- It is recommended that each API define a root exception
- class for the API. This root exception class for the
- API should inherit from the StackStormPluginException
- class.
+ The root of the exception class hierarchy for all
+ exceptions that are defined as part of a StackStorm
+ plug-in API.
+
+ It is recommended that each API define a root exception
+ class for the API. This root exception class for the
+ API should inherit from the StackStormPluginException
+ class.
"""
+
pass
diff --git a/st2common/st2common/exceptions/action.py b/st2common/st2common/exceptions/action.py
index f7ed430266..f4bba2ee75 100644
--- a/st2common/st2common/exceptions/action.py
+++ b/st2common/st2common/exceptions/action.py
@@ -17,9 +17,9 @@
from st2common.exceptions import StackStormBaseException
__all__ = [
- 'ParameterRenderingFailedException',
- 'InvalidActionReferencedException',
- 'InvalidActionParameterException'
+ "ParameterRenderingFailedException",
+ "InvalidActionReferencedException",
+ "InvalidActionParameterException",
]
diff --git a/st2common/st2common/exceptions/actionalias.py b/st2common/st2common/exceptions/actionalias.py
index 1c01cd5736..3172a72dc6 100644
--- a/st2common/st2common/exceptions/actionalias.py
+++ b/st2common/st2common/exceptions/actionalias.py
@@ -16,9 +16,7 @@
from __future__ import absolute_import
from st2common.exceptions import StackStormBaseException
-__all__ = [
- 'ActionAliasAmbiguityException'
-]
+__all__ = ["ActionAliasAmbiguityException"]
class ActionAliasAmbiguityException(ValueError, StackStormBaseException):
diff --git a/st2common/st2common/exceptions/api.py b/st2common/st2common/exceptions/api.py
index f5aee1c1c0..054eb1bcf1 100644
--- a/st2common/st2common/exceptions/api.py
+++ b/st2common/st2common/exceptions/api.py
@@ -16,8 +16,7 @@
from __future__ import absolute_import
from st2common.exceptions import StackStormBaseException
-__all__ = [
-]
+__all__ = []
class InternalServerErrorException(StackStormBaseException):
diff --git a/st2common/st2common/exceptions/auth.py b/st2common/st2common/exceptions/auth.py
index 429d597abd..5eab1915f5 100644
--- a/st2common/st2common/exceptions/auth.py
+++ b/st2common/st2common/exceptions/auth.py
@@ -18,19 +18,19 @@
from st2common.exceptions.db import StackStormDBObjectNotFoundError
__all__ = [
- 'TokenNotProvidedError',
- 'TokenNotFoundError',
- 'TokenExpiredError',
- 'TTLTooLargeException',
- 'ApiKeyNotProvidedError',
- 'ApiKeyNotFoundError',
- 'MultipleAuthSourcesError',
- 'NoAuthSourceProvidedError',
- 'NoNicknameOriginProvidedError',
- 'UserNotFoundError',
- 'AmbiguousUserError',
- 'NotServiceUserError',
- 'SSOVerificationError'
+ "TokenNotProvidedError",
+ "TokenNotFoundError",
+ "TokenExpiredError",
+ "TTLTooLargeException",
+ "ApiKeyNotProvidedError",
+ "ApiKeyNotFoundError",
+ "MultipleAuthSourcesError",
+ "NoAuthSourceProvidedError",
+ "NoNicknameOriginProvidedError",
+ "UserNotFoundError",
+ "AmbiguousUserError",
+ "NotServiceUserError",
+ "SSOVerificationError",
]
diff --git a/st2common/st2common/exceptions/connection.py b/st2common/st2common/exceptions/connection.py
index 8cb9681b41..806d6e1046 100644
--- a/st2common/st2common/exceptions/connection.py
+++ b/st2common/st2common/exceptions/connection.py
@@ -16,14 +16,17 @@
class UnknownHostException(Exception):
"""Raised when a host is unknown (dns failure)"""
+
pass
class ConnectionErrorException(Exception):
"""Raised on error connecting (connection refused/timed out)"""
+
pass
class AuthenticationException(Exception):
"""Raised on authentication error (user/password/ssh key error)"""
+
pass
diff --git a/st2common/st2common/exceptions/db.py b/st2common/st2common/exceptions/db.py
index fcd607e964..776d927e0f 100644
--- a/st2common/st2common/exceptions/db.py
+++ b/st2common/st2common/exceptions/db.py
@@ -29,6 +29,7 @@ class StackStormDBObjectConflictError(StackStormBaseException):
"""
Exception that captures a DB object conflict error.
"""
+
def __init__(self, message, conflict_id, model_object):
super(StackStormDBObjectConflictError, self).__init__(message)
self.conflict_id = conflict_id
@@ -36,7 +37,9 @@ def __init__(self, message, conflict_id, model_object):
class StackStormDBObjectWriteConflictError(StackStormBaseException):
-
def __init__(self, instance):
- msg = 'Conflict saving DB object with id "%s" and rev "%s".' % (instance.id, instance.rev)
+ msg = 'Conflict saving DB object with id "%s" and rev "%s".' % (
+ instance.id,
+ instance.rev,
+ )
super(StackStormDBObjectWriteConflictError, self).__init__(msg)
diff --git a/st2common/st2common/exceptions/inquiry.py b/st2common/st2common/exceptions/inquiry.py
index 0636d0f985..b5c3f30646 100644
--- a/st2common/st2common/exceptions/inquiry.py
+++ b/st2common/st2common/exceptions/inquiry.py
@@ -23,32 +23,33 @@
class InvalidInquiryInstance(st2_exc.StackStormBaseException):
-
def __init__(self, inquiry_id):
- Exception.__init__(self, 'Action execution "%s" is not an inquiry.' % inquiry_id)
+ Exception.__init__(
+ self, 'Action execution "%s" is not an inquiry.' % inquiry_id
+ )
class InquiryTimedOut(st2_exc.StackStormBaseException):
-
def __init__(self, inquiry_id):
- Exception.__init__(self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id)
+ Exception.__init__(
+ self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id
+ )
class InquiryAlreadyResponded(st2_exc.StackStormBaseException):
-
def __init__(self, inquiry_id):
- Exception.__init__(self, 'Inquiry "%s" has already been responded to.' % inquiry_id)
+ Exception.__init__(
+ self, 'Inquiry "%s" has already been responded to.' % inquiry_id
+ )
class InquiryResponseUnauthorized(st2_exc.StackStormBaseException):
-
def __init__(self, inquiry_id, user):
msg = 'User "%s" does not have permission to respond to inquiry "%s".'
Exception.__init__(self, msg % (user, inquiry_id))
class InvalidInquiryResponse(st2_exc.StackStormBaseException):
-
def __init__(self, inquiry_id, error):
msg = 'Response for inquiry "%s" did not pass schema validation. %s'
Exception.__init__(self, msg % (inquiry_id, error))
diff --git a/st2common/st2common/exceptions/keyvalue.py b/st2common/st2common/exceptions/keyvalue.py
index 6ef2702fe8..7fccb8b819 100644
--- a/st2common/st2common/exceptions/keyvalue.py
+++ b/st2common/st2common/exceptions/keyvalue.py
@@ -18,9 +18,9 @@
from st2common.exceptions.db import StackStormDBObjectNotFoundError
__all__ = [
- 'CryptoKeyNotSetupException',
- 'DataStoreKeyNotFoundError',
- 'InvalidScopeException'
+ "CryptoKeyNotSetupException",
+ "DataStoreKeyNotFoundError",
+ "InvalidScopeException",
]
diff --git a/st2common/st2common/exceptions/rbac.py b/st2common/st2common/exceptions/rbac.py
index 308110c267..957b0fe5be 100644
--- a/st2common/st2common/exceptions/rbac.py
+++ b/st2common/st2common/exceptions/rbac.py
@@ -18,10 +18,10 @@
from st2common.rbac.types import GLOBAL_PERMISSION_TYPES
__all__ = [
- 'AccessDeniedError',
- 'ResourceTypeAccessDeniedError',
- 'ResourceAccessDeniedError',
- 'ResourceAccessDeniedPermissionIsolationError'
+ "AccessDeniedError",
+ "ResourceTypeAccessDeniedError",
+ "ResourceAccessDeniedError",
+ "ResourceAccessDeniedPermissionIsolationError",
]
@@ -45,9 +45,13 @@ class ResourceTypeAccessDeniedError(AccessDeniedError):
def __init__(self, user_db, permission_type):
self.permission_type = permission_type
- message = ('User "%s" doesn\'t have required permission "%s"' % (user_db.name,
- permission_type))
- super(ResourceTypeAccessDeniedError, self).__init__(message=message, user_db=user_db)
+ message = 'User "%s" doesn\'t have required permission "%s"' % (
+ user_db.name,
+ permission_type,
+ )
+ super(ResourceTypeAccessDeniedError, self).__init__(
+ message=message, user_db=user_db
+ )
class ResourceAccessDeniedError(AccessDeniedError):
@@ -59,15 +63,25 @@ def __init__(self, user_db, resource_api_or_db, permission_type):
self.resource_api_db = resource_api_or_db
self.permission_type = permission_type
- resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown'
+ resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown"
if resource_api_or_db and permission_type not in GLOBAL_PERMISSION_TYPES:
- message = ('User "%s" doesn\'t have required permission "%s" on resource "%s"' %
- (user_db.name, permission_type, resource_uid))
+ message = (
+ 'User "%s" doesn\'t have required permission "%s" on resource "%s"'
+ % (
+ user_db.name,
+ permission_type,
+ resource_uid,
+ )
+ )
else:
- message = ('User "%s" doesn\'t have required permission "%s"' %
- (user_db.name, permission_type))
- super(ResourceAccessDeniedError, self).__init__(message=message, user_db=user_db)
+ message = 'User "%s" doesn\'t have required permission "%s"' % (
+ user_db.name,
+ permission_type,
+ )
+ super(ResourceAccessDeniedError, self).__init__(
+ message=message, user_db=user_db
+ )
class ResourceAccessDeniedPermissionIsolationError(AccessDeniedError):
@@ -80,9 +94,12 @@ def __init__(self, user_db, resource_api_or_db, permission_type):
self.resource_api_db = resource_api_or_db
self.permission_type = permission_type
- resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown'
+ resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown"
- message = ('User "%s" doesn\'t have access to resource "%s" due to resource permission '
- 'isolation.' % (user_db.name, resource_uid))
- super(ResourceAccessDeniedPermissionIsolationError, self).__init__(message=message,
- user_db=user_db)
+ message = (
+ 'User "%s" doesn\'t have access to resource "%s" due to resource permission '
+ "isolation." % (user_db.name, resource_uid)
+ )
+ super(ResourceAccessDeniedPermissionIsolationError, self).__init__(
+ message=message, user_db=user_db
+ )
diff --git a/st2common/st2common/exceptions/ssh.py b/st2common/st2common/exceptions/ssh.py
index f720e54b8a..7a4e1ee516 100644
--- a/st2common/st2common/exceptions/ssh.py
+++ b/st2common/st2common/exceptions/ssh.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'InvalidCredentialsException'
-]
+__all__ = ["InvalidCredentialsException"]
class InvalidCredentialsException(Exception):
diff --git a/st2common/st2common/exceptions/workflow.py b/st2common/st2common/exceptions/workflow.py
index dd787417c2..2a346819be 100644
--- a/st2common/st2common/exceptions/workflow.py
+++ b/st2common/st2common/exceptions/workflow.py
@@ -27,28 +27,25 @@
def retry_on_connection_errors(exc):
- LOG.warning('Determining if exception %s should be retried.', type(exc))
+ LOG.warning("Determining if exception %s should be retried.", type(exc))
- retrying = (
- isinstance(exc, tooz.coordination.ToozConnectionError) or
- isinstance(exc, mongoengine.connection.MongoEngineConnectionError)
+ retrying = isinstance(exc, tooz.coordination.ToozConnectionError) or isinstance(
+ exc, mongoengine.connection.MongoEngineConnectionError
)
if retrying:
- LOG.warning('Retrying operation due to connection error: %s', type(exc))
+ LOG.warning("Retrying operation due to connection error: %s", type(exc))
return retrying
def retry_on_transient_db_errors(exc):
- LOG.warning('Determining if exception %s should be retried.', type(exc))
+ LOG.warning("Determining if exception %s should be retried.", type(exc))
- retrying = (
- isinstance(exc, db_exc.StackStormDBObjectWriteConflictError)
- )
+ retrying = isinstance(exc, db_exc.StackStormDBObjectWriteConflictError)
if retrying:
- LOG.warning('Retrying operation due to transient database error: %s', type(exc))
+ LOG.warning("Retrying operation due to transient database error: %s", type(exc))
return retrying
@@ -62,38 +59,37 @@ class WorkflowExecutionException(st2_exc.StackStormBaseException):
class WorkflowExecutionNotFoundException(st2_exc.StackStormBaseException):
-
def __init__(self, ac_ex_id):
Exception.__init__(
self,
- 'Unable to identify any workflow execution that is '
- 'associated to action execution "%s".' % ac_ex_id
+ "Unable to identify any workflow execution that is "
+ 'associated to action execution "%s".' % ac_ex_id,
)
class AmbiguousWorkflowExecutionException(st2_exc.StackStormBaseException):
-
def __init__(self, ac_ex_id):
Exception.__init__(
self,
- 'More than one workflow execution is associated '
- 'to action execution "%s".' % ac_ex_id
+ "More than one workflow execution is associated "
+ 'to action execution "%s".' % ac_ex_id,
)
class WorkflowExecutionIsCompletedException(st2_exc.StackStormBaseException):
-
def __init__(self, wf_ex_id):
- Exception.__init__(self, 'Workflow execution "%s" is already completed.' % wf_ex_id)
+ Exception.__init__(
+ self, 'Workflow execution "%s" is already completed.' % wf_ex_id
+ )
class WorkflowExecutionIsRunningException(st2_exc.StackStormBaseException):
-
def __init__(self, wf_ex_id):
- Exception.__init__(self, 'Workflow execution "%s" is already active.' % wf_ex_id)
+ Exception.__init__(
+ self, 'Workflow execution "%s" is already active.' % wf_ex_id
+ )
class WorkflowExecutionRerunException(st2_exc.StackStormBaseException):
-
def __init__(self, msg):
Exception.__init__(self, msg)
diff --git a/st2common/st2common/expressions/functions/data.py b/st2common/st2common/expressions/functions/data.py
index d3783e652e..b240cb7238 100644
--- a/st2common/st2common/expressions/functions/data.py
+++ b/st2common/st2common/expressions/functions/data.py
@@ -24,13 +24,13 @@
__all__ = [
- 'from_json_string',
- 'from_yaml_string',
- 'json_escape',
- 'jsonpath_query',
- 'to_complex',
- 'to_json_string',
- 'to_yaml_string',
+ "from_json_string",
+ "from_yaml_string",
+ "json_escape",
+ "jsonpath_query",
+ "to_complex",
+ "to_json_string",
+ "to_yaml_string",
]
@@ -42,19 +42,19 @@ def from_yaml_string(value):
return yaml.safe_load(six.text_type(value))
-def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')):
+def to_json_string(value, indent=None, sort_keys=False, separators=(",", ": ")):
value = db_util.mongodb_to_python_types(value)
options = {}
if indent is not None:
- options['indent'] = indent
+ options["indent"] = indent
if sort_keys is not None:
- options['sort_keys'] = sort_keys
+ options["sort_keys"] = sort_keys
if separators is not None:
- options['separators'] = separators
+ options["separators"] = separators
return json.dumps(value, **options)
@@ -62,19 +62,19 @@ def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')):
def to_yaml_string(value, indent=None, allow_unicode=True):
value = db_util.mongodb_to_python_types(value)
- options = {'default_flow_style': False}
+ options = {"default_flow_style": False}
if indent is not None:
- options['indent'] = indent
+ options["indent"] = indent
if allow_unicode is not None:
- options['allow_unicode'] = allow_unicode
+ options["allow_unicode"] = allow_unicode
return yaml.safe_dump(value, **options)
def json_escape(value):
- """ Adds escape sequences to problematic characters in the string
+ """Adds escape sequences to problematic characters in the string
This filter simply passes the value to json.dumps
as a convenient way of escaping characters in it
However, before returning, we want to strip the double
@@ -110,7 +110,7 @@ def to_complex(value):
# Magic string to which None type is serialized when using use_none filter
-NONE_MAGIC_VALUE = '%*****__%NONE%__*****%'
+NONE_MAGIC_VALUE = "%*****__%NONE%__*****%"
def use_none(value):
diff --git a/st2common/st2common/expressions/functions/datastore.py b/st2common/st2common/expressions/functions/datastore.py
index a8e903c377..bd0e5fbb09 100644
--- a/st2common/st2common/expressions/functions/datastore.py
+++ b/st2common/st2common/expressions/functions/datastore.py
@@ -22,9 +22,7 @@
from st2common.util.crypto import read_crypto_key
from st2common.util.crypto import symmetric_decrypt
-__all__ = [
- 'decrypt_kv'
-]
+__all__ = ["decrypt_kv"]
def decrypt_kv(value):
@@ -41,11 +39,13 @@ def decrypt_kv(value):
# NOTE: If value is None this indicate key value item doesn't exist and we hrow a more
# user-friendly error
- if is_kv_item and value == '':
+ if is_kv_item and value == "":
# Build original key name
key_name = original_value.get_key_name()
- raise ValueError('Referenced datastore item "%s" doesn\'t exist or it contains an empty '
- 'string' % (key_name))
+ raise ValueError(
+ 'Referenced datastore item "%s" doesn\'t exist or it contains an empty '
+ "string" % (key_name)
+ )
crypto_key_path = cfg.CONF.keyvalue.encryption_key_path
crypto_key = read_crypto_key(key_path=crypto_key_path)
diff --git a/st2common/st2common/expressions/functions/path.py b/st2common/st2common/expressions/functions/path.py
index 6081be895c..d21f301aa1 100644
--- a/st2common/st2common/expressions/functions/path.py
+++ b/st2common/st2common/expressions/functions/path.py
@@ -16,10 +16,7 @@
from __future__ import absolute_import
import os
-__all__ = [
- 'basename',
- 'dirname'
-]
+__all__ = ["basename", "dirname"]
def basename(path):
diff --git a/st2common/st2common/expressions/functions/regex.py b/st2common/st2common/expressions/functions/regex.py
index 4db7fe0f65..4b17b7372f 100644
--- a/st2common/st2common/expressions/functions/regex.py
+++ b/st2common/st2common/expressions/functions/regex.py
@@ -17,12 +17,7 @@
import re
import six
-__all__ = [
- 'regex_match',
- 'regex_replace',
- 'regex_search',
- 'regex_substring'
-]
+__all__ = ["regex_match", "regex_replace", "regex_search", "regex_substring"]
def _get_regex_flags(ignorecase=False):
diff --git a/st2common/st2common/expressions/functions/time.py b/st2common/st2common/expressions/functions/time.py
index 543fc80938..d25b8acecc 100644
--- a/st2common/st2common/expressions/functions/time.py
+++ b/st2common/st2common/expressions/functions/time.py
@@ -19,14 +19,12 @@
import datetime
-__all__ = [
- 'to_human_time_from_seconds'
-]
+__all__ = ["to_human_time_from_seconds"]
if six.PY3:
long_int = int
else:
- long_int = long # noqa # pylint: disable=E0602
+ long_int = long # noqa # pylint: disable=E0602
def to_human_time_from_seconds(seconds):
@@ -39,8 +37,11 @@ def to_human_time_from_seconds(seconds):
:rtype: ``str``
"""
- assert (isinstance(seconds, int) or isinstance(seconds, int) or
- isinstance(seconds, float))
+ assert (
+ isinstance(seconds, int)
+ or isinstance(seconds, int)
+ or isinstance(seconds, float)
+ )
return _get_human_time(seconds)
@@ -59,10 +60,10 @@ def _get_human_time(seconds):
return None
if seconds == 0:
- return '0s'
+ return "0s"
if seconds < 1:
- return '%s\u03BCs' % seconds # Microseconds
+ return "%s\u03BCs" % seconds # Microseconds
if isinstance(seconds, float):
seconds = long_int(round(seconds)) # Let's lose microseconds.
@@ -81,17 +82,17 @@ def _get_human_time(seconds):
first_non_zero_pos = next((i for i, x in enumerate(time_parts) if x), None)
if first_non_zero_pos is None:
- return '0s'
+ return "0s"
else:
time_parts = time_parts[first_non_zero_pos:]
if len(time_parts) == 1:
- return '%ss' % tuple(time_parts)
+ return "%ss" % tuple(time_parts)
elif len(time_parts) == 2:
- return '%sm%ss' % tuple(time_parts)
+ return "%sm%ss" % tuple(time_parts)
elif len(time_parts) == 3:
- return '%sh%sm%ss' % tuple(time_parts)
+ return "%sh%sm%ss" % tuple(time_parts)
elif len(time_parts) == 4:
- return '%sd%sh%sm%ss' % tuple(time_parts)
+ return "%sd%sh%sm%ss" % tuple(time_parts)
elif len(time_parts) == 5:
- return '%sy%sd%sh%sm%ss' % tuple(time_parts)
+ return "%sy%sd%sh%sm%ss" % tuple(time_parts)
diff --git a/st2common/st2common/expressions/functions/version.py b/st2common/st2common/expressions/functions/version.py
index 2dc8d353f1..825d5965e3 100644
--- a/st2common/st2common/expressions/functions/version.py
+++ b/st2common/st2common/expressions/functions/version.py
@@ -17,13 +17,13 @@
import semver
__all__ = [
- 'version_compare',
- 'version_more_than',
- 'version_less_than',
- 'version_equal',
- 'version_match',
- 'version_bump_major',
- 'version_bump_minor'
+ "version_compare",
+ "version_more_than",
+ "version_less_than",
+ "version_equal",
+ "version_match",
+ "version_bump_major",
+ "version_bump_minor",
]
diff --git a/st2common/st2common/fields.py b/st2common/st2common/fields.py
index 7217365874..b968e2fdb7 100644
--- a/st2common/st2common/fields.py
+++ b/st2common/st2common/fields.py
@@ -21,9 +21,7 @@
from st2common.util import date as date_utils
-__all__ = [
- 'ComplexDateTimeField'
-]
+__all__ = ["ComplexDateTimeField"]
SECOND_TO_MICROSECONDS = 1000000
@@ -60,7 +58,7 @@ def _microseconds_since_epoch_to_datetime(self, data):
:type data: ``int``
"""
result = datetime.datetime.utcfromtimestamp(data // SECOND_TO_MICROSECONDS)
- microseconds_reminder = (data % SECOND_TO_MICROSECONDS)
+ microseconds_reminder = data % SECOND_TO_MICROSECONDS
result = result.replace(microsecond=microseconds_reminder)
result = date_utils.add_utc_tz(result)
return result
@@ -77,11 +75,13 @@ def _datetime_to_microseconds_since_epoch(self, value):
# Verify that the value which is passed in contains UTC timezone
# information.
if not value.tzinfo or (value.tzinfo.utcoffset(value) != datetime.timedelta(0)):
- raise ValueError('Value passed to this function needs to be in UTC timezone')
+ raise ValueError(
+ "Value passed to this function needs to be in UTC timezone"
+ )
seconds = calendar.timegm(value.timetuple())
microseconds_reminder = value.time().microsecond
- result = (int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder)
+ result = int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder
return result
def __get__(self, instance, owner):
@@ -99,8 +99,7 @@ def __set__(self, instance, value):
def validate(self, value):
value = self.to_python(value)
if not isinstance(value, datetime.datetime):
- self.error('Only datetime objects may used in a '
- 'ComplexDateTimeField')
+ self.error("Only datetime objects may used in a " "ComplexDateTimeField")
def to_python(self, value):
original_value = value
diff --git a/st2common/st2common/garbage_collection/executions.py b/st2common/st2common/garbage_collection/executions.py
index ba924e76f2..ae0f3296f4 100644
--- a/st2common/st2common/garbage_collection/executions.py
+++ b/st2common/st2common/garbage_collection/executions.py
@@ -32,15 +32,14 @@
from st2common.services import action as action_service
from st2common.services import workflows as workflow_service
-__all__ = [
- 'purge_executions',
- 'purge_execution_output_objects'
-]
+__all__ = ["purge_executions", "purge_execution_output_objects"]
-DONE_STATES = [action_constants.LIVEACTION_STATUS_SUCCEEDED,
- action_constants.LIVEACTION_STATUS_FAILED,
- action_constants.LIVEACTION_STATUS_TIMED_OUT,
- action_constants.LIVEACTION_STATUS_CANCELED]
+DONE_STATES = [
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
+ action_constants.LIVEACTION_STATUS_FAILED,
+ action_constants.LIVEACTION_STATUS_TIMED_OUT,
+ action_constants.LIVEACTION_STATUS_CANCELED,
+]
def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False):
@@ -57,90 +56,118 @@ def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False)
:type purge_incomplete: ``bool``
"""
if not timestamp:
- raise ValueError('Specify a valid timestamp to purge.')
+ raise ValueError("Specify a valid timestamp to purge.")
- logger.info('Purging executions older than timestamp: %s' %
- timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
+ logger.info(
+ "Purging executions older than timestamp: %s"
+ % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+ )
filters = {}
if purge_incomplete:
- filters['start_timestamp__lt'] = timestamp
+ filters["start_timestamp__lt"] = timestamp
else:
- filters['end_timestamp__lt'] = timestamp
- filters['start_timestamp__lt'] = timestamp
- filters['status'] = {'$in': DONE_STATES}
+ filters["end_timestamp__lt"] = timestamp
+ filters["start_timestamp__lt"] = timestamp
+ filters["status"] = {"$in": DONE_STATES}
exec_filters = copy.copy(filters)
if action_ref:
- exec_filters['action__ref'] = action_ref
+ exec_filters["action__ref"] = action_ref
liveaction_filters = copy.deepcopy(filters)
if action_ref:
- liveaction_filters['action'] = action_ref
+ liveaction_filters["action"] = action_ref
to_delete_execution_dbs = []
# 1. Delete ActionExecutionDB objects
try:
# Note: We call list() on the query set object because it's lazyily evaluated otherwise
- to_delete_execution_dbs = list(ActionExecution.query(only_fields=['id'],
- no_dereference=True,
- **exec_filters))
+ to_delete_execution_dbs = list(
+ ActionExecution.query(
+ only_fields=["id"], no_dereference=True, **exec_filters
+ )
+ )
deleted_count = ActionExecution.delete_by_query(**exec_filters)
except InvalidQueryError as e:
- msg = ('Bad query (%s) used to delete execution instances: %s'
- 'Please contact support.' % (exec_filters, six.text_type(e)))
+ msg = (
+ "Bad query (%s) used to delete execution instances: %s"
+ "Please contact support."
+ % (
+ exec_filters,
+ six.text_type(e),
+ )
+ )
raise InvalidQueryError(msg)
except:
- logger.exception('Deletion of execution models failed for query with filters: %s.',
- exec_filters)
+ logger.exception(
+ "Deletion of execution models failed for query with filters: %s.",
+ exec_filters,
+ )
else:
- logger.info('Deleted %s action execution objects' % (deleted_count))
+ logger.info("Deleted %s action execution objects" % (deleted_count))
# 2. Delete LiveActionDB objects
try:
deleted_count = LiveAction.delete_by_query(**liveaction_filters)
except InvalidQueryError as e:
- msg = ('Bad query (%s) used to delete liveaction instances: %s'
- 'Please contact support.' % (liveaction_filters, six.text_type(e)))
+ msg = (
+ "Bad query (%s) used to delete liveaction instances: %s"
+ "Please contact support."
+ % (
+ liveaction_filters,
+ six.text_type(e),
+ )
+ )
raise InvalidQueryError(msg)
except:
- logger.exception('Deletion of liveaction models failed for query with filters: %s.',
- liveaction_filters)
+ logger.exception(
+ "Deletion of liveaction models failed for query with filters: %s.",
+ liveaction_filters,
+ )
else:
- logger.info('Deleted %s liveaction objects' % (deleted_count))
+ logger.info("Deleted %s liveaction objects" % (deleted_count))
# 3. Delete ActionExecutionOutputDB objects
- to_delete_exection_ids = [str(execution_db.id) for execution_db in to_delete_execution_dbs]
+ to_delete_exection_ids = [
+ str(execution_db.id) for execution_db in to_delete_execution_dbs
+ ]
output_dbs_filters = {}
- output_dbs_filters['execution_id'] = {'$in': to_delete_exection_ids}
+ output_dbs_filters["execution_id"] = {"$in": to_delete_exection_ids}
try:
deleted_count = ActionExecutionOutput.delete_by_query(**output_dbs_filters)
except InvalidQueryError as e:
- msg = ('Bad query (%s) used to delete execution output instances: %s'
- 'Please contact support.' % (output_dbs_filters, six.text_type(e)))
+ msg = (
+ "Bad query (%s) used to delete execution output instances: %s"
+ "Please contact support." % (output_dbs_filters, six.text_type(e))
+ )
raise InvalidQueryError(msg)
except:
- logger.exception('Deletion of execution output models failed for query with filters: %s.',
- output_dbs_filters)
+ logger.exception(
+ "Deletion of execution output models failed for query with filters: %s.",
+ output_dbs_filters,
+ )
else:
- logger.info('Deleted %s execution output objects' % (deleted_count))
+ logger.info("Deleted %s execution output objects" % (deleted_count))
- zombie_execution_instances = len(ActionExecution.query(only_fields=['id'],
- no_dereference=True,
- **exec_filters))
- zombie_liveaction_instances = len(LiveAction.query(only_fields=['id'],
- no_dereference=True,
- **liveaction_filters))
+ zombie_execution_instances = len(
+ ActionExecution.query(only_fields=["id"], no_dereference=True, **exec_filters)
+ )
+ zombie_liveaction_instances = len(
+ LiveAction.query(only_fields=["id"], no_dereference=True, **liveaction_filters)
+ )
if (zombie_execution_instances > 0) or (zombie_liveaction_instances > 0):
- logger.error('Zombie execution instances left: %d.', zombie_execution_instances)
- logger.error('Zombie liveaction instances left: %s.', zombie_liveaction_instances)
+ logger.error("Zombie execution instances left: %d.", zombie_execution_instances)
+ logger.error(
+ "Zombie liveaction instances left: %s.", zombie_liveaction_instances
+ )
# Print stats
- logger.info('All execution models older than timestamp %s were deleted.', timestamp)
+ logger.info("All execution models older than timestamp %s were deleted.", timestamp)
def purge_execution_output_objects(logger, timestamp, action_ref=None):
@@ -154,28 +181,34 @@ def purge_execution_output_objects(logger, timestamp, action_ref=None):
:type action_ref: ``str``
"""
if not timestamp:
- raise ValueError('Specify a valid timestamp to purge.')
+ raise ValueError("Specify a valid timestamp to purge.")
- logger.info('Purging action execution output objects older than timestamp: %s' %
- timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
+ logger.info(
+ "Purging action execution output objects older than timestamp: %s"
+ % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+ )
filters = {}
- filters['timestamp__lt'] = timestamp
+ filters["timestamp__lt"] = timestamp
if action_ref:
- filters['action_ref'] = action_ref
+ filters["action_ref"] = action_ref
try:
deleted_count = ActionExecutionOutput.delete_by_query(**filters)
except InvalidQueryError as e:
- msg = ('Bad query (%s) used to delete execution output instances: %s'
- 'Please contact support.' % (filters, six.text_type(e)))
+ msg = (
+ "Bad query (%s) used to delete execution output instances: %s"
+ "Please contact support." % (filters, six.text_type(e))
+ )
raise InvalidQueryError(msg)
except:
- logger.exception('Deletion of execution output models failed for query with filters: %s.',
- filters)
+ logger.exception(
+ "Deletion of execution output models failed for query with filters: %s.",
+ filters,
+ )
else:
- logger.info('Deleted %s execution output objects' % (deleted_count))
+ logger.info("Deleted %s execution output objects" % (deleted_count))
def purge_orphaned_workflow_executions(logger):
@@ -190,5 +223,5 @@ def purge_orphaned_workflow_executions(logger):
# as a result of the original failure, the garbage collection routine here cancels
# the workflow execution so it cannot be rerun from failed task(s).
for ac_ex_db in workflow_service.identify_orphaned_workflows():
- lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction['id'])
+ lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction["id"])
action_service.request_cancellation(lv_ac_db, None)
diff --git a/st2common/st2common/garbage_collection/inquiries.py b/st2common/st2common/garbage_collection/inquiries.py
index 724033853f..ad95126b21 100644
--- a/st2common/st2common/garbage_collection/inquiries.py
+++ b/st2common/st2common/garbage_collection/inquiries.py
@@ -27,7 +27,7 @@
from st2common.util.date import get_datetime_utc_now
__all__ = [
- 'purge_inquiries',
+ "purge_inquiries",
]
@@ -44,7 +44,10 @@ def purge_inquiries(logger):
"""
# Get all existing Inquiries
- filters = {'runner__name': 'inquirer', 'status': action_constants.LIVEACTION_STATUS_PENDING}
+ filters = {
+ "runner__name": "inquirer",
+ "status": action_constants.LIVEACTION_STATUS_PENDING,
+ }
inquiries = list(ActionExecution.query(**filters))
gc_count = 0
@@ -52,7 +55,7 @@ def purge_inquiries(logger):
# Inspect each Inquiry, and determine if TTL is expired
for inquiry in inquiries:
- ttl = int(inquiry.result.get('ttl'))
+ ttl = int(inquiry.result.get("ttl"))
if ttl <= 0:
logger.debug("Inquiry %s has a TTL of %s. Skipping." % (inquiry.id, ttl))
continue
@@ -61,17 +64,22 @@ def purge_inquiries(logger):
(get_datetime_utc_now() - inquiry.start_timestamp).total_seconds() / 60
)
- logger.debug("Inquiry %s has a TTL of %s and was started %s minute(s) ago" % (
- inquiry.id, ttl, min_since_creation))
+ logger.debug(
+ "Inquiry %s has a TTL of %s and was started %s minute(s) ago"
+ % (inquiry.id, ttl, min_since_creation)
+ )
if min_since_creation > ttl:
gc_count += 1
- logger.info("TTL expired for Inquiry %s. Marking as timed out." % inquiry.id)
+ logger.info(
+ "TTL expired for Inquiry %s. Marking as timed out." % inquiry.id
+ )
liveaction_db = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_TIMED_OUT,
result=inquiry.result,
- liveaction_id=inquiry.liveaction.get('id'))
+ liveaction_id=inquiry.liveaction.get("id"),
+ )
executions.update_execution(liveaction_db)
# Call Inquiry runner's post_run to trigger callback to workflow
@@ -82,8 +90,7 @@ def purge_inquiries(logger):
# Request that root workflow resumes
root_liveaction = action_service.get_root_liveaction(liveaction_db)
action_service.request_resume(
- root_liveaction,
- UserDB(cfg.CONF.system_user.user)
+ root_liveaction, UserDB(cfg.CONF.system_user.user)
)
logger.info('Marked %s ttl-expired Inquiries as "timed out".' % (gc_count))
diff --git a/st2common/st2common/garbage_collection/trigger_instances.py b/st2common/st2common/garbage_collection/trigger_instances.py
index 47996614dd..0fbabb5e72 100644
--- a/st2common/st2common/garbage_collection/trigger_instances.py
+++ b/st2common/st2common/garbage_collection/trigger_instances.py
@@ -25,9 +25,7 @@
from st2common.persistence.trigger import TriggerInstance
from st2common.util import isotime
-__all__ = [
- 'purge_trigger_instances'
-]
+__all__ = ["purge_trigger_instances"]
def purge_trigger_instances(logger, timestamp):
@@ -36,23 +34,35 @@ def purge_trigger_instances(logger, timestamp):
:type timestamp: ``datetime.datetime
"""
if not timestamp:
- raise ValueError('Specify a valid timestamp to purge.')
+ raise ValueError("Specify a valid timestamp to purge.")
- logger.info('Purging trigger instances older than timestamp: %s' %
- timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
+ logger.info(
+ "Purging trigger instances older than timestamp: %s"
+ % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+ )
- query_filters = {'occurrence_time__lt': isotime.parse(timestamp)}
+ query_filters = {"occurrence_time__lt": isotime.parse(timestamp)}
try:
deleted_count = TriggerInstance.delete_by_query(**query_filters)
except InvalidQueryError as e:
- msg = ('Bad query (%s) used to delete trigger instances: %s'
- 'Please contact support.' % (query_filters, six.text_type(e)))
+ msg = (
+ "Bad query (%s) used to delete trigger instances: %s"
+ "Please contact support."
+ % (
+ query_filters,
+ six.text_type(e),
+ )
+ )
raise InvalidQueryError(msg)
except:
- logger.exception('Deleting instances using query_filters %s failed.', query_filters)
+ logger.exception(
+ "Deleting instances using query_filters %s failed.", query_filters
+ )
else:
- logger.info('Deleted %s trigger instance objects' % (deleted_count))
+ logger.info("Deleted %s trigger instance objects" % (deleted_count))
# Print stats
- logger.info('All trigger instance models older than timestamp %s were deleted.', timestamp)
+ logger.info(
+ "All trigger instance models older than timestamp %s were deleted.", timestamp
+ )
diff --git a/st2common/st2common/log.py b/st2common/st2common/log.py
index 5335af5f53..fbf6205bb9 100644
--- a/st2common/st2common/log.py
+++ b/st2common/st2common/log.py
@@ -35,34 +35,30 @@
from st2common.util.misc import get_normalized_file_path
__all__ = [
- 'getLogger',
- 'setup',
-
- 'FormatNamedFileHandler',
- 'ConfigurableSyslogHandler',
-
- 'LoggingStream',
-
- 'ignore_lib2to3_log_messages',
- 'ignore_statsd_log_messages'
+ "getLogger",
+ "setup",
+ "FormatNamedFileHandler",
+ "ConfigurableSyslogHandler",
+ "LoggingStream",
+ "ignore_lib2to3_log_messages",
+ "ignore_statsd_log_messages",
]
# NOTE: We set AUDIT to the highest log level which means AUDIT log messages will always be
# included (e.g. also if log level is set to INFO). To avoid that, we need to explicitly filter
# out AUDIT log level in service setup code.
logging.AUDIT = logging.CRITICAL + 10
-logging.addLevelName(logging.AUDIT, 'AUDIT')
+logging.addLevelName(logging.AUDIT, "AUDIT")
LOGGER_KEYS = [
- 'debug',
- 'info',
- 'warning',
- 'error',
- 'critical',
- 'exception',
- 'log',
-
- 'audit'
+ "debug",
+ "info",
+ "warning",
+ "error",
+ "critical",
+ "exception",
+ "log",
+ "audit",
]
# Note: This attribute is used by "find_caller" so it can correctly exclude this file when looking
@@ -89,10 +85,10 @@ def find_caller(stack_info=False, stacklevel=1):
on what runtine we're working in.
"""
if six.PY2:
- rv = '(unknown file)', 0, '(unknown function)'
+ rv = "(unknown file)", 0, "(unknown function)"
else:
# python 3, has extra tuple element at the end for stack information
- rv = '(unknown file)', 0, '(unknown function)', None
+ rv = "(unknown file)", 0, "(unknown function)", None
try:
f = logging.currentframe()
@@ -107,7 +103,7 @@ def find_caller(stack_info=False, stacklevel=1):
if not f:
f = orig_f
- while hasattr(f, 'f_code'):
+ while hasattr(f, "f_code"):
co = f.f_code
filename = os.path.normcase(co.co_filename)
if filename in (_srcfile, logging._srcfile): # This line is modified.
@@ -121,10 +117,10 @@ def find_caller(stack_info=False, stacklevel=1):
sinfo = None
if stack_info:
sio = io.StringIO()
- sio.write('Stack (most recent call last):\n')
+ sio.write("Stack (most recent call last):\n")
traceback.print_stack(f, file=sio)
sinfo = sio.getvalue()
- if sinfo[-1] == '\n':
+ if sinfo[-1] == "\n":
sinfo = sinfo[:-1]
sio.close()
rv = (filename, f.f_lineno, co.co_name, sinfo)
@@ -139,8 +135,8 @@ def decorate_log_method(func):
@wraps(func)
def func_wrapper(*args, **kwargs):
# Prefix extra keys with underscore
- if 'extra' in kwargs:
- kwargs['extra'] = prefix_dict_keys(dictionary=kwargs['extra'], prefix='_')
+ if "extra" in kwargs:
+ kwargs["extra"] = prefix_dict_keys(dictionary=kwargs["extra"], prefix="_")
try:
return func(*args, **kwargs)
@@ -150,10 +146,11 @@ def func_wrapper(*args, **kwargs):
# See:
# - https://docs.python.org/release/2.7.3/library/logging.html#logging.Logger.exception
# - https://docs.python.org/release/2.7.7/library/logging.html#logging.Logger.exception
- if 'got an unexpected keyword argument \'extra\'' in six.text_type(e):
- kwargs.pop('extra', None)
+ if "got an unexpected keyword argument 'extra'" in six.text_type(e):
+ kwargs.pop("extra", None)
return func(*args, **kwargs)
raise e
+
return func_wrapper
@@ -179,11 +176,11 @@ def decorate_logger_methods(logger):
def getLogger(name):
# make sure that prefix isn't appended multiple times to preserve logging name hierarchy
- prefix = 'st2.'
+ prefix = "st2."
if name.startswith(prefix):
logger = logging.getLogger(name)
else:
- logger_name = '{}{}'.format(prefix, name)
+ logger_name = "{}{}".format(prefix, name)
logger = logging.getLogger(logger_name)
logger = decorate_logger_methods(logger=logger)
@@ -191,7 +188,6 @@ def getLogger(name):
class LoggingStream(object):
-
def __init__(self, name, level=logging.ERROR):
self._logger = getLogger(name)
self._level = level
@@ -219,11 +215,16 @@ def _add_exclusion_filters(handlers, excludes=None):
def _redirect_stderr():
# It is ok to redirect stderr as none of the st2 handlers write to stderr.
- sys.stderr = LoggingStream('STDERR')
+ sys.stderr = LoggingStream("STDERR")
-def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_loggers=False,
- st2_conf_path=None):
+def setup(
+ config_file,
+ redirect_stderr=True,
+ excludes=None,
+ disable_existing_loggers=False,
+ st2_conf_path=None,
+):
"""
Configure logging from file.
@@ -232,16 +233,18 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log
absolute path relative to st2.conf.
:type st2_conf_path: ``str``
"""
- if st2_conf_path and config_file[:2] == './' and not os.path.isfile(config_file):
+ if st2_conf_path and config_file[:2] == "./" and not os.path.isfile(config_file):
# Logging config path is relative to st2.conf, resolve it to full absolute path
directory = os.path.dirname(st2_conf_path)
config_file_name = os.path.basename(config_file)
config_file = os.path.join(directory, config_file_name)
try:
- logging.config.fileConfig(config_file,
- defaults=None,
- disable_existing_loggers=disable_existing_loggers)
+ logging.config.fileConfig(
+ config_file,
+ defaults=None,
+ disable_existing_loggers=disable_existing_loggers,
+ )
handlers = logging.getLoggerClass().manager.root.handlers
_add_exclusion_filters(handlers=handlers, excludes=excludes)
if redirect_stderr:
@@ -251,13 +254,13 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log
tb_msg = traceback.format_exc()
msg = str(exc)
- msg += '\n\n' + tb_msg
+ msg += "\n\n" + tb_msg
# revert stderr redirection since there is no logger in place.
sys.stderr = sys.__stderr__
# No logger yet therefore write to stderr
- sys.stderr.write('ERROR: %s' % (msg))
+ sys.stderr.write("ERROR: %s" % (msg))
raise exc_cls(six.text_type(msg))
@@ -271,10 +274,10 @@ def ignore_lib2to3_log_messages():
class MockLoggingModule(object):
def getLogger(self, *args, **kwargs):
- return logging.getLogger('lib2to3')
+ return logging.getLogger("lib2to3")
lib2to3.pgen2.driver.logging = MockLoggingModule()
- logging.getLogger('lib2to3').setLevel(logging.ERROR)
+ logging.getLogger("lib2to3").setLevel(logging.ERROR)
def ignore_statsd_log_messages():
@@ -288,8 +291,8 @@ def ignore_statsd_log_messages():
class MockLoggingModule(object):
def getLogger(self, *args, **kwargs):
- return logging.getLogger('statsd')
+ return logging.getLogger("statsd")
statsd.connection.logging = MockLoggingModule()
statsd.client.logging = MockLoggingModule()
- logging.getLogger('statsd').setLevel(logging.ERROR)
+ logging.getLogger("statsd").setLevel(logging.ERROR)
diff --git a/st2common/st2common/logging/filters.py b/st2common/st2common/logging/filters.py
index d997589a0e..1fef164028 100644
--- a/st2common/st2common/logging/filters.py
+++ b/st2common/st2common/logging/filters.py
@@ -17,9 +17,9 @@
import logging
__all__ = [
- 'LoggerNameExclusionFilter',
- 'LoggerFunctionNameExclusionFilter',
- 'LogLevelFilter',
+ "LoggerNameExclusionFilter",
+ "LoggerFunctionNameExclusionFilter",
+ "LogLevelFilter",
]
@@ -36,8 +36,11 @@ def filter(self, record):
if len(self._exclusions) < 1:
return True
- module_decomposition = record.name.split('.')
- exclude = len(module_decomposition) > 0 and module_decomposition[0] in self._exclusions
+ module_decomposition = record.name.split(".")
+ exclude = (
+ len(module_decomposition) > 0
+ and module_decomposition[0] in self._exclusions
+ )
return not exclude
@@ -54,7 +57,7 @@ def filter(self, record):
if len(self._exclusions) < 1:
return True
- function_name = getattr(record, 'funcName', None)
+ function_name = getattr(record, "funcName", None)
if function_name in self._exclusions:
return False
diff --git a/st2common/st2common/logging/formatters.py b/st2common/st2common/logging/formatters.py
index d20b240a5a..7c30e780a9 100644
--- a/st2common/st2common/logging/formatters.py
+++ b/st2common/st2common/logging/formatters.py
@@ -28,8 +28,8 @@
from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE
__all__ = [
- 'ConsoleLogFormatter',
- 'GelfLogFormatter',
+ "ConsoleLogFormatter",
+ "GelfLogFormatter",
]
SIMPLE_TYPES = (int, float) + six.string_types
@@ -37,16 +37,16 @@
# GELF logger specific constants
HOSTNAME = socket.gethostname()
-GELF_SPEC_VERSION = '1.1'
+GELF_SPEC_VERSION = "1.1"
COMMON_ATTRIBUTE_NAMES = [
- 'name',
- 'process',
- 'processName',
- 'module',
- 'filename',
- 'funcName',
- 'lineno'
+ "name",
+ "process",
+ "processName",
+ "module",
+ "filename",
+ "funcName",
+ "lineno",
]
@@ -60,9 +60,9 @@ def serialize_object(obj):
:rtype: ``str``
"""
# Try to serialize the object
- if getattr(obj, 'to_dict', None):
+ if getattr(obj, "to_dict", None):
value = obj.to_dict()
- elif getattr(obj, 'to_serializable_dict', None):
+ elif getattr(obj, "to_serializable_dict", None):
value = obj.to_serializable_dict(mask_secrets=True)
else:
value = repr(obj)
@@ -77,7 +77,9 @@ def process_attribute_value(key, value):
if not cfg.CONF.log.mask_secrets:
return value
- blacklisted_attribute_names = MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist
+ blacklisted_attribute_names = (
+ MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist
+ )
# NOTE: This can be expensive when processing large dicts or objects
if isinstance(value, SIMPLE_TYPES):
@@ -121,11 +123,16 @@ class BaseExtraLogFormatter(logging.Formatter):
dictionary need to be prefixed with a slash ('_').
"""
- PREFIX = '_' # Prefix for user provided attributes in the extra dict
+ PREFIX = "_" # Prefix for user provided attributes in the extra dict
def _get_extra_attributes(self, record):
- attributes = dict([(k, v) for k, v in six.iteritems(record.__dict__)
- if k.startswith(self.PREFIX)])
+ attributes = dict(
+ [
+ (k, v)
+ for k, v in six.iteritems(record.__dict__)
+ if k.startswith(self.PREFIX)
+ ]
+ )
return attributes
def _get_common_extra_attributes(self, record):
@@ -182,17 +189,17 @@ def format(self, record):
msg = super(ConsoleLogFormatter, self).format(record)
if attributes:
- msg = '%s (%s)' % (msg, attributes)
+ msg = "%s (%s)" % (msg, attributes)
return msg
def _dict_to_str(self, attributes):
result = []
for key, value in six.iteritems(attributes):
- item = '%s=%s' % (key[1:], repr(value))
+ item = "%s=%s" % (key[1:], repr(value))
result.append(item)
- result = ','.join(result)
+ result = ",".join(result)
return result
@@ -245,30 +252,32 @@ def format(self, record):
exc_info = record.exc_info
time_now_float = record.created
time_now_sec = int(time_now_float)
- level = self.PYTHON_TO_GELF_LEVEL_MAP.get(record.levelno, self.DEFAULT_LOG_LEVEL)
+ level = self.PYTHON_TO_GELF_LEVEL_MAP.get(
+ record.levelno, self.DEFAULT_LOG_LEVEL
+ )
common_attributes = self._get_common_extra_attributes(record=record)
full_msg = super(GelfLogFormatter, self).format(record)
data = {
- 'version': GELF_SPEC_VERSION,
- 'host': HOSTNAME,
- 'short_message': msg,
- 'full_message': full_msg,
- 'timestamp': time_now_sec,
- 'timestamp_f': time_now_float,
- 'level': level
+ "version": GELF_SPEC_VERSION,
+ "host": HOSTNAME,
+ "short_message": msg,
+ "full_message": full_msg,
+ "timestamp": time_now_sec,
+ "timestamp_f": time_now_float,
+ "level": level,
}
if exc_info:
# Include exception information
exc_type, exc_value, exc_tb = exc_info
- tb_str = ''.join(traceback.format_tb(exc_tb))
- data['_exception'] = six.text_type(exc_value)
- data['_traceback'] = tb_str
+ tb_str = "".join(traceback.format_tb(exc_tb))
+ data["_exception"] = six.text_type(exc_value)
+ data["_traceback"] = tb_str
# Include common Python log record attributes
- data['_python'] = common_attributes
+ data["_python"] = common_attributes
# Include user extra attributes
data.update(attributes)
diff --git a/st2common/st2common/logging/handlers.py b/st2common/st2common/logging/handlers.py
index ade4dfbb04..963ac197a0 100644
--- a/st2common/st2common/logging/handlers.py
+++ b/st2common/st2common/logging/handlers.py
@@ -24,26 +24,29 @@
from st2common.util import date as date_utils
__all__ = [
- 'FormatNamedFileHandler',
- 'ConfigurableSyslogHandler',
+ "FormatNamedFileHandler",
+ "ConfigurableSyslogHandler",
]
class FormatNamedFileHandler(logging.handlers.RotatingFileHandler):
- def __init__(self, filename, mode='a', maxBytes=0, backupCount=0, encoding=None, delay=False):
+ def __init__(
+ self, filename, mode="a", maxBytes=0, backupCount=0, encoding=None, delay=False
+ ):
# We add aditional values to the context which can be used in the log filename
timestamp = int(time.time())
- isotime_str = str(date_utils.get_datetime_utc_now()).replace(' ', '_')
+ isotime_str = str(date_utils.get_datetime_utc_now()).replace(" ", "_")
pid = os.getpid()
- format_values = {
- 'timestamp': timestamp,
- 'ts': isotime_str,
- 'pid': pid
- }
+ format_values = {"timestamp": timestamp, "ts": isotime_str, "pid": pid}
filename = filename.format(**format_values)
- super(FormatNamedFileHandler, self).__init__(filename, mode=mode, maxBytes=maxBytes,
- backupCount=backupCount, encoding=encoding,
- delay=delay)
+ super(FormatNamedFileHandler, self).__init__(
+ filename,
+ mode=mode,
+ maxBytes=maxBytes,
+ backupCount=backupCount,
+ encoding=encoding,
+ delay=delay,
+ )
class ConfigurableSyslogHandler(logging.handlers.SysLogHandler):
@@ -55,12 +58,12 @@ def __init__(self, address=None, facility=None, socktype=None):
if not socktype:
protocol = cfg.CONF.syslog.protocol.lower()
- if protocol == 'udp':
+ if protocol == "udp":
socktype = socket.SOCK_DGRAM
- elif protocol == 'tcp':
+ elif protocol == "tcp":
socktype = socket.SOCK_STREAM
else:
- raise ValueError('Unsupported protocol: %s' % (protocol))
+ raise ValueError("Unsupported protocol: %s" % (protocol))
if socktype:
super(ConfigurableSyslogHandler, self).__init__(address, facility, socktype)
diff --git a/st2common/st2common/logging/misc.py b/st2common/st2common/logging/misc.py
index 36f8b17986..de7f673431 100644
--- a/st2common/st2common/logging/misc.py
+++ b/st2common/st2common/logging/misc.py
@@ -23,32 +23,26 @@
from st2common.logging.filters import LoggerFunctionNameExclusionFilter
__all__ = [
- 'reopen_log_files',
-
- 'set_log_level_for_all_handlers',
- 'set_log_level_for_all_loggers',
-
- 'add_global_filters_for_all_loggers'
+ "reopen_log_files",
+ "set_log_level_for_all_handlers",
+ "set_log_level_for_all_loggers",
+ "add_global_filters_for_all_loggers",
]
LOG = logging.getLogger(__name__)
# Because some loggers are just waste of attention span
-SPECIAL_LOGGERS = {
- 'swagger_spec_validator.ref_validators': logging.INFO
-}
+SPECIAL_LOGGERS = {"swagger_spec_validator.ref_validators": logging.INFO}
# Log messages for function names which are very spammy and we want to filter out when DEBUG log
# level is enabled
IGNORED_FUNCTION_NAMES = [
# Used by pyamqp, logs every heartbit tick every 2 ms by default
- 'heartbeat_tick'
+ "heartbeat_tick"
]
# List of global filters which apply to all the loggers
-GLOBAL_FILTERS = [
- LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES)
-]
+GLOBAL_FILTERS = [LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES)]
def reopen_log_files(handlers):
@@ -65,8 +59,10 @@ def reopen_log_files(handlers):
if not isinstance(handler, logging.FileHandler):
continue
- LOG.info('Re-opening log file "%s" with mode "%s"\n' %
- (handler.baseFilename, handler.mode))
+ LOG.info(
+ 'Re-opening log file "%s" with mode "%s"\n'
+ % (handler.baseFilename, handler.mode)
+ )
try:
handler.acquire()
@@ -76,10 +72,10 @@ def reopen_log_files(handlers):
try:
handler.release()
except RuntimeError as e:
- if 'cannot release' in six.text_type(e):
+ if "cannot release" in six.text_type(e):
# Release failed which most likely indicates that acquire failed
# and lock was never acquired
- LOG.warn('Failed to release lock', exc_info=True)
+ LOG.warn("Failed to release lock", exc_info=True)
else:
raise e
@@ -112,7 +108,9 @@ def set_log_level_for_all_loggers(level=logging.DEBUG):
logger = add_filters_for_logger(logger=logger, filters=GLOBAL_FILTERS)
if logger.name in SPECIAL_LOGGERS:
- set_log_level_for_all_handlers(logger=logger, level=SPECIAL_LOGGERS.get(logger.name))
+ set_log_level_for_all_handlers(
+ logger=logger, level=SPECIAL_LOGGERS.get(logger.name)
+ )
else:
set_log_level_for_all_handlers(logger=logger, level=level)
@@ -152,7 +150,7 @@ def add_filters_for_logger(logger, filters):
if not isinstance(logger, logging.Logger):
return logger
- if not hasattr(logger, 'addFilter'):
+ if not hasattr(logger, "addFilter"):
return logger
for logger_filter in filters:
@@ -170,7 +168,7 @@ def get_logger_name_for_module(module, exclude_module_name=False):
module_file = module.__file__
base_dir = os.path.dirname(os.path.abspath(module_file))
module_name = os.path.basename(module_file)
- module_name = module_name.replace('.pyc', '').replace('.py', '')
+ module_name = module_name.replace(".pyc", "").replace(".py", "")
split = base_dir.split(os.path.sep)
split = [component for component in split if component]
@@ -178,15 +176,15 @@ def get_logger_name_for_module(module, exclude_module_name=False):
# Find first component which starts with st2 and use that as a starting point
start_index = 0
for index, component in enumerate(reversed(split)):
- if component.startswith('st2'):
- start_index = ((len(split) - 1) - index)
+ if component.startswith("st2"):
+ start_index = (len(split) - 1) - index
break
split = split[start_index:]
if exclude_module_name:
- name = '.'.join(split)
+ name = ".".join(split)
else:
- name = '.'.join(split) + '.' + module_name
+ name = ".".join(split) + "." + module_name
return name
diff --git a/st2common/st2common/metrics/base.py b/st2common/st2common/metrics/base.py
index 18801c901d..215780b86f 100644
--- a/st2common/st2common/metrics/base.py
+++ b/st2common/st2common/metrics/base.py
@@ -28,23 +28,22 @@
from st2common.exceptions.plugins import PluginLoadError
__all__ = [
- 'BaseMetricsDriver',
-
- 'Timer',
- 'Counter',
- 'CounterWithTimer',
-
- 'metrics_initialize',
- 'get_driver'
+ "BaseMetricsDriver",
+ "Timer",
+ "Counter",
+ "CounterWithTimer",
+ "metrics_initialize",
+ "get_driver",
]
-if not hasattr(cfg.CONF, 'metrics'):
+if not hasattr(cfg.CONF, "metrics"):
from st2common.config import register_opts
+
register_opts()
LOG = logging.getLogger(__name__)
-PLUGIN_NAMESPACE = 'st2common.metrics.driver'
+PLUGIN_NAMESPACE = "st2common.metrics.driver"
# Stores reference to the metrics driver class instance.
# NOTE: This value is populated lazily on the first get_driver() function call
@@ -97,6 +96,7 @@ class Timer(object):
"""
Timer context manager for easily sending timer statistics.
"""
+
def __init__(self, key, include_parameter=False):
check_key(key)
@@ -136,8 +136,9 @@ def __call__(self, func):
def wrapper(*args, **kw):
with self as metrics_timer:
if self._include_parameter:
- kw['metrics_timer'] = metrics_timer
+ kw["metrics_timer"] = metrics_timer
return func(*args, **kw)
+
return wrapper
@@ -145,6 +146,7 @@ class Counter(object):
"""
Counter context manager for easily sending counter statistics.
"""
+
def __init__(self, key):
check_key(key)
self.key = key
@@ -162,6 +164,7 @@ def __call__(self, func):
def wrapper(*args, **kw):
with self:
return func(*args, **kw)
+
return wrapper
@@ -209,8 +212,9 @@ def __call__(self, func):
def wrapper(*args, **kw):
with self as counter_with_timer:
if self._include_parameter:
- kw['metrics_counter_with_timer'] = counter_with_timer
+ kw["metrics_counter_with_timer"] = counter_with_timer
return func(*args, **kw)
+
return wrapper
@@ -223,7 +227,9 @@ def metrics_initialize():
try:
METRICS = get_plugin_instance(PLUGIN_NAMESPACE, cfg.CONF.metrics.driver)
except (NoMatches, MultipleMatches, NoSuchOptError) as error:
- raise PluginLoadError('Error loading metrics driver. Check configuration: %s' % error)
+ raise PluginLoadError(
+ "Error loading metrics driver. Check configuration: %s" % error
+ )
return METRICS
diff --git a/st2common/st2common/metrics/drivers/echo_driver.py b/st2common/st2common/metrics/drivers/echo_driver.py
index 40b2ed3947..7cb115aab6 100644
--- a/st2common/st2common/metrics/drivers/echo_driver.py
+++ b/st2common/st2common/metrics/drivers/echo_driver.py
@@ -16,9 +16,7 @@
from st2common import log as logging
from st2common.metrics.base import BaseMetricsDriver
-__all__ = [
- 'EchoDriver'
-]
+__all__ = ["EchoDriver"]
LOG = logging.getLogger(__name__)
@@ -29,19 +27,19 @@ class EchoDriver(BaseMetricsDriver):
"""
def time(self, key, time):
- LOG.debug('[metrics] time(key=%s, time=%s)' % (key, time))
+ LOG.debug("[metrics] time(key=%s, time=%s)" % (key, time))
def inc_counter(self, key, amount=1):
- LOG.debug('[metrics] counter.incr(%s, %s)' % (key, amount))
+ LOG.debug("[metrics] counter.incr(%s, %s)" % (key, amount))
def dec_counter(self, key, amount=1):
- LOG.debug('[metrics] counter.decr(%s, %s)' % (key, amount))
+ LOG.debug("[metrics] counter.decr(%s, %s)" % (key, amount))
def set_gauge(self, key, value):
- LOG.debug('[metrics] set_gauge(%s, %s)' % (key, value))
+ LOG.debug("[metrics] set_gauge(%s, %s)" % (key, value))
def inc_gauge(self, key, amount=1):
- LOG.debug('[metrics] gauge.incr(%s, %s)' % (key, amount))
+ LOG.debug("[metrics] gauge.incr(%s, %s)" % (key, amount))
def dec_gauge(self, key, amount=1):
- LOG.debug('[metrics] gauge.decr(%s, %s)' % (key, amount))
+ LOG.debug("[metrics] gauge.decr(%s, %s)" % (key, amount))
diff --git a/st2common/st2common/metrics/drivers/noop_driver.py b/st2common/st2common/metrics/drivers/noop_driver.py
index 6f816f2a69..658ee10a40 100644
--- a/st2common/st2common/metrics/drivers/noop_driver.py
+++ b/st2common/st2common/metrics/drivers/noop_driver.py
@@ -15,9 +15,7 @@
from st2common.metrics.base import BaseMetricsDriver
-__all__ = [
- 'NoopDriver'
-]
+__all__ = ["NoopDriver"]
class NoopDriver(BaseMetricsDriver):
diff --git a/st2common/st2common/metrics/drivers/statsd_driver.py b/st2common/st2common/metrics/drivers/statsd_driver.py
index c334837e9b..efbefde601 100644
--- a/st2common/st2common/metrics/drivers/statsd_driver.py
+++ b/st2common/st2common/metrics/drivers/statsd_driver.py
@@ -30,15 +30,9 @@
LOG = logging.getLogger(__name__)
# Which exceptions thrown by statsd library should be considered as non-fatal
-NON_FATAL_EXC_CLASSES = [
- socket.error,
- IOError,
- OSError
-]
+NON_FATAL_EXC_CLASSES = [socket.error, IOError, OSError]
-__all__ = [
- 'StatsdDriver'
-]
+__all__ = ["StatsdDriver"]
class StatsdDriver(BaseMetricsDriver):
@@ -55,11 +49,15 @@ class StatsdDriver(BaseMetricsDriver):
"""
def __init__(self):
- statsd.Connection.set_defaults(host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port,
- sample_rate=cfg.CONF.metrics.sample_rate)
-
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ statsd.Connection.set_defaults(
+ host=cfg.CONF.metrics.host,
+ port=cfg.CONF.metrics.port,
+ sample_rate=cfg.CONF.metrics.sample_rate,
+ )
+
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def time(self, key, time):
"""
Timer metric
@@ -68,11 +66,12 @@ def time(self, key, time):
assert isinstance(time, Number)
key = get_full_key_name(key)
- timer = statsd.Timer('')
+ timer = statsd.Timer("")
timer.send(key, time)
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def inc_counter(self, key, amount=1):
"""
Increment counter
@@ -84,8 +83,9 @@ def inc_counter(self, key, amount=1):
counter = statsd.Counter(key)
counter.increment(delta=amount)
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def dec_counter(self, key, amount=1):
"""
Decrement metric
@@ -97,8 +97,9 @@ def dec_counter(self, key, amount=1):
counter = statsd.Counter(key)
counter.decrement(delta=amount)
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def set_gauge(self, key, value):
"""
Set gauge value.
@@ -110,8 +111,9 @@ def set_gauge(self, key, value):
gauge = statsd.Gauge(key)
gauge.send(None, value)
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def inc_gauge(self, key, amount=1):
"""
Increment gauge value.
@@ -123,8 +125,9 @@ def inc_gauge(self, key, amount=1):
gauge = statsd.Gauge(key)
gauge.increment(None, amount)
- @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG,
- level=stdlib_logging.WARNING)
+ @ignore_and_log_exception(
+ exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING
+ )
def dec_gauge(self, key, amount=1):
"""
Decrement gauge value.
diff --git a/st2common/st2common/metrics/utils.py b/st2common/st2common/metrics/utils.py
index f741743cd2..710aceff15 100644
--- a/st2common/st2common/metrics/utils.py
+++ b/st2common/st2common/metrics/utils.py
@@ -16,10 +16,7 @@
import six
from oslo_config import cfg
-__all__ = [
- 'get_full_key_name',
- 'check_key'
-]
+__all__ = ["get_full_key_name", "check_key"]
def get_full_key_name(key):
@@ -27,14 +24,14 @@ def get_full_key_name(key):
Return full metric key name, taking into account optional prefix which can be specified in the
config.
"""
- parts = ['st2']
+ parts = ["st2"]
if cfg.CONF.metrics.prefix:
parts.append(cfg.CONF.metrics.prefix)
parts.append(key)
- return '.'.join(parts)
+ return ".".join(parts)
def check_key(key):
diff --git a/st2common/st2common/middleware/cors.py b/st2common/st2common/middleware/cors.py
index 1388e65e63..eaeac86f07 100644
--- a/st2common/st2common/middleware/cors.py
+++ b/st2common/st2common/middleware/cors.py
@@ -42,18 +42,18 @@ def __call__(self, environ, start_response):
def custom_start_response(status, headers, exc_info=None):
headers = ResponseHeaders(headers)
- origin = request.headers.get('Origin')
+ origin = request.headers.get("Origin")
origins = OrderedSet(cfg.CONF.api.allow_origin)
# Build a list of the default allowed origins
public_api_url = cfg.CONF.auth.api_url
# Default gulp development server WebUI URL
- origins.add('http://127.0.0.1:3000')
+ origins.add("http://127.0.0.1:3000")
# By default WebUI simple http server listens on 8080
- origins.add('http://localhost:8080')
- origins.add('http://127.0.0.1:8080')
+ origins.add("http://localhost:8080")
+ origins.add("http://127.0.0.1:8080")
if public_api_url:
# Public API URL
@@ -62,7 +62,7 @@ def custom_start_response(status, headers, exc_info=None):
origins = list(origins)
if origin:
- if '*' in origins:
+ if "*" in origins:
origin_allowed = origin
else:
# See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header
@@ -70,21 +70,32 @@ def custom_start_response(status, headers, exc_info=None):
else:
origin_allowed = list(origins)[0]
- methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS']
- request_headers_allowed = ['Content-Type', 'Authorization', HEADER_ATTRIBUTE_NAME,
- HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER]
- response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count',
- REQUEST_ID_HEADER]
-
- headers['Access-Control-Allow-Origin'] = origin_allowed
- headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed)
- headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed)
- headers['Access-Control-Allow-Credentials'] = 'true'
- headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed)
+ methods_allowed = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
+ request_headers_allowed = [
+ "Content-Type",
+ "Authorization",
+ HEADER_ATTRIBUTE_NAME,
+ HEADER_API_KEY_ATTRIBUTE_NAME,
+ REQUEST_ID_HEADER,
+ ]
+ response_headers_allowed = [
+ "Content-Type",
+ "X-Limit",
+ "X-Total-Count",
+ REQUEST_ID_HEADER,
+ ]
+
+ headers["Access-Control-Allow-Origin"] = origin_allowed
+ headers["Access-Control-Allow-Methods"] = ",".join(methods_allowed)
+ headers["Access-Control-Allow-Headers"] = ",".join(request_headers_allowed)
+ headers["Access-Control-Allow-Credentials"] = "true"
+ headers["Access-Control-Expose-Headers"] = ",".join(
+ response_headers_allowed
+ )
return start_response(status, headers._items, exc_info)
- if request.method == 'OPTIONS':
+ if request.method == "OPTIONS":
return Response()(environ, custom_start_response)
else:
return self.app(environ, custom_start_response)
diff --git a/st2common/st2common/middleware/error_handling.py b/st2common/st2common/middleware/error_handling.py
index 478cf3691b..d7ae59cde5 100644
--- a/st2common/st2common/middleware/error_handling.py
+++ b/st2common/st2common/middleware/error_handling.py
@@ -50,13 +50,13 @@ def __call__(self, environ, start_response):
except NotFoundException:
raise exc.HTTPNotFound()
except Exception as e:
- status = getattr(e, 'code', exc.HTTPInternalServerError.code)
+ status = getattr(e, "code", exc.HTTPInternalServerError.code)
- if hasattr(e, 'detail') and not getattr(e, 'comment'):
- setattr(e, 'comment', getattr(e, 'detail'))
+ if hasattr(e, "detail") and not getattr(e, "comment"):
+ setattr(e, "comment", getattr(e, "detail"))
- if hasattr(e, 'body') and isinstance(getattr(e, 'body', None), dict):
- body = getattr(e, 'body', None)
+ if hasattr(e, "body") and isinstance(getattr(e, "body", None), dict):
+ body = getattr(e, "body", None)
else:
body = {}
@@ -69,40 +69,40 @@ def __call__(self, environ, start_response):
elif isinstance(e, db_exceptions.StackStormDBObjectConflictError):
status_code = exc.HTTPConflict.code
message = six.text_type(e)
- body['conflict-id'] = getattr(e, 'conflict_id', None)
+ body["conflict-id"] = getattr(e, "conflict_id", None)
elif isinstance(e, rbac_exceptions.AccessDeniedError):
status_code = exc.HTTPForbidden.code
message = six.text_type(e)
elif isinstance(e, (ValueValidationException, ValueError, ValidationError)):
status_code = exc.HTTPBadRequest.code
- message = getattr(e, 'message', six.text_type(e))
+ message = getattr(e, "message", six.text_type(e))
else:
status_code = exc.HTTPInternalServerError.code
- message = 'Internal Server Error'
+ message = "Internal Server Error"
# Log the error
is_internal_server_error = status_code == exc.HTTPInternalServerError.code
- error_msg = getattr(e, 'comment', six.text_type(e))
+ error_msg = getattr(e, "comment", six.text_type(e))
extra = {
- 'exception_class': e.__class__.__name__,
- 'exception_message': six.text_type(e),
- 'exception_data': e.__dict__
+ "exception_class": e.__class__.__name__,
+ "exception_message": six.text_type(e),
+ "exception_data": e.__dict__,
}
if is_internal_server_error:
- LOG.exception('API call failed: %s', error_msg, extra=extra)
+ LOG.exception("API call failed: %s", error_msg, extra=extra)
else:
- LOG.debug('API call failed: %s', error_msg, extra=extra)
+ LOG.debug("API call failed: %s", error_msg, extra=extra)
if is_debugging_enabled():
LOG.debug(traceback.format_exc())
- body['faultstring'] = message
+ body["faultstring"] = message
response_body = json_encode(body)
headers = {
- 'Content-Type': 'application/json',
- 'Content-Length': str(len(response_body))
+ "Content-Type": "application/json",
+ "Content-Length": str(len(response_body)),
}
resp = Response(response_body, status=status_code, headers=headers)
diff --git a/st2common/st2common/middleware/instrumentation.py b/st2common/st2common/middleware/instrumentation.py
index 8ff7445f75..e5d01d2223 100644
--- a/st2common/st2common/middleware/instrumentation.py
+++ b/st2common/st2common/middleware/instrumentation.py
@@ -21,10 +21,7 @@
from st2common.util.date import get_datetime_utc_now
from st2common.router import NotFoundException
-__all__ = [
- 'RequestInstrumentationMiddleware',
- 'ResponseInstrumentationMiddleware'
-]
+__all__ = ["RequestInstrumentationMiddleware", "ResponseInstrumentationMiddleware"]
LOG = logging.getLogger(__name__)
@@ -54,10 +51,11 @@ def __call__(self, environ, start_response):
# NOTE: We don't track per request and response metrics for /v1/executions/ and some
# other endpoints because this would result in a lot of unique metrics which is an
# anti-pattern and causes unnecessary load on the metrics server.
- submit_metrics = endpoint.get('x-submit-metrics', True)
- operation_id = endpoint.get('operationId', None)
- is_get_one_endpoint = bool(operation_id) and (operation_id.endswith('.get') or
- operation_id.endswith('.get_one'))
+ submit_metrics = endpoint.get("x-submit-metrics", True)
+ operation_id = endpoint.get("operationId", None)
+ is_get_one_endpoint = bool(operation_id) and (
+ operation_id.endswith(".get") or operation_id.endswith(".get_one")
+ )
if is_get_one_endpoint:
# NOTE: We don't submit metrics for any get one API endpoint since this would result
@@ -65,22 +63,22 @@ def __call__(self, environ, start_response):
submit_metrics = False
if not submit_metrics:
- LOG.debug('Not submitting request metrics for path: %s' % (request.path))
+ LOG.debug("Not submitting request metrics for path: %s" % (request.path))
return self.app(environ, start_response)
metrics_driver = get_driver()
- key = '%s.request.total' % (self._service_name)
+ key = "%s.request.total" % (self._service_name)
metrics_driver.inc_counter(key)
- key = '%s.request.method.%s' % (self._service_name, request.method)
+ key = "%s.request.method.%s" % (self._service_name, request.method)
metrics_driver.inc_counter(key)
- path = request.path.replace('/', '_')
- key = '%s.request.path.%s' % (self._service_name, path)
+ path = request.path.replace("/", "_")
+ key = "%s.request.path.%s" % (self._service_name, path)
metrics_driver.inc_counter(key)
- if self._service_name == 'stream':
+ if self._service_name == "stream":
# For stream service, we also record current number of open connections.
# Due to the way stream service works, we need to utilize eventlet posthook to
# correctly set the counter when the connection is closed / full response is returned.
@@ -88,34 +86,34 @@ def __call__(self, environ, start_response):
# hooks for details
# Increase request counter
- key = '%s.request' % (self._service_name)
+ key = "%s.request" % (self._service_name)
metrics_driver.inc_counter(key)
# Increase "total number of connections" gauge
- metrics_driver.inc_gauge('stream.connections', 1)
+ metrics_driver.inc_gauge("stream.connections", 1)
start_time = get_datetime_utc_now()
def update_metrics_hook(env):
# Hook which is called at the very end after all the response has been sent and
# connection closed
- time_delta = (get_datetime_utc_now() - start_time)
+ time_delta = get_datetime_utc_now() - start_time
duration = time_delta.total_seconds()
# Send total request time
metrics_driver.time(key, duration)
# Decrease "current number of connections" gauge
- metrics_driver.dec_gauge('stream.connections', 1)
+ metrics_driver.dec_gauge("stream.connections", 1)
# NOTE: Some tests mock environ and there 'eventlet.posthooks' key is not available
- if 'eventlet.posthooks' in environ:
- environ['eventlet.posthooks'].append((update_metrics_hook, (), {}))
+ if "eventlet.posthooks" in environ:
+ environ["eventlet.posthooks"].append((update_metrics_hook, (), {}))
return self.app(environ, start_response)
else:
# Track and time current number of processing requests
- key = '%s.request' % (self._service_name)
+ key = "%s.request" % (self._service_name)
with CounterWithTimer(key=key):
return self.app(environ, start_response)
@@ -138,11 +136,12 @@ def __init__(self, app, router, service_name):
def __call__(self, environ, start_response):
# Track and time current number of processing requests
def custom_start_response(status, headers, exc_info=None):
- status_code = int(status.split(' ')[0])
+ status_code = int(status.split(" ")[0])
metrics_driver = get_driver()
- metrics_driver.inc_counter('%s.response.status.%s' % (self._service_name,
- status_code))
+ metrics_driver.inc_counter(
+ "%s.response.status.%s" % (self._service_name, status_code)
+ )
return start_response(status, headers, exc_info)
diff --git a/st2common/st2common/middleware/logging.py b/st2common/st2common/middleware/logging.py
index d41622ff29..a044e2c59b 100644
--- a/st2common/st2common/middleware/logging.py
+++ b/st2common/st2common/middleware/logging.py
@@ -33,7 +33,7 @@
SECRET_QUERY_PARAMS = [
QUERY_PARAM_ATTRIBUTE_NAME,
- QUERY_PARAM_API_KEY_ATTRIBUTE_NAME
+ QUERY_PARAM_API_KEY_ATTRIBUTE_NAME,
] + MASKED_ATTRIBUTES_BLACKLIST
try:
@@ -68,21 +68,23 @@ def __call__(self, environ, start_response):
# Log the incoming request
values = {
- 'method': request.method,
- 'path': request.path,
- 'remote_addr': request.remote_addr,
- 'query': query_params,
- 'request_id': request.headers.get(REQUEST_ID_HEADER, None)
+ "method": request.method,
+ "path": request.path,
+ "remote_addr": request.remote_addr,
+ "query": query_params,
+ "request_id": request.headers.get(REQUEST_ID_HEADER, None),
}
- LOG.info('%(request_id)s - %(method)s %(path)s with query=%(query)s' %
- values, extra=values)
+ LOG.info(
+ "%(request_id)s - %(method)s %(path)s with query=%(query)s" % values,
+ extra=values,
+ )
def custom_start_response(status, headers, exc_info=None):
- status_code.append(int(status.split(' ')[0]))
+ status_code.append(int(status.split(" ")[0]))
for name, value in headers:
- if name.lower() == 'content-length':
+ if name.lower() == "content-length":
content_length.append(int(value))
break
@@ -95,7 +97,7 @@ def custom_start_response(status, headers, exc_info=None):
except NotFoundException:
endpoint = {}
- log_result = endpoint.get('x-log-result', True)
+ log_result = endpoint.get("x-log-result", True)
if isinstance(retval, (types.GeneratorType, itertools.chain)):
# Note: We don't log the result when return value is a generator, because this would
@@ -105,22 +107,28 @@ def custom_start_response(status, headers, exc_info=None):
# Log the response
values = {
- 'method': request.method,
- 'path': request.path,
- 'remote_addr': request.remote_addr,
- 'status': status_code[0],
- 'runtime': float("{0:.3f}".format((clock() - start_time) * 10**3)),
- 'content_length': content_length[0] if content_length else len(b''.join(retval)),
- 'request_id': request.headers.get(REQUEST_ID_HEADER, None)
+ "method": request.method,
+ "path": request.path,
+ "remote_addr": request.remote_addr,
+ "status": status_code[0],
+ "runtime": float("{0:.3f}".format((clock() - start_time) * 10 ** 3)),
+ "content_length": content_length[0]
+ if content_length
+ else len(b"".join(retval)),
+ "request_id": request.headers.get(REQUEST_ID_HEADER, None),
}
- log_msg = '%(request_id)s - %(status)s %(content_length)s %(runtime)sms' % (values)
+ log_msg = "%(request_id)s - %(status)s %(content_length)s %(runtime)sms" % (
+ values
+ )
LOG.info(log_msg, extra=values)
if log_result:
- values['result'] = retval[0]
- log_msg = ('%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s' %
- (values))
+ values["result"] = retval[0]
+ log_msg = (
+ "%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s"
+ % (values)
+ )
LOG.debug(log_msg, extra=values)
return retval
diff --git a/st2common/st2common/middleware/streaming.py b/st2common/st2common/middleware/streaming.py
index 8f48dedbcf..eb09084b30 100644
--- a/st2common/st2common/middleware/streaming.py
+++ b/st2common/st2common/middleware/streaming.py
@@ -16,9 +16,7 @@
from __future__ import absolute_import
import fnmatch
-__all__ = [
- 'StreamingMiddleware'
-]
+__all__ = ["StreamingMiddleware"]
class StreamingMiddleware(object):
@@ -32,7 +30,7 @@ def __call__(self, environ, start_response):
# middleware is not important since it acts as pass-through.
matches = False
- req_path = environ.get('PATH_INFO', None)
+ req_path = environ.get("PATH_INFO", None)
if not self._path_whitelist:
matches = True
@@ -43,6 +41,6 @@ def __call__(self, environ, start_response):
break
if matches:
- environ['eventlet.minimum_write_chunk_size'] = 0
+ environ["eventlet.minimum_write_chunk_size"] = 0
return self.app(environ, start_response)
diff --git a/st2common/st2common/models/api/action.py b/st2common/st2common/models/api/action.py
index 70eaeddad9..1924f54460 100644
--- a/st2common/st2common/models/api/action.py
+++ b/st2common/st2common/models/api/action.py
@@ -23,7 +23,10 @@
from st2common.models.api.base import BaseAPI
from st2common.models.api.base import APIUIDMixin
from st2common.models.api.tag import TagsHelper
-from st2common.models.api.notification import (NotificationSubSchemaAPI, NotificationsHelper)
+from st2common.models.api.notification import (
+ NotificationSubSchemaAPI,
+ NotificationsHelper,
+)
from st2common.models.db.action import ActionDB
from st2common.models.db.actionalias import ActionAliasDB
from st2common.models.db.executionstate import ActionExecutionStateDB
@@ -34,17 +37,16 @@
__all__ = [
- 'ActionAPI',
- 'ActionCreateAPI',
- 'LiveActionAPI',
- 'LiveActionCreateAPI',
- 'RunnerTypeAPI',
-
- 'AliasExecutionAPI',
- 'AliasMatchAndExecuteInputAPI',
- 'ActionAliasAPI',
- 'ActionAliasMatchAPI',
- 'ActionAliasHelpAPI'
+ "ActionAPI",
+ "ActionCreateAPI",
+ "LiveActionAPI",
+ "LiveActionCreateAPI",
+ "RunnerTypeAPI",
+ "AliasExecutionAPI",
+ "AliasMatchAndExecuteInputAPI",
+ "ActionAliasAPI",
+ "ActionAliasMatchAPI",
+ "ActionAliasHelpAPI",
]
@@ -56,6 +58,7 @@ class RunnerTypeAPI(BaseAPI):
The representation of an RunnerType in the system. An RunnerType
has a one-to-one mapping to a particular ActionRunner implementation.
"""
+
model = RunnerTypeDB
schema = {
"title": "Runner",
@@ -65,42 +68,40 @@ class RunnerTypeAPI(BaseAPI):
"id": {
"description": "The unique identifier for the action runner.",
"type": "string",
- "default": None
- },
- "uid": {
- "type": "string"
+ "default": None,
},
+ "uid": {"type": "string"},
"name": {
"description": "The name of the action runner.",
"type": "string",
- "required": True
+ "required": True,
},
"description": {
"description": "The description of the action runner.",
- "type": "string"
+ "type": "string",
},
"enabled": {
"description": "Enable or disable the action runner.",
"type": "boolean",
- "default": True
+ "default": True,
},
"runner_package": {
"description": "The python package that implements the "
- "action runner for this type.",
+ "action runner for this type.",
"type": "string",
- "required": False
+ "required": False,
},
"runner_module": {
"description": "The python module that implements the "
- "action runner for this type.",
+ "action runner for this type.",
"type": "string",
- "required": True
+ "required": True,
},
"query_module": {
"description": "The python module that implements the "
- "results tracker (querier) for the runner.",
+ "results tracker (querier) for the runner.",
"type": "string",
- "required": False
+ "required": False,
},
"runner_parameters": {
"description": "Input parameters for the action runner.",
@@ -108,24 +109,22 @@ class RunnerTypeAPI(BaseAPI):
"patternProperties": {
r"^\w+$": util_schema.get_action_parameters_schema()
},
- 'additionalProperties': False
+ "additionalProperties": False,
},
"output_key": {
"description": "Default key to expect results to be published to.",
"type": "string",
- "required": False
+ "required": False,
},
"output_schema": {
"description": "Schema for the runner's output.",
"type": "object",
- "patternProperties": {
- r"^\w+$": util_schema.get_action_output_schema()
- },
- 'additionalProperties': False,
- "default": {}
+ "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()},
+ "additionalProperties": False,
+ "default": {},
},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def __init__(self, **kw):
@@ -138,25 +137,34 @@ def __init__(self, **kw):
# modified one
for key, value in kw.items():
setattr(self, key, value)
- if not hasattr(self, 'runner_parameters'):
- setattr(self, 'runner_parameters', dict())
+ if not hasattr(self, "runner_parameters"):
+ setattr(self, "runner_parameters", dict())
@classmethod
def to_model(cls, runner_type):
name = runner_type.name
description = runner_type.description
- enabled = getattr(runner_type, 'enabled', True)
- runner_package = getattr(runner_type, 'runner_package', runner_type.runner_module)
+ enabled = getattr(runner_type, "enabled", True)
+ runner_package = getattr(
+ runner_type, "runner_package", runner_type.runner_module
+ )
runner_module = str(runner_type.runner_module)
- runner_parameters = getattr(runner_type, 'runner_parameters', dict())
- output_key = getattr(runner_type, 'output_key', None)
- output_schema = getattr(runner_type, 'output_schema', dict())
- query_module = getattr(runner_type, 'query_module', None)
-
- model = cls.model(name=name, description=description, enabled=enabled,
- runner_package=runner_package, runner_module=runner_module,
- runner_parameters=runner_parameters, output_schema=output_schema,
- query_module=query_module, output_key=output_key)
+ runner_parameters = getattr(runner_type, "runner_parameters", dict())
+ output_key = getattr(runner_type, "output_key", None)
+ output_schema = getattr(runner_type, "output_schema", dict())
+ query_module = getattr(runner_type, "query_module", None)
+
+ model = cls.model(
+ name=name,
+ description=description,
+ enabled=enabled,
+ runner_package=runner_package,
+ runner_module=runner_module,
+ runner_parameters=runner_parameters,
+ output_schema=output_schema,
+ query_module=query_module,
+ output_key=output_key,
+ )
return model
@@ -174,44 +182,42 @@ class ActionAPI(BaseAPI, APIUIDMixin):
"properties": {
"id": {
"description": "The unique identifier for the action.",
- "type": "string"
+ "type": "string",
},
"ref": {
"description": "System computed user friendly reference for the action. \
Provided value will be overridden by computed value.",
- "type": "string"
- },
- "uid": {
- "type": "string"
+ "type": "string",
},
+ "uid": {"type": "string"},
"name": {
"description": "The name of the action.",
"type": "string",
- "required": True
+ "required": True,
},
"description": {
"description": "The description of the action.",
- "type": "string"
+ "type": "string",
},
"enabled": {
"description": "Enable or disable the action from invocation.",
"type": "boolean",
- "default": True
+ "default": True,
},
"runner_type": {
"description": "The type of runner that executes the action.",
"type": "string",
- "required": True
+ "required": True,
},
"entry_point": {
"description": "The entry point for the action.",
"type": "string",
- "default": ""
+ "default": "",
},
"pack": {
"description": "The content pack this action belongs to.",
"type": "string",
- "default": DEFAULT_PACK_NAME
+ "default": DEFAULT_PACK_NAME,
},
"parameters": {
"description": "Input parameters for the action.",
@@ -219,22 +225,20 @@ class ActionAPI(BaseAPI, APIUIDMixin):
"patternProperties": {
r"^\w+$": util_schema.get_action_parameters_schema()
},
- 'additionalProperties': False,
- "default": {}
+ "additionalProperties": False,
+ "default": {},
},
"output_schema": {
"description": "Schema for the action's output.",
"type": "object",
- "patternProperties": {
- r"^\w+$": util_schema.get_action_output_schema()
- },
- 'additionalProperties': False,
- "default": {}
+ "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()},
+ "additionalProperties": False,
+ "default": {},
},
"tags": {
"description": "User associated metadata assigned to this object.",
"type": "array",
- "items": {"type": "object"}
+ "items": {"type": "object"},
},
"notify": {
"description": "Notification settings for action.",
@@ -242,52 +246,52 @@ class ActionAPI(BaseAPI, APIUIDMixin):
"properties": {
"on-complete": NotificationSubSchemaAPI,
"on-failure": NotificationSubSchemaAPI,
- "on-success": NotificationSubSchemaAPI
+ "on-success": NotificationSubSchemaAPI,
},
- "additionalProperties": False
+ "additionalProperties": False,
},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def __init__(self, **kw):
for key, value in kw.items():
setattr(self, key, value)
- if not hasattr(self, 'parameters'):
- setattr(self, 'parameters', dict())
- if not hasattr(self, 'entry_point'):
- setattr(self, 'entry_point', '')
+ if not hasattr(self, "parameters"):
+ setattr(self, "parameters", dict())
+ if not hasattr(self, "entry_point"):
+ setattr(self, "entry_point", "")
@classmethod
def from_model(cls, model, mask_secrets=False):
action = cls._from_model(model)
- action['runner_type'] = action.get('runner_type', {}).get('name', None)
- action['tags'] = TagsHelper.from_model(model.tags)
+ action["runner_type"] = action.get("runner_type", {}).get("name", None)
+ action["tags"] = TagsHelper.from_model(model.tags)
- if getattr(model, 'notify', None):
- action['notify'] = NotificationsHelper.from_model(model.notify)
+ if getattr(model, "notify", None):
+ action["notify"] = NotificationsHelper.from_model(model.notify)
return cls(**action)
@classmethod
def to_model(cls, action):
- name = getattr(action, 'name', None)
- description = getattr(action, 'description', None)
- enabled = bool(getattr(action, 'enabled', True))
+ name = getattr(action, "name", None)
+ description = getattr(action, "description", None)
+ enabled = bool(getattr(action, "enabled", True))
entry_point = str(action.entry_point)
pack = str(action.pack)
- runner_type = {'name': str(action.runner_type)}
- parameters = getattr(action, 'parameters', dict())
- output_schema = getattr(action, 'output_schema', dict())
- tags = TagsHelper.to_model(getattr(action, 'tags', []))
+ runner_type = {"name": str(action.runner_type)}
+ parameters = getattr(action, "parameters", dict())
+ output_schema = getattr(action, "output_schema", dict())
+ tags = TagsHelper.to_model(getattr(action, "tags", []))
ref = ResourceReference.to_string_reference(pack=pack, name=name)
- if getattr(action, 'notify', None):
+ if getattr(action, "notify", None):
notify = NotificationsHelper.to_model(action.notify)
else:
# We use embedded document model for ``notify`` in action model. If notify is
@@ -296,12 +300,22 @@ def to_model(cls, action):
# to use an empty document.
notify = NotificationsHelper.to_model({})
- metadata_file = getattr(action, 'metadata_file', None)
-
- model = cls.model(name=name, description=description, enabled=enabled,
- entry_point=entry_point, pack=pack, runner_type=runner_type,
- tags=tags, parameters=parameters, output_schema=output_schema,
- notify=notify, ref=ref, metadata_file=metadata_file)
+ metadata_file = getattr(action, "metadata_file", None)
+
+ model = cls.model(
+ name=name,
+ description=description,
+ enabled=enabled,
+ entry_point=entry_point,
+ pack=pack,
+ runner_type=runner_type,
+ tags=tags,
+ parameters=parameters,
+ output_schema=output_schema,
+ notify=notify,
+ ref=ref,
+ metadata_file=metadata_file,
+ )
return model
@@ -310,28 +324,31 @@ class ActionCreateAPI(ActionAPI, APIUIDMixin):
"""
API model for create action operation.
"""
+
schema = copy.deepcopy(ActionAPI.schema)
- schema['properties']['data_files'] = {
- 'description': 'Optional action script and data files which are written to the filesystem.',
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'file_path': {
- 'type': 'string',
- 'description': ('Path to the file relative to the pack actions directory '
- '(e.g. my_action.py)'),
- 'required': True
+ schema["properties"]["data_files"] = {
+ "description": "Optional action script and data files which are written to the filesystem.",
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "file_path": {
+ "type": "string",
+ "description": (
+ "Path to the file relative to the pack actions directory "
+ "(e.g. my_action.py)"
+ ),
+ "required": True,
},
- 'content': {
- 'type': 'string',
- 'description': 'Raw file content.',
- 'required': True
+ "content": {
+ "type": "string",
+ "description": "Raw file content.",
+ "required": True,
},
},
- 'additionalProperties': False
+ "additionalProperties": False,
},
- 'default': []
+ "default": [],
}
@@ -339,8 +356,9 @@ class ActionUpdateAPI(ActionAPI, APIUIDMixin):
"""
API model for update action operation.
"""
+
schema = copy.deepcopy(ActionCreateAPI.schema)
- del schema['properties']['pack']['default']
+ del schema["properties"]["pack"]["default"]
class LiveActionAPI(BaseAPI):
@@ -356,27 +374,27 @@ class LiveActionAPI(BaseAPI):
"properties": {
"id": {
"description": "The unique identifier for the action execution.",
- "type": "string"
+ "type": "string",
},
"status": {
"description": "The current status of the action execution.",
"type": "string",
- "enum": LIVEACTION_STATUSES
+ "enum": LIVEACTION_STATUSES,
},
"start_timestamp": {
"description": "The start time when the action is executed.",
"type": "string",
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"end_timestamp": {
"description": "The timestamp when the action has finished.",
"type": "string",
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"action": {
"description": "Reference to the action to be executed.",
"type": "string",
- "required": True
+ "required": True,
},
"parameters": {
"description": "Input parameters for the action.",
@@ -390,58 +408,56 @@ class LiveActionAPI(BaseAPI):
{"type": "number"},
{"type": "object"},
{"type": "string"},
- {"type": "null"}
+ {"type": "null"},
]
}
},
- 'additionalProperties': False
+ "additionalProperties": False,
},
"result": {
- "anyOf": [{"type": "array"},
- {"type": "boolean"},
- {"type": "integer"},
- {"type": "number"},
- {"type": "object"},
- {"type": "string"}]
- },
- "context": {
- "type": "object"
- },
- "callback": {
- "type": "object"
- },
- "runner_info": {
- "type": "object"
- },
+ "anyOf": [
+ {"type": "array"},
+ {"type": "boolean"},
+ {"type": "integer"},
+ {"type": "number"},
+ {"type": "object"},
+ {"type": "string"},
+ ]
+ },
+ "context": {"type": "object"},
+ "callback": {"type": "object"},
+ "runner_info": {"type": "object"},
"notify": {
"description": "Notification settings for liveaction.",
"type": "object",
"properties": {
"on-complete": NotificationSubSchemaAPI,
"on-failure": NotificationSubSchemaAPI,
- "on-success": NotificationSubSchemaAPI
+ "on-success": NotificationSubSchemaAPI,
},
- "additionalProperties": False
+ "additionalProperties": False,
},
"delay": {
- "description": ("How long (in milliseconds) to delay the execution before"
- "scheduling."),
+ "description": (
+ "How long (in milliseconds) to delay the execution before"
+ "scheduling."
+ ),
"type": "integer",
- }
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets)
if model.start_timestamp:
- doc['start_timestamp'] = isotime.format(model.start_timestamp, offset=False)
+ doc["start_timestamp"] = isotime.format(model.start_timestamp, offset=False)
if model.end_timestamp:
- doc['end_timestamp'] = isotime.format(model.end_timestamp, offset=False)
+ doc["end_timestamp"] = isotime.format(model.end_timestamp, offset=False)
- if getattr(model, 'notify', None):
- doc['notify'] = NotificationsHelper.from_model(model.notify)
+ if getattr(model, "notify", None):
+ doc["notify"] = NotificationsHelper.from_model(model.notify)
return cls(**doc)
@@ -449,32 +465,40 @@ def from_model(cls, model, mask_secrets=False):
def to_model(cls, live_action):
action = live_action.action
- if getattr(live_action, 'start_timestamp', None):
+ if getattr(live_action, "start_timestamp", None):
start_timestamp = isotime.parse(live_action.start_timestamp)
else:
start_timestamp = None
- if getattr(live_action, 'end_timestamp', None):
+ if getattr(live_action, "end_timestamp", None):
end_timestamp = isotime.parse(live_action.end_timestamp)
else:
end_timestamp = None
- status = getattr(live_action, 'status', None)
- parameters = getattr(live_action, 'parameters', dict())
- context = getattr(live_action, 'context', dict())
- callback = getattr(live_action, 'callback', dict())
- result = getattr(live_action, 'result', None)
- delay = getattr(live_action, 'delay', None)
+ status = getattr(live_action, "status", None)
+ parameters = getattr(live_action, "parameters", dict())
+ context = getattr(live_action, "context", dict())
+ callback = getattr(live_action, "callback", dict())
+ result = getattr(live_action, "result", None)
+ delay = getattr(live_action, "delay", None)
- if getattr(live_action, 'notify', None):
+ if getattr(live_action, "notify", None):
notify = NotificationsHelper.to_model(live_action.notify)
else:
notify = None
- model = cls.model(action=action,
- start_timestamp=start_timestamp, end_timestamp=end_timestamp,
- status=status, parameters=parameters, context=context,
- callback=callback, result=result, notify=notify, delay=delay)
+ model = cls.model(
+ action=action,
+ start_timestamp=start_timestamp,
+ end_timestamp=end_timestamp,
+ status=status,
+ parameters=parameters,
+ context=context,
+ callback=callback,
+ result=result,
+ notify=notify,
+ delay=delay,
+ )
return model
@@ -483,11 +507,12 @@ class LiveActionCreateAPI(LiveActionAPI):
"""
API model for action execution create (run action) operations.
"""
+
schema = copy.deepcopy(LiveActionAPI.schema)
- schema['properties']['user'] = {
- 'description': 'User context under which action should run (admins only)',
- 'type': 'string',
- 'default': None
+ schema["properties"]["user"] = {
+ "description": "User context under which action should run (admins only)",
+ "type": "string",
+ "default": None,
}
@@ -496,6 +521,7 @@ class ActionExecutionStateAPI(BaseAPI):
System entity that represents state of an action in the system.
This is used only in tests for now.
"""
+
model = ActionExecutionStateDB
schema = {
"title": "ActionExecutionState",
@@ -504,25 +530,25 @@ class ActionExecutionStateAPI(BaseAPI):
"properties": {
"id": {
"description": "The unique identifier for the action execution state.",
- "type": "string"
+ "type": "string",
},
"execution_id": {
"type": "string",
"description": "ID of the action execution.",
- "required": True
+ "required": True,
},
"query_context": {
"type": "object",
"description": "query context to be used by querier.",
- "required": True
+ "required": True,
},
"query_module": {
"type": "string",
"description": "Name of the query module.",
- "required": True
- }
+ "required": True,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -531,8 +557,11 @@ def to_model(cls, state):
query_module = state.query_module
query_context = state.query_context
- model = cls.model(execution_id=execution_id, query_module=query_module,
- query_context=query_context)
+ model = cls.model(
+ execution_id=execution_id,
+ query_module=query_module,
+ query_context=query_context,
+ )
return model
@@ -540,6 +569,7 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin):
"""
Alias for an action in the system.
"""
+
model = ActionAliasDB
schema = {
"title": "ActionAlias",
@@ -548,42 +578,40 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin):
"properties": {
"id": {
"description": "The unique identifier for the action alias.",
- "type": "string"
+ "type": "string",
},
"ref": {
"description": (
"System computed user friendly reference for the alias. "
"Provided value will be overridden by computed value."
),
- "type": "string"
- },
- "uid": {
- "type": "string"
+ "type": "string",
},
+ "uid": {"type": "string"},
"name": {
"type": "string",
"description": "Name of the action alias.",
- "required": True
+ "required": True,
},
"pack": {
"description": "The content pack this actionalias belongs to.",
"type": "string",
- "required": True
+ "required": True,
},
"description": {
"type": "string",
"description": "Description of the action alias.",
- "default": None
+ "default": None,
},
"enabled": {
"description": "Flag indicating of action alias is enabled.",
"type": "boolean",
- "default": True
+ "default": True,
},
"action_ref": {
"type": "string",
"description": "Reference to the aliased action.",
- "required": True
+ "required": True,
},
"formats": {
"type": "array",
@@ -596,13 +624,13 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin):
"display": {"type": "string"},
"representation": {
"type": "array",
- "items": {"type": "string"}
- }
- }
- }
+ "items": {"type": "string"},
+ },
+ },
+ },
]
},
- "description": "Possible parameter format."
+ "description": "Possible parameter format.",
},
"ack": {
"type": "object",
@@ -610,56 +638,65 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin):
"enabled": {"type": "boolean"},
"format": {"type": "string"},
"extra": {"type": "object"},
- "append_url": {"type": "boolean"}
+ "append_url": {"type": "boolean"},
},
- "description": "Acknowledgement message format."
+ "description": "Acknowledgement message format.",
},
"result": {
"type": "object",
"properties": {
"enabled": {"type": "boolean"},
"format": {"type": "string"},
- "extra": {"type": "object"}
+ "extra": {"type": "object"},
},
- "description": "Execution message format."
+ "description": "Execution message format.",
},
"extra": {
"type": "object",
- "description": "Extra parameters, usually adapter-specific."
+ "description": "Extra parameters, usually adapter-specific.",
},
"immutable_parameters": {
"type": "object",
- "description": "Parameters to be passed to the action on every execution."
+ "description": "Parameters to be passed to the action on every execution.",
},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def to_model(cls, alias):
name = alias.name
- description = getattr(alias, 'description', None)
+ description = getattr(alias, "description", None)
pack = alias.pack
ref = ResourceReference.to_string_reference(pack=pack, name=name)
- enabled = getattr(alias, 'enabled', True)
+ enabled = getattr(alias, "enabled", True)
action_ref = alias.action_ref
formats = alias.formats
- ack = getattr(alias, 'ack', None)
- result = getattr(alias, 'result', None)
- extra = getattr(alias, 'extra', None)
- immutable_parameters = getattr(alias, 'immutable_parameters', None)
- metadata_file = getattr(alias, 'metadata_file', None)
-
- model = cls.model(name=name, description=description, pack=pack, ref=ref,
- enabled=enabled, action_ref=action_ref, formats=formats,
- ack=ack, result=result, extra=extra,
- immutable_parameters=immutable_parameters,
- metadata_file=metadata_file)
+ ack = getattr(alias, "ack", None)
+ result = getattr(alias, "result", None)
+ extra = getattr(alias, "extra", None)
+ immutable_parameters = getattr(alias, "immutable_parameters", None)
+ metadata_file = getattr(alias, "metadata_file", None)
+
+ model = cls.model(
+ name=name,
+ description=description,
+ pack=pack,
+ ref=ref,
+ enabled=enabled,
+ action_ref=action_ref,
+ formats=formats,
+ ack=ack,
+ result=result,
+ extra=extra,
+ immutable_parameters=immutable_parameters,
+ metadata_file=metadata_file,
+ )
return model
@@ -667,6 +704,7 @@ class AliasExecutionAPI(BaseAPI):
"""
Alias for an action in the system.
"""
+
model = None
schema = {
"title": "AliasExecution",
@@ -676,48 +714,48 @@ class AliasExecutionAPI(BaseAPI):
"name": {
"type": "string",
"description": "Name of the action alias which matched.",
- "required": True
+ "required": True,
},
"format": {
"type": "string",
"description": "Format string which matched.",
- "required": True
+ "required": True,
},
"command": {
"type": "string",
"description": "Command used in chat.",
- "required": True
+ "required": True,
},
"user": {
"type": "string",
"description": "User that requested the execution.",
- "default": "channel" # TODO: This value doesnt get set
+ "default": "channel", # TODO: This value doesnt get set
},
"source_channel": {
"type": "string",
"description": "Channel from which the execution was requested. This is not the "
- "channel as defined by the notification system.",
- "required": True
+ "channel as defined by the notification system.",
+ "required": True,
},
"source_context": {
"type": "object",
"description": "ALL data included with the message (also called the message "
- "envelope). This is currently only used by the Microsoft Teams "
- "adapter.",
- "required": False
+ "envelope). This is currently only used by the Microsoft Teams "
+ "adapter.",
+ "required": False,
},
"notification_channel": {
"type": "string",
"description": "StackStorm notification channel to use to respond.",
- "required": False
+ "required": False,
},
"notification_route": {
"type": "string",
"description": "StackStorm notification route to use to respond.",
- "required": False
- }
+ "required": False,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -734,6 +772,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI):
"""
API object used for alias execution "match and execute" API endpoint request payload.
"""
+
model = None
schema = {
"title": "ActionAliasMatchAndExecuteInputAPI",
@@ -743,7 +782,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI):
"command": {
"type": "string",
"description": "Command used in chat.",
- "required": True
+ "required": True,
},
"user": {
"type": "string",
@@ -753,22 +792,22 @@ class AliasMatchAndExecuteInputAPI(BaseAPI):
"type": "string",
"description": "Channel from which the execution was requested. This is not the \
channel as defined by the notification system.",
- "required": True
+ "required": True,
},
"notification_channel": {
"type": "string",
"description": "StackStorm notification channel to use to respond.",
"required": False,
- "default": None
+ "default": None,
},
"notification_route": {
"type": "string",
"description": "StackStorm notification route to use to respond.",
"required": False,
- "default": None
- }
+ "default": None,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@@ -776,6 +815,7 @@ class ActionAliasMatchAPI(BaseAPI):
"""
API model used for alias match API endpoint.
"""
+
model = None
schema = {
@@ -786,10 +826,10 @@ class ActionAliasMatchAPI(BaseAPI):
"command": {
"type": "string",
"description": "Command string to try to match the aliases against.",
- "required": True
+ "required": True,
}
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -805,6 +845,7 @@ class ActionAliasHelpAPI(BaseAPI):
"""
API model used to display action-alias help API endpoint.
"""
+
model = None
schema = {
@@ -816,28 +857,28 @@ class ActionAliasHelpAPI(BaseAPI):
"type": "string",
"description": "Find help strings containing keyword.",
"required": False,
- "default": ""
+ "default": "",
},
"pack": {
"type": "string",
"description": "List help strings for a specific pack.",
"required": False,
- "default": ""
+ "default": "",
},
"offset": {
"type": "integer",
"description": "List help strings from the offset position.",
"required": False,
- "default": 0
+ "default": 0,
},
"limit": {
"type": "integer",
"description": "Limit the number of help strings returned.",
"required": False,
- "default": 0
- }
+ "default": 0,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
diff --git a/st2common/st2common/models/api/actionrunner.py b/st2common/st2common/models/api/actionrunner.py
index d2a2029e32..7b580e1c9b 100644
--- a/st2common/st2common/models/api/actionrunner.py
+++ b/st2common/st2common/models/api/actionrunner.py
@@ -17,7 +17,7 @@
from st2common import log as logging
from st2common.models.api.base import BaseAPI
-__all__ = ['ActionRunnerAPI']
+__all__ = ["ActionRunnerAPI"]
LOG = logging.getLogger(__name__)
@@ -29,12 +29,9 @@ class ActionRunnerAPI(BaseAPI):
Attribute:
...
"""
+
schema = {
- 'type': 'object',
- 'parameters': {
- 'id': {
- 'type': 'string'
- }
- },
- 'additionalProperties': False
+ "type": "object",
+ "parameters": {"id": {"type": "string"}},
+ "additionalProperties": False,
}
diff --git a/st2common/st2common/models/api/auth.py b/st2common/st2common/models/api/auth.py
index 8e5ed34e34..10672e99ec 100644
--- a/st2common/st2common/models/api/auth.py
+++ b/st2common/st2common/models/api/auth.py
@@ -36,13 +36,8 @@ class UserAPI(BaseAPI):
schema = {
"title": "User",
"type": "object",
- "properties": {
- "name": {
- "type": "string",
- "required": True
- }
- },
- "additionalProperties": False
+ "properties": {"name": {"type": "string", "required": True}},
+ "additionalProperties": False,
}
@classmethod
@@ -58,34 +53,25 @@ class TokenAPI(BaseAPI):
"title": "Token",
"type": "object",
"properties": {
- "id": {
- "type": "string"
- },
- "user": {
- "type": ["string", "null"]
- },
- "token": {
- "type": ["string", "null"]
- },
- "ttl": {
- "type": "integer",
- "minimum": 1
- },
+ "id": {"type": "string"},
+ "user": {"type": ["string", "null"]},
+ "token": {"type": ["string", "null"]},
+ "ttl": {"type": "integer", "minimum": 1},
"expiry": {
"type": ["string", "null"],
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
- "metadata": {
- "type": ["object", "null"]
- }
+ "metadata": {"type": ["object", "null"]},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets)
- doc['expiry'] = isotime.format(model.expiry, offset=False) if model.expiry else None
+ doc["expiry"] = (
+ isotime.format(model.expiry, offset=False) if model.expiry else None
+ )
return cls(**doc)
@classmethod
@@ -104,52 +90,44 @@ class ApiKeyAPI(BaseAPI, APIUIDMixin):
"title": "ApiKey",
"type": "object",
"properties": {
- "id": {
- "type": "string"
- },
- "uid": {
- "type": "string"
- },
- "user": {
- "type": ["string", "null"],
- "default": ""
- },
- "key_hash": {
- "type": ["string", "null"]
- },
- "metadata": {
- "type": ["object", "null"]
- },
- 'created_at': {
- 'description': 'The start time when the action is executed.',
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
+ "id": {"type": "string"},
+ "uid": {"type": "string"},
+ "user": {"type": ["string", "null"], "default": ""},
+ "key_hash": {"type": ["string", "null"]},
+ "metadata": {"type": ["object", "null"]},
+ "created_at": {
+ "description": "The start time when the action is executed.",
+ "type": "string",
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"enabled": {
"description": "Enable or disable the action from invocation.",
"type": "boolean",
- "default": True
- }
+ "default": True,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets)
- doc['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \
- else None
+ doc["created_at"] = (
+ isotime.format(model.created_at, offset=False) if model.created_at else None
+ )
return cls(**doc)
@classmethod
def to_model(cls, instance):
# If PrimaryKey ID is provided, - we want to work with existing ST2 API key
- id = getattr(instance, 'id', None)
+ id = getattr(instance, "id", None)
user = str(instance.user) if instance.user else None
- key_hash = getattr(instance, 'key_hash', None)
- metadata = getattr(instance, 'metadata', {})
- enabled = bool(getattr(instance, 'enabled', True))
- model = cls.model(id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled)
+ key_hash = getattr(instance, "key_hash", None)
+ metadata = getattr(instance, "metadata", {})
+ enabled = bool(getattr(instance, "enabled", True))
+ model = cls.model(
+ id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled
+ )
return model
@@ -158,45 +136,35 @@ class ApiKeyCreateResponseAPI(BaseAPI):
"title": "APIKeyCreateResponse",
"type": "object",
"properties": {
- "id": {
- "type": "string"
- },
- "uid": {
- "type": "string"
- },
- "user": {
- "type": ["string", "null"],
- "default": ""
- },
- "key": {
- "type": ["string", "null"]
- },
- "metadata": {
- "type": ["object", "null"]
- },
- 'created_at': {
- 'description': 'The start time when the action is executed.',
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
+ "id": {"type": "string"},
+ "uid": {"type": "string"},
+ "user": {"type": ["string", "null"], "default": ""},
+ "key": {"type": ["string", "null"]},
+ "metadata": {"type": ["object", "null"]},
+ "created_at": {
+ "description": "The start time when the action is executed.",
+ "type": "string",
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"enabled": {
"description": "Enable or disable the action from invocation.",
"type": "boolean",
- "default": True
- }
+ "default": True,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
doc = cls._from_model(model=model, mask_secrets=mask_secrets)
attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None}
- attrs['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \
- else None
+ attrs["created_at"] = (
+ isotime.format(model.created_at, offset=False) if model.created_at else None
+ )
# key_hash is ignored.
- attrs.pop('key_hash', None)
+ attrs.pop("key_hash", None)
# key is unknown so the calling code will have to update after conversion.
- attrs['key'] = None
+ attrs["key"] = None
return cls(**attrs)
diff --git a/st2common/st2common/models/api/base.py b/st2common/st2common/models/api/base.py
index 3669291e9c..6c052a43e3 100644
--- a/st2common/st2common/models/api/base.py
+++ b/st2common/st2common/models/api/base.py
@@ -23,10 +23,7 @@
from st2common.util import mongoescape as util_mongodb
from st2common import log as logging
-__all__ = [
- 'BaseAPI',
- 'APIUIDMixin'
-]
+__all__ = ["BaseAPI", "APIUIDMixin"]
LOG = logging.getLogger(__name__)
@@ -43,13 +40,13 @@ def __init__(self, **kw):
def __repr__(self):
name = type(self).__name__
- attrs = ', '.join("'%s': %r" % item for item in six.iteritems(vars(self)))
+ attrs = ", ".join("'%s': %r" % item for item in six.iteritems(vars(self)))
# The format here is so that eval can be applied.
return "%s(**{%s})" % (name, attrs)
def __str__(self):
name = type(self).__name__
- attrs = ', '.join("%s=%r" % item for item in six.iteritems(vars(self)))
+ attrs = ", ".join("%s=%r" % item for item in six.iteritems(vars(self)))
return "%s[%s]" % (name, attrs)
@@ -66,12 +63,16 @@ def validate(self):
"""
from st2common.util import schema as util_schema
- schema = getattr(self, 'schema', {})
+ schema = getattr(self, "schema", {})
attributes = vars(self)
- cleaned = util_schema.validate(instance=attributes, schema=schema,
- cls=util_schema.CustomValidator, use_default=True,
- allow_default_none=True)
+ cleaned = util_schema.validate(
+ instance=attributes,
+ schema=schema,
+ cls=util_schema.CustomValidator,
+ use_default=True,
+ allow_default_none=True,
+ )
# Note: We use type() instead of self.__class__ since self.__class__ confuses pylint
return type(self)(**cleaned)
@@ -80,8 +81,8 @@ def validate(self):
def _from_model(cls, model, mask_secrets=False):
doc = model.to_mongo()
- if '_id' in doc:
- doc['id'] = str(doc.pop('_id'))
+ if "_id" in doc:
+ doc["id"] = str(doc.pop("_id"))
doc = util_mongodb.unescape_chars(doc)
@@ -117,7 +118,7 @@ def to_model(cls, doc):
class APIUIDMixin(object):
- """"
+ """ "
Mixin class for retrieving UID for API objects.
"""
@@ -142,9 +143,11 @@ def has_valid_uid(self):
def cast_argument_value(value_type, value):
if value_type == bool:
+
def cast_func(value):
value = str(value)
- return value.lower() in ['1', 'true']
+ return value.lower() in ["1", "true"]
+
else:
cast_func = value_type
diff --git a/st2common/st2common/models/api/execution.py b/st2common/st2common/models/api/execution.py
index 447a8679eb..87a5ff52c0 100644
--- a/st2common/st2common/models/api/execution.py
+++ b/st2common/st2common/models/api/execution.py
@@ -28,10 +28,7 @@
from st2common.models.api.action import RunnerTypeAPI, ActionAPI, LiveActionAPI
from st2common import log as logging
-__all__ = [
- 'ActionExecutionAPI',
- 'ActionExecutionOutputAPI'
-]
+__all__ = ["ActionExecutionAPI", "ActionExecutionOutputAPI"]
LOG = logging.getLogger(__name__)
@@ -48,47 +45,44 @@
class ActionExecutionAPI(BaseAPI):
model = ActionExecutionDB
- SKIP = ['start_timestamp', 'end_timestamp']
+ SKIP = ["start_timestamp", "end_timestamp"]
schema = {
"title": "ActionExecution",
"description": "Record of the execution of an action.",
"type": "object",
"properties": {
- "id": {
- "type": "string",
- "required": True
- },
+ "id": {"type": "string", "required": True},
"trigger": TriggerAPI.schema,
"trigger_type": TriggerTypeAPI.schema,
"trigger_instance": TriggerInstanceAPI.schema,
"rule": RuleAPI.schema,
- "action": REQUIRED_ATTR_SCHEMAS['action'],
- "runner": REQUIRED_ATTR_SCHEMAS['runner'],
- "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'],
+ "action": REQUIRED_ATTR_SCHEMAS["action"],
+ "runner": REQUIRED_ATTR_SCHEMAS["runner"],
+ "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"],
"status": {
"description": "The current status of the action execution.",
"type": "string",
- "enum": LIVEACTION_STATUSES
+ "enum": LIVEACTION_STATUSES,
},
"start_timestamp": {
"description": "The start time when the action is executed.",
"type": "string",
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"end_timestamp": {
"description": "The timestamp when the action has finished.",
"type": "string",
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
"elapsed_seconds": {
"description": "Time duration in seconds taken for completion of this execution.",
"type": "number",
- "required": False
+ "required": False,
},
"web_url": {
"description": "History URL for this execution if you want to view in UI.",
"type": "string",
- "required": False
+ "required": False,
},
"parameters": {
"description": "Input parameters for the action.",
@@ -101,28 +95,28 @@ class ActionExecutionAPI(BaseAPI):
{"type": "integer"},
{"type": "number"},
{"type": "object"},
- {"type": "string"}
+ {"type": "string"},
]
}
},
- 'additionalProperties': False
- },
- "context": {
- "type": "object"
+ "additionalProperties": False,
},
+ "context": {"type": "object"},
"result": {
- "anyOf": [{"type": "array"},
- {"type": "boolean"},
- {"type": "integer"},
- {"type": "number"},
- {"type": "object"},
- {"type": "string"}]
+ "anyOf": [
+ {"type": "array"},
+ {"type": "boolean"},
+ {"type": "integer"},
+ {"type": "number"},
+ {"type": "object"},
+ {"type": "string"},
+ ]
},
"parent": {"type": "string"},
"children": {
"type": "array",
"items": {"type": "string"},
- "uniqueItems": True
+ "uniqueItems": True,
},
"log": {
"description": "Contains information about execution state transitions.",
@@ -132,22 +126,21 @@ class ActionExecutionAPI(BaseAPI):
"properties": {
"timestamp": {
"type": "string",
- "pattern": isotime.ISO8601_UTC_REGEX
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
- "status": {
- "type": "string",
- "enum": LIVEACTION_STATUSES
- }
- }
- }
+ "status": {"type": "string", "enum": LIVEACTION_STATUSES},
+ },
+ },
},
"delay": {
- "description": ("How long (in milliseconds) to delay the execution before"
- "scheduling."),
+ "description": (
+ "How long (in milliseconds) to delay the execution before"
+ "scheduling."
+ ),
"type": "integer",
- }
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -155,16 +148,16 @@ def from_model(cls, model, mask_secrets=False):
doc = cls._from_model(model, mask_secrets=mask_secrets)
start_timestamp = model.start_timestamp
start_timestamp_iso = isotime.format(start_timestamp, offset=False)
- doc['start_timestamp'] = start_timestamp_iso
+ doc["start_timestamp"] = start_timestamp_iso
end_timestamp = model.end_timestamp
if end_timestamp:
end_timestamp_iso = isotime.format(end_timestamp, offset=False)
- doc['end_timestamp'] = end_timestamp_iso
- doc['elapsed_seconds'] = (end_timestamp - start_timestamp).total_seconds()
+ doc["end_timestamp"] = end_timestamp_iso
+ doc["elapsed_seconds"] = (end_timestamp - start_timestamp).total_seconds()
- for entry in doc.get('log', []):
- entry['timestamp'] = isotime.format(entry['timestamp'], offset=False)
+ for entry in doc.get("log", []):
+ entry["timestamp"] = isotime.format(entry["timestamp"], offset=False)
attrs = {attr: value for attr, value in six.iteritems(doc) if value}
return cls(**attrs)
@@ -172,11 +165,11 @@ def from_model(cls, model, mask_secrets=False):
@classmethod
def to_model(cls, instance):
values = {}
- for attr, meta in six.iteritems(cls.schema.get('properties', dict())):
+ for attr, meta in six.iteritems(cls.schema.get("properties", dict())):
if not getattr(instance, attr, None):
continue
- default = copy.deepcopy(meta.get('default', None))
+ default = copy.deepcopy(meta.get("default", None))
value = getattr(instance, attr, default)
# pylint: disable=no-member
@@ -188,8 +181,8 @@ def to_model(cls, instance):
if attr not in ActionExecutionAPI.SKIP:
values[attr] = value
- values['start_timestamp'] = isotime.parse(instance.start_timestamp)
- values['end_timestamp'] = isotime.parse(instance.end_timestamp)
+ values["start_timestamp"] = isotime.parse(instance.start_timestamp)
+ values["end_timestamp"] = isotime.parse(instance.end_timestamp)
model = cls.model(**values)
return model
@@ -198,41 +191,24 @@ def to_model(cls, instance):
class ActionExecutionOutputAPI(BaseAPI):
model = ActionExecutionOutputDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string'
- },
- 'execution_id': {
- 'type': 'string'
- },
- 'action_ref': {
- 'type': 'string'
- },
- 'runner_ref': {
- 'type': 'string'
- },
- 'timestamp': {
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
- },
- 'output_type': {
- 'type': 'string'
- },
- 'data': {
- 'type': 'string'
- },
- 'delay': {
- 'type': 'integer'
- }
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "execution_id": {"type": "string"},
+ "action_ref": {"type": "string"},
+ "runner_ref": {"type": "string"},
+ "timestamp": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX},
+ "output_type": {"type": "string"},
+ "data": {"type": "string"},
+ "delay": {"type": "integer"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=True):
doc = cls._from_model(model, mask_secrets=mask_secrets)
- doc['timestamp'] = isotime.format(model.timestamp, offset=False)
+ doc["timestamp"] = isotime.format(model.timestamp, offset=False)
attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None}
return cls(**attrs)
diff --git a/st2common/st2common/models/api/inquiry.py b/st2common/st2common/models/api/inquiry.py
index e3194df28c..a45327aaa7 100644
--- a/st2common/st2common/models/api/inquiry.py
+++ b/st2common/st2common/models/api/inquiry.py
@@ -54,30 +54,11 @@ class InquiryAPI(BaseAPI):
"description": "Record of an Inquiry",
"type": "object",
"properties": {
- "id": {
- "type": "string",
- "required": True
- },
- "route": {
- "type": "string",
- "default": "",
- "required": True
- },
- "ttl": {
- "type": "integer",
- "default": 1440,
- "required": True
- },
- "users": {
- "type": "array",
- "default": [],
- "required": True
- },
- "roles": {
- "type": "array",
- "default": [],
- "required": True
- },
+ "id": {"type": "string", "required": True},
+ "route": {"type": "string", "default": "", "required": True},
+ "ttl": {"type": "integer", "default": 1440, "required": True},
+ "users": {"type": "array", "default": [], "required": True},
+ "roles": {"type": "array", "default": [], "required": True},
"schema": {
"type": "object",
"default": {
@@ -87,30 +68,32 @@ class InquiryAPI(BaseAPI):
"continue": {
"type": "boolean",
"description": "Would you like to continue the workflow?",
- "required": True
+ "required": True,
}
},
},
- "required": True
+ "required": True,
},
- "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'],
- "runner": REQUIRED_ATTR_SCHEMAS['runner'],
+ "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"],
+ "runner": REQUIRED_ATTR_SCHEMAS["runner"],
"status": {
"description": "The current status of the action execution.",
"type": "string",
- "enum": LIVEACTION_STATUSES
+ "enum": LIVEACTION_STATUSES,
},
"parent": {"type": "string"},
"result": {
- "anyOf": [{"type": "array"},
- {"type": "boolean"},
- {"type": "integer"},
- {"type": "number"},
- {"type": "object"},
- {"type": "string"}]
- }
+ "anyOf": [
+ {"type": "array"},
+ {"type": "boolean"},
+ {"type": "integer"},
+ {"type": "number"},
+ {"type": "object"},
+ {"type": "string"},
+ ]
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -118,23 +101,22 @@ def from_model(cls, model, mask_secrets=False):
doc = cls._from_model(model, mask_secrets=mask_secrets)
newdoc = {
- 'id': doc['id'],
- 'runner': doc.get('runner', None),
- 'status': doc.get('status', None),
- 'liveaction': doc.get('liveaction', None),
- 'parent': doc.get('parent', None),
- 'result': doc.get('result', None)
+ "id": doc["id"],
+ "runner": doc.get("runner", None),
+ "status": doc.get("status", None),
+ "liveaction": doc.get("liveaction", None),
+ "parent": doc.get("parent", None),
+ "result": doc.get("result", None),
}
- for field in ['route', 'ttl', 'users', 'roles', 'schema']:
- newdoc[field] = doc['result'].get(field, None)
+ for field in ["route", "ttl", "users", "roles", "schema"]:
+ newdoc[field] = doc["result"].get(field, None)
return cls(**newdoc)
class InquiryResponseAPI(BaseAPI):
- """A more pruned Inquiry model, containing only the fields needed for an API response
- """
+ """A more pruned Inquiry model, containing only the fields needed for an API response"""
model = ActionExecutionDB
schema = {
@@ -142,30 +124,11 @@ class InquiryResponseAPI(BaseAPI):
"description": "Record of an Inquiry",
"type": "object",
"properties": {
- "id": {
- "type": "string",
- "required": True
- },
- "route": {
- "type": "string",
- "default": "",
- "required": True
- },
- "ttl": {
- "type": "integer",
- "default": 1440,
- "required": True
- },
- "users": {
- "type": "array",
- "default": [],
- "required": True
- },
- "roles": {
- "type": "array",
- "default": [],
- "required": True
- },
+ "id": {"type": "string", "required": True},
+ "route": {"type": "string", "default": "", "required": True},
+ "ttl": {"type": "integer", "default": 1440, "required": True},
+ "users": {"type": "array", "default": [], "required": True},
+ "roles": {"type": "array", "default": [], "required": True},
"schema": {
"type": "object",
"default": {
@@ -175,14 +138,14 @@ class InquiryResponseAPI(BaseAPI):
"continue": {
"type": "boolean",
"description": "Would you like to continue the workflow?",
- "required": True
+ "required": True,
}
},
},
- "required": True
- }
+ "required": True,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -201,9 +164,7 @@ def from_model(cls, model, mask_secrets=False, skip_db=False):
else:
doc = model
- newdoc = {
- "id": doc["id"]
- }
+ newdoc = {"id": doc["id"]}
for field in ["route", "ttl", "users", "roles", "schema"]:
newdoc[field] = doc["result"].get(field)
@@ -211,16 +172,16 @@ def from_model(cls, model, mask_secrets=False, skip_db=False):
@classmethod
def from_inquiry_api(cls, inquiry_api, mask_secrets=False):
- """ Allows translation of InquiryAPI directly to InquiryResponseAPI
+ """Allows translation of InquiryAPI directly to InquiryResponseAPI
This bypasses the DB modeling, since there's no DB model for Inquiries yet.
"""
return cls(
- id=getattr(inquiry_api, 'id', None),
- route=getattr(inquiry_api, 'route', None),
- ttl=getattr(inquiry_api, 'ttl', None),
- users=getattr(inquiry_api, 'users', None),
- roles=getattr(inquiry_api, 'roles', None),
- schema=getattr(inquiry_api, 'schema', None)
+ id=getattr(inquiry_api, "id", None),
+ route=getattr(inquiry_api, "route", None),
+ ttl=getattr(inquiry_api, "ttl", None),
+ users=getattr(inquiry_api, "users", None),
+ roles=getattr(inquiry_api, "roles", None),
+ schema=getattr(inquiry_api, "schema", None),
)
diff --git a/st2common/st2common/models/api/keyvalue.py b/st2common/st2common/models/api/keyvalue.py
index 8365350ef7..a19cfcc33e 100644
--- a/st2common/st2common/models/api/keyvalue.py
+++ b/st2common/st2common/models/api/keyvalue.py
@@ -21,9 +21,16 @@
from oslo_config import cfg
import six
-from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, ALLOWED_SCOPES
+from st2common.constants.keyvalue import (
+ FULL_SYSTEM_SCOPE,
+ FULL_USER_SCOPE,
+ ALLOWED_SCOPES,
+)
from st2common.constants.keyvalue import SYSTEM_SCOPE, USER_SCOPE
-from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException
+from st2common.exceptions.keyvalue import (
+ CryptoKeyNotSetupException,
+ InvalidScopeException,
+)
from st2common.log import logging
from st2common.util import isotime
from st2common.util import date as date_utils
@@ -32,10 +39,7 @@
from st2common.models.system.keyvalue import UserKeyReference
from st2common.models.db.keyvalue import KeyValuePairDB
-__all__ = [
- 'KeyValuePairAPI',
- 'KeyValuePairSetAPI'
-]
+__all__ = ["KeyValuePairAPI", "KeyValuePairSetAPI"]
LOG = logging.getLogger(__name__)
@@ -44,50 +48,29 @@ class KeyValuePairAPI(BaseAPI):
crypto_setup = False
model = KeyValuePairDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string'
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "uid": {"type": "string"},
+ "name": {"type": "string"},
+ "description": {"type": "string"},
+ "value": {"type": "string", "required": True},
+ "secret": {"type": "boolean", "required": False, "default": False},
+ "encrypted": {"type": "boolean", "required": False, "default": False},
+ "scope": {
+ "type": "string",
+ "required": False,
+ "default": FULL_SYSTEM_SCOPE,
},
- "uid": {
- "type": "string"
- },
- 'name': {
- 'type': 'string'
- },
- 'description': {
- 'type': 'string'
- },
- 'value': {
- 'type': 'string',
- 'required': True
- },
- 'secret': {
- 'type': 'boolean',
- 'required': False,
- 'default': False
- },
- 'encrypted': {
- 'type': 'boolean',
- 'required': False,
- 'default': False
- },
- 'scope': {
- 'type': 'string',
- 'required': False,
- 'default': FULL_SYSTEM_SCOPE
- },
- 'expire_timestamp': {
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
+ "expire_timestamp": {
+ "type": "string",
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
# Note: Those values are only used for input
# TODO: Improve
- 'ttl': {
- 'type': 'integer'
- }
+ "ttl": {"type": "integer"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@staticmethod
@@ -96,19 +79,25 @@ def _setup_crypto():
# Crypto already set up
return
- LOG.info('Checking if encryption is enabled for key-value store.')
+ LOG.info("Checking if encryption is enabled for key-value store.")
KeyValuePairAPI.is_encryption_enabled = cfg.CONF.keyvalue.enable_encryption
- LOG.debug('Encryption enabled? : %s', KeyValuePairAPI.is_encryption_enabled)
+ LOG.debug("Encryption enabled? : %s", KeyValuePairAPI.is_encryption_enabled)
if KeyValuePairAPI.is_encryption_enabled:
KeyValuePairAPI.crypto_key_path = cfg.CONF.keyvalue.encryption_key_path
- LOG.info('Encryption enabled. Looking for key in path %s',
- KeyValuePairAPI.crypto_key_path)
+ LOG.info(
+ "Encryption enabled. Looking for key in path %s",
+ KeyValuePairAPI.crypto_key_path,
+ )
if not os.path.exists(KeyValuePairAPI.crypto_key_path):
- msg = ('Encryption key file does not exist in path %s.' %
- KeyValuePairAPI.crypto_key_path)
+ msg = (
+ "Encryption key file does not exist in path %s."
+ % KeyValuePairAPI.crypto_key_path
+ )
LOG.exception(msg)
- LOG.info('All API requests will now send out BAD_REQUEST ' +
- 'if you ask to store secrets in key value store.')
+ LOG.info(
+ "All API requests will now send out BAD_REQUEST "
+ + "if you ask to store secrets in key value store."
+ )
KeyValuePairAPI.crypto_key = None
else:
KeyValuePairAPI.crypto_key = read_crypto_key(
@@ -123,28 +112,30 @@ def from_model(cls, model, mask_secrets=True):
doc = cls._from_model(model, mask_secrets=mask_secrets)
- if getattr(model, 'expire_timestamp', None) and model.expire_timestamp:
- doc['expire_timestamp'] = isotime.format(model.expire_timestamp, offset=False)
+ if getattr(model, "expire_timestamp", None) and model.expire_timestamp:
+ doc["expire_timestamp"] = isotime.format(
+ model.expire_timestamp, offset=False
+ )
encrypted = False
- secret = getattr(model, 'secret', False)
+ secret = getattr(model, "secret", False)
if secret:
encrypted = True
if not mask_secrets and secret:
- doc['value'] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value)
+ doc["value"] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value)
encrypted = False
- scope = getattr(model, 'scope', SYSTEM_SCOPE)
+ scope = getattr(model, "scope", SYSTEM_SCOPE)
if scope:
- doc['scope'] = scope
+ doc["scope"] = scope
- key = doc.get('name', None)
+ key = doc.get("name", None)
if (scope == USER_SCOPE or scope == FULL_USER_SCOPE) and key:
- doc['user'] = UserKeyReference.get_user(key)
- doc['name'] = UserKeyReference.get_name(key)
+ doc["user"] = UserKeyReference.get_user(key)
+ doc["name"] = UserKeyReference.get_name(key)
- doc['encrypted'] = encrypted
+ doc["encrypted"] = encrypted
attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None}
return cls(**attrs)
@@ -153,21 +144,22 @@ def to_model(cls, kvp):
if not KeyValuePairAPI.crypto_setup:
KeyValuePairAPI._setup_crypto()
- kvp_id = getattr(kvp, 'id', None)
- name = getattr(kvp, 'name', None)
- description = getattr(kvp, 'description', None)
+ kvp_id = getattr(kvp, "id", None)
+ name = getattr(kvp, "name", None)
+ description = getattr(kvp, "description", None)
value = kvp.value
original_value = value
secret = False
- if getattr(kvp, 'ttl', None):
- expire_timestamp = (date_utils.get_datetime_utc_now() +
- datetime.timedelta(seconds=kvp.ttl))
+ if getattr(kvp, "ttl", None):
+ expire_timestamp = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=kvp.ttl
+ )
else:
expire_timestamp = None
- encrypted = getattr(kvp, 'encrypted', False)
- secret = getattr(kvp, 'secret', False)
+ encrypted = getattr(kvp, "encrypted", False)
+ secret = getattr(kvp, "secret", False)
# If user transmitted the value in an pre-encrypted format, we perform the decryption here
# to ensure data integrity. Besides that, we store data as-is.
@@ -182,9 +174,11 @@ def to_model(cls, kvp):
try:
symmetric_decrypt(KeyValuePairAPI.crypto_key, value)
except Exception:
- msg = ('Failed to verify the integrity of the provided value for key "%s". Ensure '
- 'that the value is encrypted with the correct key and not corrupted.' %
- (name))
+ msg = (
+ 'Failed to verify the integrity of the provided value for key "%s". Ensure '
+ "that the value is encrypted with the correct key and not corrupted."
+ % (name)
+ )
raise ValueError(msg)
# Additional safety check to ensure that the value hasn't been decrypted
@@ -194,30 +188,39 @@ def to_model(cls, kvp):
value = symmetric_encrypt(KeyValuePairAPI.crypto_key, value)
- scope = getattr(kvp, 'scope', FULL_SYSTEM_SCOPE)
+ scope = getattr(kvp, "scope", FULL_SYSTEM_SCOPE)
if scope not in ALLOWED_SCOPES:
- raise InvalidScopeException('Invalid scope "%s"! Allowed scopes are %s.' % (
- scope, ALLOWED_SCOPES)
+ raise InvalidScopeException(
+ 'Invalid scope "%s"! Allowed scopes are %s.' % (scope, ALLOWED_SCOPES)
)
# NOTE: For security reasons, encrypted always implies secret=True. See comment
# above for explanation.
if encrypted and not secret:
- raise ValueError('encrypted option can only be used in combination with secret '
- 'option')
+ raise ValueError(
+ "encrypted option can only be used in combination with secret " "option"
+ )
- model = cls.model(id=kvp_id, name=name, description=description, value=value,
- secret=secret, scope=scope,
- expire_timestamp=expire_timestamp)
+ model = cls.model(
+ id=kvp_id,
+ name=name,
+ description=description,
+ value=value,
+ secret=secret,
+ scope=scope,
+ expire_timestamp=expire_timestamp,
+ )
return model
@classmethod
def _verif_key_is_set_up(cls, name):
if not KeyValuePairAPI.crypto_key:
- msg = ('Crypto key not found in %s. Unable to encrypt / decrypt value for key %s.' %
- (KeyValuePairAPI.crypto_key_path, name))
+ msg = "Crypto key not found in %s. Unable to encrypt / decrypt value for key %s." % (
+ KeyValuePairAPI.crypto_key_path,
+ name,
+ )
raise CryptoKeyNotSetupException(msg)
@@ -227,13 +230,12 @@ class KeyValuePairSetAPI(KeyValuePairAPI):
"""
schema = copy.deepcopy(KeyValuePairAPI.schema)
- schema['properties']['ttl'] = {
- 'description': 'Items TTL',
- 'type': 'integer'
- }
- schema['properties']['user'] = {
- 'description': ('User to which the value should be scoped to. Only applicable to '
- 'scope == user'),
- 'type': 'string',
- 'default': None
+ schema["properties"]["ttl"] = {"description": "Items TTL", "type": "integer"}
+ schema["properties"]["user"] = {
+ "description": (
+ "User to which the value should be scoped to. Only applicable to "
+ "scope == user"
+ ),
+ "type": "string",
+ "default": None,
}
diff --git a/st2common/st2common/models/api/notification.py b/st2common/st2common/models/api/notification.py
index fef0545f26..9d80ddbf7f 100644
--- a/st2common/st2common/models/api/notification.py
+++ b/st2common/st2common/models/api/notification.py
@@ -19,57 +19,60 @@
NotificationSubSchemaAPI = {
"type": "object",
"properties": {
- "message": {
- "type": "string",
- "description": "Message to use for notification"
- },
+ "message": {"type": "string", "description": "Message to use for notification"},
"data": {
"type": "object",
- "description": "Data to be sent as part of notification"
+ "description": "Data to be sent as part of notification",
},
"routes": {
"type": "array",
- "description": "Channels to post notifications to."
+ "description": "Channels to post notifications to.",
},
"channels": { # Deprecated. Only here for backward compatibility.
"type": "array",
- "description": "Channels to post notifications to."
+ "description": "Channels to post notifications to.",
},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
class NotificationsHelper(object):
-
@staticmethod
def to_model(notify_api_object):
- if notify_api_object.get('on-success', None):
- on_success = NotificationsHelper._to_model_sub_schema(notify_api_object['on-success'])
+ if notify_api_object.get("on-success", None):
+ on_success = NotificationsHelper._to_model_sub_schema(
+ notify_api_object["on-success"]
+ )
else:
on_success = None
- if notify_api_object.get('on-complete', None):
+ if notify_api_object.get("on-complete", None):
on_complete = NotificationsHelper._to_model_sub_schema(
- notify_api_object['on-complete'])
+ notify_api_object["on-complete"]
+ )
else:
on_complete = None
- if notify_api_object.get('on-failure', None):
- on_failure = NotificationsHelper._to_model_sub_schema(notify_api_object['on-failure'])
+ if notify_api_object.get("on-failure", None):
+ on_failure = NotificationsHelper._to_model_sub_schema(
+ notify_api_object["on-failure"]
+ )
else:
on_failure = None
- model = NotificationSchema(on_success=on_success, on_failure=on_failure,
- on_complete=on_complete)
+ model = NotificationSchema(
+ on_success=on_success, on_failure=on_failure, on_complete=on_complete
+ )
return model
@staticmethod
def _to_model_sub_schema(notification_settings_json):
- message = notification_settings_json.get('message', None)
- data = notification_settings_json.get('data', {})
- routes = (notification_settings_json.get('routes', None) or
- notification_settings_json.get('channels', []))
+ message = notification_settings_json.get("message", None)
+ data = notification_settings_json.get("data", {})
+ routes = notification_settings_json.get(
+ "routes", None
+ ) or notification_settings_json.get("channels", [])
model = NotificationSubSchema(message=message, data=data, routes=routes)
return model
@@ -77,15 +80,18 @@ def _to_model_sub_schema(notification_settings_json):
@staticmethod
def from_model(notify_model):
notify = {}
- if getattr(notify_model, 'on_complete', None):
- notify['on-complete'] = NotificationsHelper._from_model_sub_schema(
- notify_model.on_complete)
- if getattr(notify_model, 'on_success', None):
- notify['on-success'] = NotificationsHelper._from_model_sub_schema(
- notify_model.on_success)
- if getattr(notify_model, 'on_failure', None):
- notify['on-failure'] = NotificationsHelper._from_model_sub_schema(
- notify_model.on_failure)
+ if getattr(notify_model, "on_complete", None):
+ notify["on-complete"] = NotificationsHelper._from_model_sub_schema(
+ notify_model.on_complete
+ )
+ if getattr(notify_model, "on_success", None):
+ notify["on-success"] = NotificationsHelper._from_model_sub_schema(
+ notify_model.on_success
+ )
+ if getattr(notify_model, "on_failure", None):
+ notify["on-failure"] = NotificationsHelper._from_model_sub_schema(
+ notify_model.on_failure
+ )
return notify
@@ -93,13 +99,14 @@ def from_model(notify_model):
def _from_model_sub_schema(notify_sub_schema_model):
notify_sub_schema = {}
- if getattr(notify_sub_schema_model, 'message', None):
- notify_sub_schema['message'] = notify_sub_schema_model.message
- if getattr(notify_sub_schema_model, 'data', None):
- notify_sub_schema['data'] = notify_sub_schema_model.data
- routes = (getattr(notify_sub_schema_model, 'routes') or
- getattr(notify_sub_schema_model, 'channels'))
+ if getattr(notify_sub_schema_model, "message", None):
+ notify_sub_schema["message"] = notify_sub_schema_model.message
+ if getattr(notify_sub_schema_model, "data", None):
+ notify_sub_schema["data"] = notify_sub_schema_model.data
+ routes = getattr(notify_sub_schema_model, "routes") or getattr(
+ notify_sub_schema_model, "channels"
+ )
if routes:
- notify_sub_schema['routes'] = routes
+ notify_sub_schema["routes"] = routes
return notify_sub_schema
diff --git a/st2common/st2common/models/api/pack.py b/st2common/st2common/models/api/pack.py
index 02c6d00f63..6de2893427 100644
--- a/st2common/st2common/models/api/pack.py
+++ b/st2common/st2common/models/api/pack.py
@@ -37,16 +37,14 @@
from st2common.util.pack import validate_config_against_schema
__all__ = [
- 'PackAPI',
- 'ConfigSchemaAPI',
- 'ConfigAPI',
-
- 'ConfigItemSetAPI',
-
- 'PackInstallRequestAPI',
- 'PackRegisterRequestAPI',
- 'PackSearchRequestAPI',
- 'PackAsyncAPI'
+ "PackAPI",
+ "ConfigSchemaAPI",
+ "ConfigAPI",
+ "ConfigItemSetAPI",
+ "PackInstallRequestAPI",
+ "PackRegisterRequestAPI",
+ "PackSearchRequestAPI",
+ "PackAsyncAPI",
]
LOG = logging.getLogger(__name__)
@@ -55,124 +53,117 @@
class PackAPI(BaseAPI):
model = PackDB
schema = {
- 'type': 'object',
- 'description': 'Content pack schema.',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'description': 'Unique identifier for the pack.',
- 'default': None
+ "type": "object",
+ "description": "Content pack schema.",
+ "properties": {
+ "id": {
+ "type": "string",
+ "description": "Unique identifier for the pack.",
+ "default": None,
},
- 'name': {
- 'type': 'string',
- 'description': 'Display name of the pack. If the name only contains lowercase'
- 'letters, digits and underscores, the "ref" field is not required.',
- 'required': True
+ "name": {
+ "type": "string",
+ "description": "Display name of the pack. If the name only contains lowercase"
+ 'letters, digits and underscores, the "ref" field is not required.',
+ "required": True,
},
- 'ref': {
- 'type': 'string',
- 'description': 'Reference for the pack, used as an internal id.',
- 'default': None,
- 'pattern': PACK_REF_WHITELIST_REGEX
+ "ref": {
+ "type": "string",
+ "description": "Reference for the pack, used as an internal id.",
+ "default": None,
+ "pattern": PACK_REF_WHITELIST_REGEX,
},
- 'uid': {
- 'type': 'string'
+ "uid": {"type": "string"},
+ "description": {
+ "type": "string",
+ "description": "Brief description of the pack and the service it integrates with.",
+ "required": True,
},
- 'description': {
- 'type': 'string',
- 'description': 'Brief description of the pack and the service it integrates with.',
- 'required': True
+ "keywords": {
+ "type": "array",
+ "description": "Keywords describing the pack.",
+ "items": {"type": "string"},
+ "default": [],
},
- 'keywords': {
- 'type': 'array',
- 'description': 'Keywords describing the pack.',
- 'items': {'type': 'string'},
- 'default': []
+ "version": {
+ "type": "string",
+ "description": "Pack version. Must follow the semver format "
+ '(for instance, "0.1.0").',
+ "pattern": PACK_VERSION_REGEX,
+ "required": True,
},
- 'version': {
- 'type': 'string',
- 'description': 'Pack version. Must follow the semver format '
- '(for instance, "0.1.0").',
- 'pattern': PACK_VERSION_REGEX,
- 'required': True
+ "stackstorm_version": {
+ "type": "string",
+ "description": 'Required StackStorm version. Examples: ">1.6.0", '
+ '">=1.8.0, <2.2.0"',
+ "pattern": ST2_VERSION_REGEX,
},
- 'stackstorm_version': {
- 'type': 'string',
- 'description': 'Required StackStorm version. Examples: ">1.6.0", '
- '">=1.8.0, <2.2.0"',
- 'pattern': ST2_VERSION_REGEX,
+ "python_versions": {
+ "type": "array",
+ "description": (
+ "Major Python versions supported by this pack. E.g. "
+ '"2" for Python 2.7.x and "3" for Python 3.6.x'
+ ),
+ "items": {"type": "string", "enum": ["2", "3"]},
+ "minItems": 1,
+ "maxItems": 2,
+ "uniqueItems": True,
+ "additionalItems": True,
},
- 'python_versions': {
- 'type': 'array',
- 'description': ('Major Python versions supported by this pack. E.g. '
- '"2" for Python 2.7.x and "3" for Python 3.6.x'),
- 'items': {
- 'type': 'string',
- 'enum': [
- '2',
- '3'
- ]
- },
- 'minItems': 1,
- 'maxItems': 2,
- 'uniqueItems': True,
- 'additionalItems': True
+ "author": {
+ "type": "string",
+ "description": "Pack author or authors.",
+ "required": True,
},
- 'author': {
- 'type': 'string',
- 'description': 'Pack author or authors.',
- 'required': True
+ "email": {
+ "type": "string",
+ "description": "E-mail of the pack author.",
+ "format": "email",
},
- 'email': {
- 'type': 'string',
- 'description': 'E-mail of the pack author.',
- 'format': 'email'
+ "contributors": {
+ "type": "array",
+ "items": {"type": "string", "maxLength": 100},
+ "description": (
+ "A list of people who have contributed to the pack. Format is: "
+ "Name e.g. Tomaz Muraus ."
+ ),
},
- 'contributors': {
- 'type': 'array',
- 'items': {
- 'type': 'string',
- 'maxLength': 100
- },
- 'description': ('A list of people who have contributed to the pack. Format is: '
- 'Name e.g. Tomaz Muraus .')
+ "files": {
+ "type": "array",
+ "description": "A list of files inside the pack.",
+ "items": {"type": "string"},
+ "default": [],
},
- 'files': {
- 'type': 'array',
- 'description': 'A list of files inside the pack.',
- 'items': {'type': 'string'},
- 'default': []
+ "dependencies": {
+ "type": "array",
+ "description": "A list of other StackStorm packs this pack depends upon. "
+ 'The same format as in "st2 pack install" is used: '
+ '"[=]".',
+ "items": {"type": "string"},
+ "default": [],
},
- 'dependencies': {
- 'type': 'array',
- 'description': 'A list of other StackStorm packs this pack depends upon. '
- 'The same format as in "st2 pack install" is used: '
- '"[=]".',
- 'items': {'type': 'string'},
- 'default': []
+ "system": {
+ "type": "object",
+ "description": "Specification for the system components and packages "
+ "required for the pack.",
+ "default": {},
},
- 'system': {
- 'type': 'object',
- 'description': 'Specification for the system components and packages '
- 'required for the pack.',
- 'default': {}
+ "path": {
+ "type": "string",
+ "description": "Location of the pack on disk in st2 system.",
+ "required": False,
},
- 'path': {
- 'type': 'string',
- 'description': 'Location of the pack on disk in st2 system.',
- 'required': False
- }
},
# NOTE: We add this here explicitly so we can gracefuly add new attributs to pack.yaml
# without breaking existing installations
- 'additionalProperties': True
+ "additionalProperties": True,
}
def __init__(self, **values):
# Note: If some version values are not explicitly surrounded by quotes they are recognized
# as numbers so we cast them to string
- if values.get('version', None):
- values['version'] = str(values['version'])
+ if values.get("version", None):
+ values["version"] = str(values["version"])
super(PackAPI, self).__init__(**values)
@@ -186,17 +177,21 @@ def validate(self):
# Invalid version
if "Failed validating 'pattern' in schema['properties']['version']" in msg:
- new_msg = ('Pack version "%s" doesn\'t follow a valid semver format. Valid '
- 'versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc.' %
- (self.version))
- new_msg += '\n\n' + msg
+ new_msg = (
+ 'Pack version "%s" doesn\'t follow a valid semver format. Valid '
+ "versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc."
+ % (self.version)
+ )
+ new_msg += "\n\n" + msg
raise jsonschema.ValidationError(new_msg)
# Invalid ref / name
if "Failed validating 'pattern' in schema['properties']['ref']" in msg:
- new_msg = ('Pack ref / name can only contain valid word characters (a-z, 0-9 and '
- '_), dashes are not allowed.')
- new_msg += '\n\n' + msg
+ new_msg = (
+ "Pack ref / name can only contain valid word characters (a-z, 0-9 and "
+ "_), dashes are not allowed."
+ )
+ new_msg += "\n\n" + msg
raise jsonschema.ValidationError(new_msg)
raise e
@@ -206,24 +201,35 @@ def to_model(cls, pack):
ref = pack.ref
name = pack.name
description = pack.description
- keywords = getattr(pack, 'keywords', [])
+ keywords = getattr(pack, "keywords", [])
version = str(pack.version)
- stackstorm_version = getattr(pack, 'stackstorm_version', None)
- python_versions = getattr(pack, 'python_versions', [])
+ stackstorm_version = getattr(pack, "stackstorm_version", None)
+ python_versions = getattr(pack, "python_versions", [])
author = pack.author
email = pack.email
- contributors = getattr(pack, 'contributors', [])
- files = getattr(pack, 'files', [])
- pack_dir = getattr(pack, 'path', None)
- dependencies = getattr(pack, 'dependencies', [])
- system = getattr(pack, 'system', {})
-
- model = cls.model(ref=ref, name=name, description=description, keywords=keywords,
- version=version, author=author, email=email, contributors=contributors,
- files=files, dependencies=dependencies, system=system,
- stackstorm_version=stackstorm_version, path=pack_dir,
- python_versions=python_versions)
+ contributors = getattr(pack, "contributors", [])
+ files = getattr(pack, "files", [])
+ pack_dir = getattr(pack, "path", None)
+ dependencies = getattr(pack, "dependencies", [])
+ system = getattr(pack, "system", {})
+
+ model = cls.model(
+ ref=ref,
+ name=name,
+ description=description,
+ keywords=keywords,
+ version=version,
+ author=author,
+ email=email,
+ contributors=contributors,
+ files=files,
+ dependencies=dependencies,
+ system=system,
+ stackstorm_version=stackstorm_version,
+ path=pack_dir,
+ python_versions=python_versions,
+ )
return model
@@ -236,11 +242,11 @@ class ConfigSchemaAPI(BaseAPI):
"properties": {
"id": {
"description": "The unique identifier for the config schema.",
- "type": "string"
+ "type": "string",
},
"pack": {
"description": "The content pack this config schema belongs to.",
- "type": "string"
+ "type": "string",
},
"attributes": {
"description": "Config schema attributes.",
@@ -248,11 +254,11 @@ class ConfigSchemaAPI(BaseAPI):
"patternProperties": {
r"^\w+$": util_schema.get_action_parameters_schema()
},
- 'additionalProperties': False,
- "default": {}
- }
+ "additionalProperties": False,
+ "default": {},
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
@@ -273,19 +279,19 @@ class ConfigAPI(BaseAPI):
"properties": {
"id": {
"description": "The unique identifier for the config.",
- "type": "string"
+ "type": "string",
},
"pack": {
"description": "The content pack this config belongs to.",
- "type": "string"
+ "type": "string",
},
"values": {
"description": "Config values.",
"type": "object",
- "default": {}
- }
+ "default": {},
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def validate(self, validate_against_schema=False):
@@ -310,13 +316,15 @@ def _validate_config_values_against_schema(self):
instance = self.values or {}
schema = config_schema_db.attributes or {}
- configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/')
- config_path = os.path.join(configs_path, '%s.yaml' % (self.pack))
+ configs_path = os.path.join(cfg.CONF.system.base_path, "configs/")
+ config_path = os.path.join(configs_path, "%s.yaml" % (self.pack))
- cleaned = validate_config_against_schema(config_schema=schema,
- config_object=instance,
- config_path=config_path,
- pack_name=self.pack)
+ cleaned = validate_config_against_schema(
+ config_schema=schema,
+ config_object=instance,
+ config_path=config_path,
+ pack_name=self.pack,
+ )
return cleaned
@@ -330,15 +338,14 @@ def to_model(cls, config):
class ConfigUpdateRequestAPI(BaseAPI):
- schema = {
- "type": "object"
- }
+ schema = {"type": "object"}
class ConfigItemSetAPI(BaseAPI):
"""
API class used with the config set API endpoint.
"""
+
model = None
schema = {
"title": "",
@@ -348,30 +355,27 @@ class ConfigItemSetAPI(BaseAPI):
"name": {
"description": "Config item name (key)",
"type": "string",
- "required": True
+ "required": True,
},
"value": {
"description": "Config item value.",
"type": ["string", "number", "boolean", "array", "object"],
- "required": True
+ "required": True,
},
"scope": {
"description": "Config item scope (system / user)",
"type": "string",
"default": SYSTEM_SCOPE,
- "enum": [
- SYSTEM_SCOPE,
- USER_SCOPE
- ]
+ "enum": [SYSTEM_SCOPE, USER_SCOPE],
},
"user": {
"description": "User for user-scoped items (only available to admins).",
"type": "string",
"required": False,
- "default": None
- }
+ "default": None,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@@ -379,15 +383,13 @@ class PackInstallRequestAPI(BaseAPI):
schema = {
"type": "object",
"properties": {
- "packs": {
- "type": "array"
- },
+ "packs": {"type": "array"},
"force": {
"type": "boolean",
"description": "Force pack installation",
- "default": False
- }
- }
+ "default": False,
+ },
+ },
}
@@ -395,24 +397,14 @@ class PackRegisterRequestAPI(BaseAPI):
schema = {
"type": "object",
"properties": {
- "types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "packs": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
+ "types": {"type": "array", "items": {"type": "string"}},
+ "packs": {"type": "array", "items": {"type": "string"}},
"fail_on_failure": {
"type": "boolean",
"description": "True to fail on failure",
- "default": True
- }
- }
+ "default": True,
+ },
+ },
}
@@ -438,18 +430,13 @@ class PackSearchRequestAPI(BaseAPI):
},
"additionalProperties": False,
},
- ]
+ ],
}
class PackAsyncAPI(BaseAPI):
schema = {
"type": "object",
- "properties": {
- "execution_id": {
- "type": "string",
- "required": True
- }
- },
- "additionalProperties": False
+ "properties": {"execution_id": {"type": "string", "required": True}},
+ "additionalProperties": False,
}
diff --git a/st2common/st2common/models/api/policy.py b/st2common/st2common/models/api/policy.py
index a46dad9eda..211560d453 100644
--- a/st2common/st2common/models/api/policy.py
+++ b/st2common/st2common/models/api/policy.py
@@ -22,7 +22,7 @@
from st2common.util import schema as util_schema
-__all__ = ['PolicyTypeAPI']
+__all__ = ["PolicyTypeAPI"]
LOG = logging.getLogger(__name__)
@@ -33,55 +33,34 @@ class PolicyTypeAPI(BaseAPI, APIUIDMixin):
"title": "Policy Type",
"type": "object",
"properties": {
- "id": {
- "type": "string",
- "default": None
- },
- 'uid': {
- 'type': 'string'
- },
- "name": {
- "type": "string",
- "required": True
- },
- "resource_type": {
- "enum": ["action"],
- "required": True
- },
- "ref": {
- "type": "string"
- },
- "description": {
- "type": "string"
- },
- "enabled": {
- "type": "boolean",
- "default": True
- },
- "module": {
- "type": "string",
- "required": True
- },
+ "id": {"type": "string", "default": None},
+ "uid": {"type": "string"},
+ "name": {"type": "string", "required": True},
+ "resource_type": {"enum": ["action"], "required": True},
+ "ref": {"type": "string"},
+ "description": {"type": "string"},
+ "enabled": {"type": "boolean", "default": True},
+ "module": {"type": "string", "required": True},
"parameters": {
"type": "object",
- "patternProperties": {
- r"^\w+$": util_schema.get_draft_schema()
- },
- 'additionalProperties': False
- }
+ "patternProperties": {r"^\w+$": util_schema.get_draft_schema()},
+ "additionalProperties": False,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
@classmethod
def to_model(cls, instance):
- return cls.model(name=str(instance.name),
- description=getattr(instance, 'description', None),
- resource_type=str(instance.resource_type),
- ref=getattr(instance, 'ref', None),
- enabled=getattr(instance, 'enabled', None),
- module=str(instance.module),
- parameters=getattr(instance, 'parameters', dict()))
+ return cls.model(
+ name=str(instance.name),
+ description=getattr(instance, "description", None),
+ resource_type=str(instance.resource_type),
+ ref=getattr(instance, "ref", None),
+ enabled=getattr(instance, "enabled", None),
+ module=str(instance.module),
+ parameters=getattr(instance, "parameters", dict()),
+ )
class PolicyAPI(BaseAPI, APIUIDMixin):
@@ -90,38 +69,15 @@ class PolicyAPI(BaseAPI, APIUIDMixin):
"title": "Policy",
"type": "object",
"properties": {
- "id": {
- "type": "string",
- "default": None
- },
- 'uid': {
- 'type': 'string'
- },
- "name": {
- "type": "string",
- "required": True
- },
- "pack": {
- "type": "string"
- },
- "ref": {
- "type": "string"
- },
- "description": {
- "type": "string"
- },
- "enabled": {
- "type": "boolean",
- "default": True
- },
- "resource_ref": {
- "type": "string",
- "required": True
- },
- "policy_type": {
- "type": "string",
- "required": True
- },
+ "id": {"type": "string", "default": None},
+ "uid": {"type": "string"},
+ "name": {"type": "string", "required": True},
+ "pack": {"type": "string"},
+ "ref": {"type": "string"},
+ "description": {"type": "string"},
+ "enabled": {"type": "boolean", "default": True},
+ "resource_ref": {"type": "string", "required": True},
+ "policy_type": {"type": "string", "required": True},
"parameters": {
"type": "object",
"patternProperties": {
@@ -132,20 +88,19 @@ class PolicyAPI(BaseAPI, APIUIDMixin):
{"type": "integer"},
{"type": "number"},
{"type": "object"},
- {"type": "string"}
+ {"type": "string"},
]
}
},
- 'additionalProperties': False
-
+ "additionalProperties": False,
},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def validate(self):
@@ -156,15 +111,19 @@ def validate(self):
# pylint: disable=no-member
policy_type_db = PolicyType.get_by_ref(cleaned.policy_type)
if not policy_type_db:
- raise ValueError('Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type))
+ raise ValueError(
+ 'Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type)
+ )
parameters_schema = policy_type_db.parameters
- parameters = getattr(cleaned, 'parameters', {})
+ parameters = getattr(cleaned, "parameters", {})
schema = util_schema.get_schema_for_resource_parameters(
- parameters_schema=parameters_schema)
+ parameters_schema=parameters_schema
+ )
validator = util_schema.get_validator()
- cleaned_parameters = util_schema.validate(parameters, schema, validator, use_default=True,
- allow_default_none=True)
+ cleaned_parameters = util_schema.validate(
+ parameters, schema, validator, use_default=True, allow_default_none=True
+ )
cleaned.parameters = cleaned_parameters
@@ -172,13 +131,15 @@ def validate(self):
@classmethod
def to_model(cls, instance):
- return cls.model(id=getattr(instance, 'id', None),
- name=str(instance.name),
- description=getattr(instance, 'description', None),
- pack=str(instance.pack),
- ref=getattr(instance, 'ref', None),
- enabled=getattr(instance, 'enabled', None),
- resource_ref=str(instance.resource_ref),
- policy_type=str(instance.policy_type),
- parameters=getattr(instance, 'parameters', dict()),
- metadata_file=getattr(instance, 'metadata_file', None))
+ return cls.model(
+ id=getattr(instance, "id", None),
+ name=str(instance.name),
+ description=getattr(instance, "description", None),
+ pack=str(instance.pack),
+ ref=getattr(instance, "ref", None),
+ enabled=getattr(instance, "enabled", None),
+ resource_ref=str(instance.resource_ref),
+ policy_type=str(instance.policy_type),
+ parameters=getattr(instance, "parameters", dict()),
+ metadata_file=getattr(instance, "metadata_file", None),
+ )
diff --git a/st2common/st2common/models/api/rbac.py b/st2common/st2common/models/api/rbac.py
index 556793b7a6..ffaff75409 100644
--- a/st2common/st2common/models/api/rbac.py
+++ b/st2common/st2common/models/api/rbac.py
@@ -25,67 +25,55 @@
from st2common.util.uid import parse_uid
__all__ = [
- 'RoleAPI',
- 'UserRoleAssignmentAPI',
-
- 'RoleDefinitionFileFormatAPI',
- 'UserRoleAssignmentFileFormatAPI',
-
- 'AuthGroupToRoleMapAssignmentFileFormatAPI'
+ "RoleAPI",
+ "UserRoleAssignmentAPI",
+ "RoleDefinitionFileFormatAPI",
+ "UserRoleAssignmentFileFormatAPI",
+ "AuthGroupToRoleMapAssignmentFileFormatAPI",
]
class RoleAPI(BaseAPI):
model = RoleDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'name': {
- 'type': 'string',
- 'required': True
- },
- 'description': {
- 'type': 'string'
- },
- 'permission_grant_ids': {
- 'type': 'array',
- 'items': {
- 'type': 'string'
- }
- },
- 'permission_grant_objects': {
- 'type': 'array',
- 'items': {
- 'type': 'object'
- }
- }
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "name": {"type": "string", "required": True},
+ "description": {"type": "string"},
+ "permission_grant_ids": {"type": "array", "items": {"type": "string"}},
+ "permission_grant_objects": {"type": "array", "items": {"type": "object"}},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
- def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects=True):
+ def from_model(
+ cls, model, mask_secrets=False, retrieve_permission_grant_objects=True
+ ):
role = cls._from_model(model, mask_secrets=mask_secrets)
# Convert ObjectIDs to strings
- role['permission_grant_ids'] = [str(permission_grant) for permission_grant in
- model.permission_grants]
+ role["permission_grant_ids"] = [
+ str(permission_grant) for permission_grant in model.permission_grants
+ ]
# Retrieve and include corresponding permission grant objects
if retrieve_permission_grant_objects:
from st2common.persistence.rbac import PermissionGrant
- permission_grant_dbs = PermissionGrant.query(id__in=role['permission_grants'])
+
+ permission_grant_dbs = PermissionGrant.query(
+ id__in=role["permission_grants"]
+ )
permission_grant_apis = []
for permission_grant_db in permission_grant_dbs:
- permission_grant_api = PermissionGrantAPI.from_model(permission_grant_db)
+ permission_grant_api = PermissionGrantAPI.from_model(
+ permission_grant_db
+ )
permission_grant_apis.append(permission_grant_api)
- role['permission_grant_objects'] = permission_grant_apis
+ role["permission_grant_objects"] = permission_grant_apis
return cls(**role)
@@ -93,56 +81,30 @@ def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects
class UserRoleAssignmentAPI(BaseAPI):
model = UserRoleAssignmentDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'user': {
- 'type': 'string',
- 'required': True
- },
- 'role': {
- 'type': 'string',
- 'required': True
- },
- 'description': {
- 'type': 'string'
- },
- 'is_remote': {
- 'type': 'boolean'
- },
- 'source': {
- 'type': 'string'
- }
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "user": {"type": "string", "required": True},
+ "role": {"type": "string", "required": True},
+ "description": {"type": "string"},
+ "is_remote": {"type": "boolean"},
+ "source": {"type": "string"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class PermissionGrantAPI(BaseAPI):
model = PermissionGrantDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'resource_uid': {
- 'type': 'string',
- 'required': True
- },
- 'resource_type': {
- 'type': 'string',
- 'required': True
- },
- 'permission_types': {
- 'type': 'array'
- }
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "resource_uid": {"type": "string", "required": True},
+ "resource_type": {"type": "string", "required": True},
+ "permission_types": {"type": "array"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@@ -152,53 +114,55 @@ class RoleDefinitionFileFormatAPI(BaseAPI):
"""
schema = {
- 'type': 'object',
- 'properties': {
- 'name': {
- 'type': 'string',
- 'description': 'Role name',
- 'required': True,
- 'default': None
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string",
+ "description": "Role name",
+ "required": True,
+ "default": None,
},
- 'description': {
- 'type': 'string',
- 'description': 'Role description',
- 'required': False
+ "description": {
+ "type": "string",
+ "description": "Role description",
+ "required": False,
},
- 'enabled': {
- 'type': 'boolean',
- 'description': ('Flag indicating if this role is enabled. Note: Disabled roles '
- 'are simply ignored when loading definitions from disk.'),
- 'default': True
+ "enabled": {
+ "type": "boolean",
+ "description": (
+ "Flag indicating if this role is enabled. Note: Disabled roles "
+ "are simply ignored when loading definitions from disk."
+ ),
+ "default": True,
},
- 'permission_grants': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'resource_uid': {
- 'type': 'string',
- 'description': 'UID of a resource to which this grant applies to.',
- 'required': False,
- 'default': None
+ "permission_grants": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "resource_uid": {
+ "type": "string",
+ "description": "UID of a resource to which this grant applies to.",
+ "required": False,
+ "default": None,
},
- 'permission_types': {
- 'type': 'array',
- 'description': 'A list of permission types to grant',
- 'uniqueItems': True,
- 'items': {
- 'type': 'string',
+ "permission_types": {
+ "type": "array",
+ "description": "A list of permission types to grant",
+ "uniqueItems": True,
+ "items": {
+ "type": "string",
# Note: We permission aditional validation for based on the
# resource type in other place
- 'enum': PermissionType.get_valid_values()
+ "enum": PermissionType.get_valid_values(),
},
- 'default': []
- }
- }
- }
- }
+ "default": [],
+ },
+ },
+ },
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
def validate(self):
@@ -208,31 +172,43 @@ def validate(self):
# Custom validation
# Validate that only the correct permission types are used
- permission_grants = getattr(self, 'permission_grants', [])
+ permission_grants = getattr(self, "permission_grants", [])
for permission_grant in permission_grants:
- resource_uid = permission_grant.get('resource_uid', None)
- permission_types = permission_grant.get('permission_types', [])
+ resource_uid = permission_grant.get("resource_uid", None)
+ permission_types = permission_grant.get("permission_types", [])
if resource_uid:
# Permission types which apply to a resource
resource_type, _ = parse_uid(uid=resource_uid)
- valid_permission_types = PermissionType.get_valid_permissions_for_resource_type(
- resource_type=resource_type)
+ valid_permission_types = (
+ PermissionType.get_valid_permissions_for_resource_type(
+ resource_type=resource_type
+ )
+ )
for permission_type in permission_types:
if permission_type not in valid_permission_types:
- message = ('Invalid permission type "%s" for resource type "%s"' %
- (permission_type, resource_type))
+ message = (
+ 'Invalid permission type "%s" for resource type "%s"'
+ % (
+ permission_type,
+ resource_type,
+ )
+ )
raise ValueError(message)
else:
# Right now we only support single permission type (list) which is global and
# doesn't apply to a resource
for permission_type in permission_types:
if permission_type not in GLOBAL_PERMISSION_TYPES:
- valid_global_permission_types = ', '.join(GLOBAL_PERMISSION_TYPES)
- message = ('Invalid permission type "%s". Valid global permission types '
- 'which can be used without a resource id are: %s' %
- (permission_type, valid_global_permission_types))
+ valid_global_permission_types = ", ".join(
+ GLOBAL_PERMISSION_TYPES
+ )
+ message = (
+ 'Invalid permission type "%s". Valid global permission types '
+ "which can be used without a resource id are: %s"
+ % (permission_type, valid_global_permission_types)
+ )
raise ValueError(message)
return cleaned
@@ -252,52 +228,52 @@ def validate(self, validate_role_exists=False):
if validate_role_exists:
# Validate that the referenced roles exist in the db
rbac_service = get_rbac_backend().get_service_class()
- rbac_service.validate_roles_exists(role_names=self.roles) # pylint: disable=no-member
-
+ # pylint: disable=no-member
+ rbac_service.validate_roles_exists(role_names=self.roles)
+ # pylint: enable=no-member
return cleaned
class UserRoleAssignmentFileFormatAPI(BaseAPI):
schema = {
- 'type': 'object',
- 'properties': {
- 'username': {
- 'type': 'string',
- 'description': 'Username',
- 'required': True,
- 'default': None
+ "type": "object",
+ "properties": {
+ "username": {
+ "type": "string",
+ "description": "Username",
+ "required": True,
+ "default": None,
},
- 'description': {
- 'type': 'string',
- 'description': 'Assignment description',
- 'required': False,
- 'default': None
+ "description": {
+ "type": "string",
+ "description": "Assignment description",
+ "required": False,
+ "default": None,
},
- 'enabled': {
- 'type': 'boolean',
- 'description': ('Flag indicating if this assignment is enabled. Note: Disabled '
- 'assignments are simply ignored when loading definitions from '
- ' disk.'),
- 'default': True
+ "enabled": {
+ "type": "boolean",
+ "description": (
+ "Flag indicating if this assignment is enabled. Note: Disabled "
+ "assignments are simply ignored when loading definitions from "
+ " disk."
+ ),
+ "default": True,
},
- 'roles': {
- 'type': 'array',
- 'description': 'Roles assigned to this user',
- 'uniqueItems': True,
- 'items': {
- 'type': 'string'
- },
- 'required': True
+ "roles": {
+ "type": "array",
+ "description": "Roles assigned to this user",
+ "uniqueItems": True,
+ "items": {"type": "string"},
+ "required": True,
+ },
+ "file_path": {
+ "type": "string",
+ "description": "Path of the file of where this assignment comes from.",
+ "default": None,
+ "required": False,
},
- 'file_path': {
- 'type': 'string',
- 'description': 'Path of the file of where this assignment comes from.',
- 'default': None,
- 'required': False
- }
-
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
def validate(self, validate_role_exists=False):
@@ -307,44 +283,46 @@ def validate(self, validate_role_exists=False):
class AuthGroupToRoleMapAssignmentFileFormatAPI(BaseAPI):
schema = {
- 'type': 'object',
- 'properties': {
- 'group': {
- 'type': 'string',
- 'description': 'Name of the group as returned by auth backend.',
- 'required': True
+ "type": "object",
+ "properties": {
+ "group": {
+ "type": "string",
+ "description": "Name of the group as returned by auth backend.",
+ "required": True,
},
- 'description': {
- 'type': 'string',
- 'description': 'Mapping description',
- 'required': False,
- 'default': None
+ "description": {
+ "type": "string",
+ "description": "Mapping description",
+ "required": False,
+ "default": None,
},
- 'enabled': {
- 'type': 'boolean',
- 'description': ('Flag indicating if this mapping is enabled. Note: Disabled '
- 'assignments are simply ignored when loading definitions from '
- ' disk.'),
- 'default': True
+ "enabled": {
+ "type": "boolean",
+ "description": (
+ "Flag indicating if this mapping is enabled. Note: Disabled "
+ "assignments are simply ignored when loading definitions from "
+ " disk."
+ ),
+ "default": True,
},
- 'roles': {
- 'type': 'array',
- 'description': ('StackStorm roles which are assigned to each user which belongs '
- 'to that group.'),
- 'uniqueItems': True,
- 'items': {
- 'type': 'string'
- },
- 'required': True
+ "roles": {
+ "type": "array",
+ "description": (
+ "StackStorm roles which are assigned to each user which belongs "
+ "to that group."
+ ),
+ "uniqueItems": True,
+ "items": {"type": "string"},
+ "required": True,
+ },
+ "file_path": {
+ "type": "string",
+ "description": "Path of the file of where this assignment comes from.",
+ "default": None,
+ "required": False,
},
- 'file_path': {
- 'type': 'string',
- 'description': 'Path of the file of where this assignment comes from.',
- 'default': None,
- 'required': False
- }
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
def validate(self, validate_role_exists=False):
diff --git a/st2common/st2common/models/api/rule.py b/st2common/st2common/models/api/rule.py
index 716eeec207..8919a2ffc9 100644
--- a/st2common/st2common/models/api/rule.py
+++ b/st2common/st2common/models/api/rule.py
@@ -20,7 +20,12 @@
from st2common.models.api.base import BaseAPI
from st2common.models.api.base import APIUIDMixin
from st2common.models.api.tag import TagsHelper
-from st2common.models.db.rule import RuleDB, RuleTypeDB, RuleTypeSpecDB, ActionExecutionSpecDB
+from st2common.models.db.rule import (
+ RuleDB,
+ RuleTypeDB,
+ RuleTypeSpecDB,
+ ActionExecutionSpecDB,
+)
from st2common.models.system.common import ResourceReference
from st2common.persistence.trigger import Trigger
import st2common.services.triggers as TriggerService
@@ -30,61 +35,52 @@
class RuleTypeSpec(BaseAPI):
schema = {
- 'type': 'object',
- 'properties': {
- 'ref': {
- 'type': 'string',
- 'required': True
- },
- 'parameters': {
- 'type': 'object'
- }
+ "type": "object",
+ "properties": {
+ "ref": {"type": "string", "required": True},
+ "parameters": {"type": "object"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class RuleTypeAPI(BaseAPI):
model = RuleTypeDB
schema = {
- 'title': 'RuleType',
- 'description': 'A specific type of rule.',
- 'type': 'object',
- 'properties': {
- 'id': {
- 'description': 'The unique identifier for the rule type.',
- 'type': 'string',
- 'default': None
- },
- 'name': {
- 'description': 'The name for the rule type.',
- 'type': 'string',
- 'required': True
+ "title": "RuleType",
+ "description": "A specific type of rule.",
+ "type": "object",
+ "properties": {
+ "id": {
+ "description": "The unique identifier for the rule type.",
+ "type": "string",
+ "default": None,
},
- 'description': {
- 'description': 'The description of the rule type.',
- 'type': 'string'
+ "name": {
+ "description": "The name for the rule type.",
+ "type": "string",
+ "required": True,
},
- 'enabled': {
- 'type': 'boolean',
- 'default': True
+ "description": {
+ "description": "The description of the rule type.",
+ "type": "string",
},
- 'parameters': {
- 'type': 'object'
- }
+ "enabled": {"type": "boolean", "default": True},
+ "parameters": {"type": "object"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def to_model(cls, rule_type):
- name = getattr(rule_type, 'name', None)
- description = getattr(rule_type, 'description', None)
- enabled = getattr(rule_type, 'enabled', False)
- parameters = getattr(rule_type, 'parameters', {})
+ name = getattr(rule_type, "name", None)
+ description = getattr(rule_type, "description", None)
+ enabled = getattr(rule_type, "enabled", False)
+ parameters = getattr(rule_type, "parameters", {})
- return cls.model(name=name, description=description, enabled=enabled,
- parameters=parameters)
+ return cls.model(
+ name=name, description=description, enabled=enabled, parameters=parameters
+ )
class RuleAPI(BaseAPI, APIUIDMixin):
@@ -113,100 +109,60 @@ class RuleAPI(BaseAPI, APIUIDMixin):
status: enabled or disabled. If disabled occurrence of the trigger
does not lead to execution of a action and vice-versa.
"""
+
model = RuleDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
"ref": {
"description": (
"System computed user friendly reference for the rule. "
"Provided value will be overridden by computed value."
),
- "type": "string"
- },
- 'uid': {
- 'type': 'string'
- },
- 'name': {
- 'type': 'string',
- 'required': True
- },
- 'pack': {
- 'type': 'string',
- 'default': DEFAULT_PACK_NAME
- },
- 'description': {
- 'type': 'string'
+ "type": "string",
},
- 'type': RuleTypeSpec.schema,
- 'trigger': {
- 'type': 'object',
- 'required': True,
- 'properties': {
- 'type': {
- 'type': 'string',
- 'required': True
- },
- 'description': {
- 'type': 'string',
- 'require': False
- },
- 'parameters': {
- 'type': 'object',
- 'default': {}
- },
- 'ref': {
- 'type': 'string',
- 'required': False
- }
+ "uid": {"type": "string"},
+ "name": {"type": "string", "required": True},
+ "pack": {"type": "string", "default": DEFAULT_PACK_NAME},
+ "description": {"type": "string"},
+ "type": RuleTypeSpec.schema,
+ "trigger": {
+ "type": "object",
+ "required": True,
+ "properties": {
+ "type": {"type": "string", "required": True},
+ "description": {"type": "string", "require": False},
+ "parameters": {"type": "object", "default": {}},
+ "ref": {"type": "string", "required": False},
},
- 'additionalProperties': True
- },
- 'criteria': {
- 'type': 'object',
- 'default': {}
- },
- 'action': {
- 'type': 'object',
- 'required': True,
- 'properties': {
- 'ref': {
- 'type': 'string',
- 'required': True
- },
- 'description': {
- 'type': 'string',
- 'require': False
- },
- 'parameters': {
- 'type': 'object'
- }
+ "additionalProperties": True,
+ },
+ "criteria": {"type": "object", "default": {}},
+ "action": {
+ "type": "object",
+ "required": True,
+ "properties": {
+ "ref": {"type": "string", "required": True},
+ "description": {"type": "string", "require": False},
+ "parameters": {"type": "object"},
},
- 'additionalProperties': False
- },
- 'enabled': {
- 'type': 'boolean',
- 'default': False
- },
- 'context': {
- 'type': 'object'
+ "additionalProperties": False,
},
+ "enabled": {"type": "boolean", "default": False},
+ "context": {"type": "object"},
"tags": {
"description": "User associated metadata assigned to this object.",
"type": "array",
- "items": {"type": "object"}
+ "items": {"type": "object"},
},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
@@ -215,58 +171,62 @@ def from_model(cls, model, mask_secrets=False, ignore_missing_trigger=False):
trigger_db = reference.get_model_by_resource_ref(Trigger, model.trigger)
if not ignore_missing_trigger and not trigger_db:
- raise ValueError('Missing TriggerDB object for rule %s' % (rule['id']))
+ raise ValueError("Missing TriggerDB object for rule %s" % (rule["id"]))
if trigger_db:
- rule['trigger'] = {
- 'type': trigger_db.type,
- 'parameters': trigger_db.parameters,
- 'ref': model.trigger
+ rule["trigger"] = {
+ "type": trigger_db.type,
+ "parameters": trigger_db.parameters,
+ "ref": model.trigger,
}
- rule['tags'] = TagsHelper.from_model(model.tags)
+ rule["tags"] = TagsHelper.from_model(model.tags)
return cls(**rule)
@classmethod
def to_model(cls, rule):
kwargs = {}
- kwargs['name'] = getattr(rule, 'name', None)
- kwargs['description'] = getattr(rule, 'description', None)
+ kwargs["name"] = getattr(rule, "name", None)
+ kwargs["description"] = getattr(rule, "description", None)
# Validate trigger parameters
# Note: This must happen before we create a trigger, otherwise create trigger could fail
# with a cryptic error
- trigger = getattr(rule, 'trigger', {})
- trigger_type_ref = trigger.get('type', None)
- parameters = trigger.get('parameters', {})
+ trigger = getattr(rule, "trigger", {})
+ trigger_type_ref = trigger.get("type", None)
+ parameters = trigger.get("parameters", {})
- validator.validate_trigger_parameters(trigger_type_ref=trigger_type_ref,
- parameters=parameters)
+ validator.validate_trigger_parameters(
+ trigger_type_ref=trigger_type_ref, parameters=parameters
+ )
# Create a trigger for the provided rule
trigger_db = TriggerService.create_trigger_db_from_rule(rule)
- kwargs['trigger'] = reference.get_str_resource_ref_from_model(trigger_db)
+ kwargs["trigger"] = reference.get_str_resource_ref_from_model(trigger_db)
- kwargs['pack'] = getattr(rule, 'pack', DEFAULT_PACK_NAME)
- kwargs['ref'] = ResourceReference.to_string_reference(pack=kwargs['pack'],
- name=kwargs['name'])
+ kwargs["pack"] = getattr(rule, "pack", DEFAULT_PACK_NAME)
+ kwargs["ref"] = ResourceReference.to_string_reference(
+ pack=kwargs["pack"], name=kwargs["name"]
+ )
# Validate criteria
- kwargs['criteria'] = dict(getattr(rule, 'criteria', {}))
- validator.validate_criteria(kwargs['criteria'])
+ kwargs["criteria"] = dict(getattr(rule, "criteria", {}))
+ validator.validate_criteria(kwargs["criteria"])
- kwargs['action'] = ActionExecutionSpecDB(ref=rule.action['ref'],
- parameters=rule.action.get('parameters', {}))
+ kwargs["action"] = ActionExecutionSpecDB(
+ ref=rule.action["ref"], parameters=rule.action.get("parameters", {})
+ )
- rule_type = dict(getattr(rule, 'type', {}))
+ rule_type = dict(getattr(rule, "type", {}))
if rule_type:
- kwargs['type'] = RuleTypeSpecDB(ref=rule_type['ref'],
- parameters=rule_type.get('parameters', {}))
+ kwargs["type"] = RuleTypeSpecDB(
+ ref=rule_type["ref"], parameters=rule_type.get("parameters", {})
+ )
- kwargs['enabled'] = getattr(rule, 'enabled', False)
- kwargs['context'] = getattr(rule, 'context', dict())
- kwargs['tags'] = TagsHelper.to_model(getattr(rule, 'tags', []))
- kwargs['metadata_file'] = getattr(rule, 'metadata_file', None)
+ kwargs["enabled"] = getattr(rule, "enabled", False)
+ kwargs["context"] = getattr(rule, "context", dict())
+ kwargs["tags"] = TagsHelper.to_model(getattr(rule, "tags", []))
+ kwargs["metadata_file"] = getattr(rule, "metadata_file", None)
model = cls.model(**kwargs)
return model
@@ -277,13 +237,5 @@ class RuleViewAPI(RuleAPI):
# Always deep-copy to avoid breaking the original.
schema = copy.deepcopy(RuleAPI.schema)
# Update the schema to include the description properties
- schema['properties']['action'].update({
- 'description': {
- 'type': 'string'
- }
- })
- schema['properties']['trigger'].update({
- 'description': {
- 'type': 'string'
- }
- })
+ schema["properties"]["action"].update({"description": {"type": "string"}})
+ schema["properties"]["trigger"].update({"description": {"type": "string"}})
diff --git a/st2common/st2common/models/api/rule_enforcement.py b/st2common/st2common/models/api/rule_enforcement.py
index c950b59bfe..d7aa1bc873 100644
--- a/st2common/st2common/models/api/rule_enforcement.py
+++ b/st2common/st2common/models/api/rule_enforcement.py
@@ -28,95 +28,98 @@
from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUSES
from st2common.util import isotime
-__all__ = [
- 'RuleEnforcementAPI',
- 'RuleEnforcementViewAPI',
-
- 'RuleReferenceSpecDB'
-]
+__all__ = ["RuleEnforcementAPI", "RuleEnforcementViewAPI", "RuleReferenceSpecDB"]
class RuleReferenceSpec(BaseAPI):
schema = {
- 'type': 'object',
- 'properties': {
- 'ref': {
- 'type': 'string',
- 'required': True,
+ "type": "object",
+ "properties": {
+ "ref": {
+ "type": "string",
+ "required": True,
},
- 'uid': {
- 'type': 'string',
- 'required': True,
+ "uid": {
+ "type": "string",
+ "required": True,
},
- 'id': {
- 'type': 'string',
- 'required': False,
+ "id": {
+ "type": "string",
+ "required": False,
},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class RuleEnforcementAPI(BaseAPI):
model = RuleEnforcementDB
schema = {
- 'title': 'RuleEnforcement',
- 'description': 'A specific instance of rule enforcement.',
- 'type': 'object',
- 'properties': {
- 'trigger_instance_id': {
- 'description': 'The unique identifier for the trigger instance ' +
- 'that flipped the rule.',
- 'type': 'string',
- 'required': True
+ "title": "RuleEnforcement",
+ "description": "A specific instance of rule enforcement.",
+ "type": "object",
+ "properties": {
+ "trigger_instance_id": {
+ "description": "The unique identifier for the trigger instance "
+ + "that flipped the rule.",
+ "type": "string",
+ "required": True,
},
- 'execution_id': {
- 'description': 'ID of the action execution that was invoked as a response.',
- 'type': 'string'
+ "execution_id": {
+ "description": "ID of the action execution that was invoked as a response.",
+ "type": "string",
},
- 'failure_reason': {
- 'description': 'Reason for failure to execute the action specified in the rule.',
- 'type': 'string'
+ "failure_reason": {
+ "description": "Reason for failure to execute the action specified in the rule.",
+ "type": "string",
},
- 'rule': RuleReferenceSpec.schema,
- 'enforced_at': {
- 'description': 'Timestamp when rule enforcement happened.',
- 'type': 'string',
- 'required': True
+ "rule": RuleReferenceSpec.schema,
+ "enforced_at": {
+ "description": "Timestamp when rule enforcement happened.",
+ "type": "string",
+ "required": True,
},
"status": {
"description": "Rule enforcement status.",
"type": "string",
- "enum": RULE_ENFORCEMENT_STATUSES
+ "enum": RULE_ENFORCEMENT_STATUSES,
},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def to_model(cls, rule_enforcement):
- trigger_instance_id = getattr(rule_enforcement, 'trigger_instance_id', None)
- execution_id = getattr(rule_enforcement, 'execution_id', None)
- enforced_at = getattr(rule_enforcement, 'enforced_at', None)
- failure_reason = getattr(rule_enforcement, 'failure_reason', None)
- status = getattr(rule_enforcement, 'status', RULE_ENFORCEMENT_STATUS_SUCCEEDED)
-
- rule_ref_model = dict(getattr(rule_enforcement, 'rule', {}))
- rule = RuleReferenceSpecDB(ref=rule_ref_model['ref'], id=rule_ref_model['id'],
- uid=rule_ref_model['uid'])
+ trigger_instance_id = getattr(rule_enforcement, "trigger_instance_id", None)
+ execution_id = getattr(rule_enforcement, "execution_id", None)
+ enforced_at = getattr(rule_enforcement, "enforced_at", None)
+ failure_reason = getattr(rule_enforcement, "failure_reason", None)
+ status = getattr(rule_enforcement, "status", RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+
+ rule_ref_model = dict(getattr(rule_enforcement, "rule", {}))
+ rule = RuleReferenceSpecDB(
+ ref=rule_ref_model["ref"],
+ id=rule_ref_model["id"],
+ uid=rule_ref_model["uid"],
+ )
if enforced_at:
enforced_at = isotime.parse(enforced_at)
- return cls.model(trigger_instance_id=trigger_instance_id, execution_id=execution_id,
- failure_reason=failure_reason, enforced_at=enforced_at, rule=rule,
- status=status)
+ return cls.model(
+ trigger_instance_id=trigger_instance_id,
+ execution_id=execution_id,
+ failure_reason=failure_reason,
+ enforced_at=enforced_at,
+ rule=rule,
+ status=status,
+ )
@classmethod
def from_model(cls, model, mask_secrets=False):
doc = cls._from_model(model, mask_secrets=mask_secrets)
enforced_at = isotime.format(model.enforced_at, offset=False)
- doc['enforced_at'] = enforced_at
+ doc["enforced_at"] = enforced_at
attrs = {attr: value for attr, value in six.iteritems(doc) if value}
return cls(**attrs)
@@ -126,7 +129,7 @@ class RuleEnforcementViewAPI(RuleEnforcementAPI):
schema = copy.deepcopy(RuleEnforcementAPI.schema)
# Update the schema to include additional execution properties
- schema['properties']['execution'] = copy.deepcopy(ActionExecutionAPI.schema)
+ schema["properties"]["execution"] = copy.deepcopy(ActionExecutionAPI.schema)
# Update the schema to include additional trigger instance properties
- schema['properties']['trigger_instance'] = copy.deepcopy(TriggerInstanceAPI.schema)
+ schema["properties"]["trigger_instance"] = copy.deepcopy(TriggerInstanceAPI.schema)
diff --git a/st2common/st2common/models/api/sensor.py b/st2common/st2common/models/api/sensor.py
index af9c687611..a2ba978adf 100644
--- a/st2common/st2common/models/api/sensor.py
+++ b/st2common/st2common/models/api/sensor.py
@@ -22,53 +22,34 @@
class SensorTypeAPI(BaseAPI):
model = SensorTypeDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'ref': {
- 'type': 'string'
- },
- 'uid': {
- 'type': 'string'
- },
- 'name': {
- 'type': 'string',
- 'required': True
- },
- 'pack': {
- 'type': 'string'
- },
- 'description': {
- 'type': 'string'
- },
- 'artifact_uri': {
- 'type': 'string',
- },
- 'entry_point': {
- 'type': 'string',
- },
- 'enabled': {
- 'description': 'Enable or disable the sensor.',
- 'type': 'boolean',
- 'default': True
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "ref": {"type": "string"},
+ "uid": {"type": "string"},
+ "name": {"type": "string", "required": True},
+ "pack": {"type": "string"},
+ "description": {"type": "string"},
+ "artifact_uri": {
+ "type": "string",
},
- 'trigger_types': {
- 'type': 'array',
- 'default': []
+ "entry_point": {
+ "type": "string",
},
- 'poll_interval': {
- 'type': 'number'
+ "enabled": {
+ "description": "Enable or disable the sensor.",
+ "type": "boolean",
+ "default": True,
},
+ "trigger_types": {"type": "array", "default": []},
+ "poll_interval": {"type": "number"},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
diff --git a/st2common/st2common/models/api/tag.py b/st2common/st2common/models/api/tag.py
index 78d92568e0..0ed763a51f 100644
--- a/st2common/st2common/models/api/tag.py
+++ b/st2common/st2common/models/api/tag.py
@@ -16,19 +16,19 @@
from __future__ import absolute_import
from st2common.models.db.stormbase import TagField
-__all__ = [
- 'TagsHelper'
-]
+__all__ = ["TagsHelper"]
class TagsHelper(object):
-
@staticmethod
def to_model(tags):
tags = tags or []
- return [TagField(name=tag.get('name', ''), value=tag.get('value', '')) for tag in tags]
+ return [
+ TagField(name=tag.get("name", ""), value=tag.get("value", ""))
+ for tag in tags
+ ]
@staticmethod
def from_model(tags):
tags = tags or []
- return [{'name': tag.name, 'value': tag.value} for tag in tags]
+ return [{"name": tag.name, "value": tag.value} for tag in tags]
diff --git a/st2common/st2common/models/api/trace.py b/st2common/st2common/models/api/trace.py
index f09faf6d36..4ce8ec8942 100644
--- a/st2common/st2common/models/api/trace.py
+++ b/st2common/st2common/models/api/trace.py
@@ -21,141 +21,148 @@
TraceComponentAPISchema = {
- 'type': 'object',
- 'properties': {
- 'object_id': {
- 'type': 'string',
- 'description': 'Id of the component',
- 'required': True
+ "type": "object",
+ "properties": {
+ "object_id": {
+ "type": "string",
+ "description": "Id of the component",
+ "required": True,
},
- 'ref': {
- 'type': 'string',
- 'description': 'ref of the component',
- 'required': False
+ "ref": {
+ "type": "string",
+ "description": "ref of the component",
+ "required": False,
},
- 'updated_at': {
- 'description': 'The start time when the action is executed.',
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
+ "updated_at": {
+ "description": "The start time when the action is executed.",
+ "type": "string",
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
- 'caused_by': {
- 'type': 'object',
- 'description': 'Component that is the cause or the predecesor.',
- 'properties': {
- 'id': {
- 'description': 'Id of the causal component.',
- 'type': 'string'
+ "caused_by": {
+ "type": "object",
+ "description": "Component that is the cause or the predecesor.",
+ "properties": {
+ "id": {"description": "Id of the causal component.", "type": "string"},
+ "type": {
+ "description": "Type of the causal component.",
+ "type": "string",
},
- 'type': {
- 'description': 'Type of the causal component.',
- 'type': 'string'
- }
- }
- }
+ },
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class TraceAPI(BaseAPI, APIUIDMixin):
model = TraceDB
schema = {
- 'title': 'Trace',
- 'desciption': 'Trace is a collection of all TriggerInstances, Rules and ActionExecutions \
+ "title": "Trace",
+ "desciption": "Trace is a collection of all TriggerInstances, Rules and ActionExecutions \
that represent an activity which begins with the introduction of a \
TriggerInstance or request of an ActionExecution and ends with the \
- completion of an ActionExecution.',
- 'type': 'object',
- 'properties': {
- 'id': {
- 'description': 'The unique identifier for a Trace.',
- 'type': 'string',
- 'default': None
+ completion of an ActionExecution.",
+ "type": "object",
+ "properties": {
+ "id": {
+ "description": "The unique identifier for a Trace.",
+ "type": "string",
+ "default": None,
},
- 'trace_tag': {
- 'description': 'User assigned identifier for each Trace.',
- 'type': 'string',
- 'required': True
+ "trace_tag": {
+ "description": "User assigned identifier for each Trace.",
+ "type": "string",
+ "required": True,
},
- 'action_executions': {
- 'description': 'All ActionExecutions belonging to a Trace.',
- 'type': 'array',
- 'items': TraceComponentAPISchema
+ "action_executions": {
+ "description": "All ActionExecutions belonging to a Trace.",
+ "type": "array",
+ "items": TraceComponentAPISchema,
},
- 'rules': {
- 'description': 'All rules that applied as part of a Trace.',
- 'type': 'array',
- 'items': TraceComponentAPISchema
+ "rules": {
+ "description": "All rules that applied as part of a Trace.",
+ "type": "array",
+ "items": TraceComponentAPISchema,
},
- 'trigger_instances': {
- 'description': 'All TriggerInstances fired during a Trace.',
- 'type': 'array',
- 'items': TraceComponentAPISchema
+ "trigger_instances": {
+ "description": "All TriggerInstances fired during a Trace.",
+ "type": "array",
+ "items": TraceComponentAPISchema,
},
- 'start_timestamp': {
- 'description': 'Timestamp when the Trace is started.',
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
+ "start_timestamp": {
+ "description": "Timestamp when the Trace is started.",
+ "type": "string",
+ "pattern": isotime.ISO8601_UTC_REGEX,
},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def to_component_model(cls, component):
values = {
- 'object_id': component['object_id'],
- 'ref': component['ref'],
- 'caused_by': component.get('caused_by', {})
+ "object_id": component["object_id"],
+ "ref": component["ref"],
+ "caused_by": component.get("caused_by", {}),
}
- updated_at = component.get('updated_at', None)
+ updated_at = component.get("updated_at", None)
if updated_at:
- values['updated_at'] = isotime.parse(updated_at)
+ values["updated_at"] = isotime.parse(updated_at)
return TraceComponentDB(**values)
@classmethod
def to_model(cls, instance):
- values = {
- 'trace_tag': instance.trace_tag
- }
- action_executions = getattr(instance, 'action_executions', [])
- action_executions = [TraceAPI.to_component_model(component=action_execution)
- for action_execution in action_executions]
- values['action_executions'] = action_executions
-
- rules = getattr(instance, 'rules', [])
+ values = {"trace_tag": instance.trace_tag}
+ action_executions = getattr(instance, "action_executions", [])
+ action_executions = [
+ TraceAPI.to_component_model(component=action_execution)
+ for action_execution in action_executions
+ ]
+ values["action_executions"] = action_executions
+
+ rules = getattr(instance, "rules", [])
rules = [TraceAPI.to_component_model(component=rule) for rule in rules]
- values['rules'] = rules
+ values["rules"] = rules
- trigger_instances = getattr(instance, 'trigger_instances', [])
- trigger_instances = [TraceAPI.to_component_model(component=trigger_instance)
- for trigger_instance in trigger_instances]
- values['trigger_instances'] = trigger_instances
+ trigger_instances = getattr(instance, "trigger_instances", [])
+ trigger_instances = [
+ TraceAPI.to_component_model(component=trigger_instance)
+ for trigger_instance in trigger_instances
+ ]
+ values["trigger_instances"] = trigger_instances
- start_timestamp = getattr(instance, 'start_timestamp', None)
+ start_timestamp = getattr(instance, "start_timestamp", None)
if start_timestamp:
- values['start_timestamp'] = isotime.parse(start_timestamp)
+ values["start_timestamp"] = isotime.parse(start_timestamp)
return cls.model(**values)
@classmethod
def from_component_model(cls, component_model):
- return {'object_id': component_model.object_id,
- 'ref': component_model.ref,
- 'updated_at': isotime.format(component_model.updated_at, offset=False),
- 'caused_by': component_model.caused_by}
+ return {
+ "object_id": component_model.object_id,
+ "ref": component_model.ref,
+ "updated_at": isotime.format(component_model.updated_at, offset=False),
+ "caused_by": component_model.caused_by,
+ }
@classmethod
def from_model(cls, model, mask_secrets=False):
instance = cls._from_model(model, mask_secrets=mask_secrets)
- instance['start_timestamp'] = isotime.format(model.start_timestamp, offset=False)
+ instance["start_timestamp"] = isotime.format(
+ model.start_timestamp, offset=False
+ )
if model.action_executions:
- instance['action_executions'] = [cls.from_component_model(action_execution)
- for action_execution in model.action_executions]
+ instance["action_executions"] = [
+ cls.from_component_model(action_execution)
+ for action_execution in model.action_executions
+ ]
if model.rules:
- instance['rules'] = [cls.from_component_model(rule) for rule in model.rules]
+ instance["rules"] = [cls.from_component_model(rule) for rule in model.rules]
if model.trigger_instances:
- instance['trigger_instances'] = [cls.from_component_model(trigger_instance)
- for trigger_instance in model.trigger_instances]
+ instance["trigger_instances"] = [
+ cls.from_component_model(trigger_instance)
+ for trigger_instance in model.trigger_instances
+ ]
return cls(**instance)
@@ -173,12 +180,13 @@ class TraceContext(object):
Optional property.
:type trace_tag: ``str``
"""
+
def __init__(self, id_=None, trace_tag=None):
self.id_ = id_
self.trace_tag = trace_tag
def __str__(self):
- return '{id_: %s, trace_tag: %s}' % (self.id_, self.trace_tag)
+ return "{id_: %s, trace_tag: %s}" % (self.id_, self.trace_tag)
def __json__(self):
return vars(self)
diff --git a/st2common/st2common/models/api/trigger.py b/st2common/st2common/models/api/trigger.py
index af88027fe0..cdb2cd9ddd 100644
--- a/st2common/st2common/models/api/trigger.py
+++ b/st2common/st2common/models/api/trigger.py
@@ -23,140 +23,113 @@
from st2common.models.db.trigger import TriggerTypeDB, TriggerDB, TriggerInstanceDB
from st2common.models.system.common import ResourceReference
-DATE_FORMAT = '%Y-%m-%d %H:%M:%S.%f'
+DATE_FORMAT = "%Y-%m-%d %H:%M:%S.%f"
class TriggerTypeAPI(BaseAPI):
model = TriggerTypeDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'ref': {
- 'type': 'string'
- },
- 'uid': {
- 'type': 'string'
- },
- 'name': {
- 'type': 'string',
- 'required': True
- },
- 'pack': {
- 'type': 'string'
- },
- 'description': {
- 'type': 'string'
- },
- 'payload_schema': {
- 'type': 'object',
- 'default': {}
- },
- 'parameters_schema': {
- 'type': 'object',
- 'default': {}
- },
- 'tags': {
- 'description': 'User associated metadata assigned to this object.',
- 'type': 'array',
- 'items': {'type': 'object'}
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "ref": {"type": "string"},
+ "uid": {"type": "string"},
+ "name": {"type": "string", "required": True},
+ "pack": {"type": "string"},
+ "description": {"type": "string"},
+ "payload_schema": {"type": "object", "default": {}},
+ "parameters_schema": {"type": "object", "default": {}},
+ "tags": {
+ "description": "User associated metadata assigned to this object.",
+ "type": "array",
+ "items": {"type": "object"},
},
"metadata_file": {
"description": "Path to the metadata file relative to the pack directory.",
"type": "string",
- "default": ""
- }
+ "default": "",
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def to_model(cls, trigger_type):
- name = getattr(trigger_type, 'name', None)
- description = getattr(trigger_type, 'description', None)
- pack = getattr(trigger_type, 'pack', None)
- payload_schema = getattr(trigger_type, 'payload_schema', {})
- parameters_schema = getattr(trigger_type, 'parameters_schema', {})
- tags = TagsHelper.to_model(getattr(trigger_type, 'tags', []))
- metadata_file = getattr(trigger_type, 'metadata_file', None)
-
- model = cls.model(name=name, description=description, pack=pack,
- payload_schema=payload_schema, parameters_schema=parameters_schema,
- tags=tags, metadata_file=metadata_file)
+ name = getattr(trigger_type, "name", None)
+ description = getattr(trigger_type, "description", None)
+ pack = getattr(trigger_type, "pack", None)
+ payload_schema = getattr(trigger_type, "payload_schema", {})
+ parameters_schema = getattr(trigger_type, "parameters_schema", {})
+ tags = TagsHelper.to_model(getattr(trigger_type, "tags", []))
+ metadata_file = getattr(trigger_type, "metadata_file", None)
+
+ model = cls.model(
+ name=name,
+ description=description,
+ pack=pack,
+ payload_schema=payload_schema,
+ parameters_schema=parameters_schema,
+ tags=tags,
+ metadata_file=metadata_file,
+ )
return model
@classmethod
def from_model(cls, model, mask_secrets=False):
triggertype = cls._from_model(model, mask_secrets=mask_secrets)
- triggertype['tags'] = TagsHelper.from_model(model.tags)
+ triggertype["tags"] = TagsHelper.from_model(model.tags)
return cls(**triggertype)
class TriggerAPI(BaseAPI):
model = TriggerDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'ref': {
- 'type': 'string'
- },
- 'uid': {
- 'type': 'string'
- },
- 'name': {
- 'type': 'string'
- },
- 'pack': {
- 'type': 'string'
- },
- 'type': {
- 'type': 'string',
- 'required': True
- },
- 'parameters': {
- 'type': 'object'
- },
- 'description': {
- 'type': 'string'
- }
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "ref": {"type": "string"},
+ "uid": {"type": "string"},
+ "name": {"type": "string"},
+ "pack": {"type": "string"},
+ "type": {"type": "string", "required": True},
+ "parameters": {"type": "object"},
+ "description": {"type": "string"},
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
trigger = cls._from_model(model, mask_secrets=mask_secrets)
# Hide ref count from API.
- trigger.pop('ref_count', None)
+ trigger.pop("ref_count", None)
return cls(**trigger)
@classmethod
def to_model(cls, trigger):
- name = getattr(trigger, 'name', None)
- description = getattr(trigger, 'description', None)
- pack = getattr(trigger, 'pack', None)
- _type = getattr(trigger, 'type', None)
- parameters = getattr(trigger, 'parameters', {})
+ name = getattr(trigger, "name", None)
+ description = getattr(trigger, "description", None)
+ pack = getattr(trigger, "pack", None)
+ _type = getattr(trigger, "type", None)
+ parameters = getattr(trigger, "parameters", {})
if _type and not parameters:
trigger_type_ref = ResourceReference.from_string_reference(_type)
name = trigger_type_ref.name
- if hasattr(trigger, 'name') and trigger.name:
+ if hasattr(trigger, "name") and trigger.name:
name = trigger.name
else:
# assign a name if none is provided.
name = str(uuid.uuid4())
- model = cls.model(name=name, description=description, pack=pack, type=_type,
- parameters=parameters)
+ model = cls.model(
+ name=name,
+ description=description,
+ pack=pack,
+ type=_type,
+ parameters=parameters,
+ )
return model
def to_dict(self):
@@ -167,38 +140,29 @@ def to_dict(self):
class TriggerInstanceAPI(BaseAPI):
model = TriggerInstanceDB
schema = {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string'
- },
- 'occurrence_time': {
- 'type': 'string',
- 'pattern': isotime.ISO8601_UTC_REGEX
- },
- 'payload': {
- 'type': 'object'
- },
- 'trigger': {
- 'type': 'string',
- 'default': None,
- 'required': True
+ "type": "object",
+ "properties": {
+ "id": {"type": "string"},
+ "occurrence_time": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX},
+ "payload": {"type": "object"},
+ "trigger": {"type": "string", "default": None, "required": True},
+ "status": {
+ "type": "string",
+ "default": None,
+ "enum": TRIGGER_INSTANCE_STATUSES,
},
- 'status': {
- 'type': 'string',
- 'default': None,
- 'enum': TRIGGER_INSTANCE_STATUSES
- }
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
@classmethod
def from_model(cls, model, mask_secrets=False):
instance = cls._from_model(model, mask_secrets=mask_secrets)
- if instance.get('occurrence_time', None):
- instance['occurrence_time'] = isotime.format(instance['occurrence_time'], offset=False)
+ if instance.get("occurrence_time", None):
+ instance["occurrence_time"] = isotime.format(
+ instance["occurrence_time"], offset=False
+ )
return cls(**instance)
@@ -209,6 +173,10 @@ def to_model(cls, instance):
occurrence_time = isotime.parse(instance.occurrence_time)
status = instance.status
- model = cls.model(trigger=trigger, payload=payload, occurrence_time=occurrence_time,
- status=status)
+ model = cls.model(
+ trigger=trigger,
+ payload=payload,
+ occurrence_time=occurrence_time,
+ status=status,
+ )
return model
diff --git a/st2common/st2common/models/api/webhook.py b/st2common/st2common/models/api/webhook.py
index 9d1a37ed1d..eb7b04a29b 100644
--- a/st2common/st2common/models/api/webhook.py
+++ b/st2common/st2common/models/api/webhook.py
@@ -15,20 +15,15 @@
from st2common.models.api.base import BaseAPI
-__all___ = [
- 'WebhookBodyAPI'
-]
+__all___ = ["WebhookBodyAPI"]
class WebhookBodyAPI(BaseAPI):
schema = {
- 'type': 'object',
- 'properties': {
+ "type": "object",
+ "properties": {
# Holds actual webhook body
- 'data': {
- 'type': ['object', 'array'],
- 'required': True
- }
+ "data": {"type": ["object", "array"], "required": True}
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
diff --git a/st2common/st2common/models/base.py b/st2common/st2common/models/base.py
index 342daf7028..35d5c884a7 100644
--- a/st2common/st2common/models/base.py
+++ b/st2common/st2common/models/base.py
@@ -17,9 +17,7 @@
Common model related classes.
"""
-__all__ = [
- 'DictSerializableClassMixin'
-]
+__all__ = ["DictSerializableClassMixin"]
class DictSerializableClassMixin(object):
diff --git a/st2common/st2common/models/db/__init__.py b/st2common/st2common/models/db/__init__.py
index 4fd51b4f61..ee7261facd 100644
--- a/st2common/st2common/models/db/__init__.py
+++ b/st2common/st2common/models/db/__init__.py
@@ -40,32 +40,30 @@
LOG = logging.getLogger(__name__)
MODEL_MODULE_NAMES = [
- 'st2common.models.db.auth',
- 'st2common.models.db.action',
- 'st2common.models.db.actionalias',
- 'st2common.models.db.keyvalue',
- 'st2common.models.db.execution',
- 'st2common.models.db.executionstate',
- 'st2common.models.db.execution_queue',
- 'st2common.models.db.liveaction',
- 'st2common.models.db.notification',
- 'st2common.models.db.pack',
- 'st2common.models.db.policy',
- 'st2common.models.db.rbac',
- 'st2common.models.db.rule',
- 'st2common.models.db.rule_enforcement',
- 'st2common.models.db.runner',
- 'st2common.models.db.sensor',
- 'st2common.models.db.trace',
- 'st2common.models.db.trigger',
- 'st2common.models.db.webhook',
- 'st2common.models.db.workflow'
+ "st2common.models.db.auth",
+ "st2common.models.db.action",
+ "st2common.models.db.actionalias",
+ "st2common.models.db.keyvalue",
+ "st2common.models.db.execution",
+ "st2common.models.db.executionstate",
+ "st2common.models.db.execution_queue",
+ "st2common.models.db.liveaction",
+ "st2common.models.db.notification",
+ "st2common.models.db.pack",
+ "st2common.models.db.policy",
+ "st2common.models.db.rbac",
+ "st2common.models.db.rule",
+ "st2common.models.db.rule_enforcement",
+ "st2common.models.db.runner",
+ "st2common.models.db.sensor",
+ "st2common.models.db.trace",
+ "st2common.models.db.trigger",
+ "st2common.models.db.webhook",
+ "st2common.models.db.workflow",
]
# A list of model names for which we don't perform extra index cleanup
-INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [
- 'PermissionGrantDB'
-]
+INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = ["PermissionGrantDB"]
# Reference to DB model classes used for db_ensure_indexes
# NOTE: This variable is populated lazily inside get_model_classes()
@@ -86,55 +84,78 @@ def get_model_classes():
result = []
for module_name in MODEL_MODULE_NAMES:
module = importlib.import_module(module_name)
- model_classes = getattr(module, 'MODELS', [])
+ model_classes = getattr(module, "MODELS", [])
result.extend(model_classes)
MODEL_CLASSES = result
return MODEL_CLASSES
-def _db_connect(db_name, db_host, db_port, username=None, password=None,
- ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
- ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True):
-
- if '://' in db_host:
+def _db_connect(
+ db_name,
+ db_host,
+ db_port,
+ username=None,
+ password=None,
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
+
+ if "://" in db_host:
# Hostname is provided as a URI string. Make sure we don't log the password in case one is
# included as part of the URI string.
uri_dict = uri_parser.parse_uri(db_host)
- username_string = uri_dict.get('username', username) or username
+ username_string = uri_dict.get("username", username) or username
- if uri_dict.get('username', None) and username:
+ if uri_dict.get("username", None) and username:
# Username argument has precedence over connection string username
username_string = username
hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict)
- if len(uri_dict['nodelist']) > 1:
- host_string = '%s (replica set)' % (hostnames)
+ if len(uri_dict["nodelist"]) > 1:
+ host_string = "%s (replica set)" % (hostnames)
else:
host_string = hostnames
else:
- host_string = '%s:%s' % (db_host, db_port)
+ host_string = "%s:%s" % (db_host, db_port)
username_string = username
- LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string,
- str(username_string)))
-
- ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile,
- ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
- authentication_mechanism=authentication_mechanism,
- ssl_match_hostname=ssl_match_hostname)
+ LOG.info(
+ 'Connecting to database "%s" @ "%s" as user "%s".'
+ % (db_name, host_string, str(username_string))
+ )
+
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl=ssl,
+ ssl_keyfile=ssl_keyfile,
+ ssl_certfile=ssl_certfile,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ authentication_mechanism=authentication_mechanism,
+ ssl_match_hostname=ssl_match_hostname,
+ )
# NOTE: We intentionally set "serverSelectionTimeoutMS" to 3 seconds. By default it's set to
# 30 seconds, which means it will block up to 30 seconds and fail if there are any SSL related
# or other errors
connection_timeout = cfg.CONF.database.connection_timeout
- connection = mongoengine.connection.connect(db_name, host=db_host,
- port=db_port, tz_aware=True,
- username=username, password=password,
- connectTimeoutMS=connection_timeout,
- serverSelectionTimeoutMS=connection_timeout,
- **ssl_kwargs)
+ connection = mongoengine.connection.connect(
+ db_name,
+ host=db_host,
+ port=db_port,
+ tz_aware=True,
+ username=username,
+ password=password,
+ connectTimeoutMS=connection_timeout,
+ serverSelectionTimeoutMS=connection_timeout,
+ **ssl_kwargs,
+ )
# NOTE: Since pymongo 3.0, connect() method is lazy and not blocking (always returns success)
# so we need to issue a command / query to check if connection has been
@@ -142,32 +163,55 @@ def _db_connect(db_name, db_host, db_port, username=None, password=None,
# See http://api.mongodb.com/python/current/api/pymongo/mongo_client.html for details
try:
# The ismaster command is cheap and does not require auth
- connection.admin.command('ismaster')
+ connection.admin.command("ismaster")
except (ConnectionFailure, ServerSelectionTimeoutError) as e:
# NOTE: ServerSelectionTimeoutError can also be thrown if SSLHandShake fails in the server
# Sadly the client doesn't include more information about the error so in such scenarios
# user needs to check MongoDB server log
- LOG.error('Failed to connect to database "%s" @ "%s" as user "%s": %s' %
- (db_name, host_string, str(username_string), six.text_type(e)))
+ LOG.error(
+ 'Failed to connect to database "%s" @ "%s" as user "%s": %s'
+ % (db_name, host_string, str(username_string), six.text_type(e))
+ )
raise e
- LOG.info('Successfully connected to database "%s" @ "%s" as user "%s".' % (
- db_name, host_string, str(username_string)))
+ LOG.info(
+ 'Successfully connected to database "%s" @ "%s" as user "%s".'
+ % (db_name, host_string, str(username_string))
+ )
return connection
-def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True,
- ssl=False, ssl_keyfile=None, ssl_certfile=None,
- ssl_cert_reqs=None, ssl_ca_certs=None,
- authentication_mechanism=None, ssl_match_hostname=True):
-
- connection = _db_connect(db_name, db_host, db_port, username=username,
- password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
- ssl_certfile=ssl_certfile,
- ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
- authentication_mechanism=authentication_mechanism,
- ssl_match_hostname=ssl_match_hostname)
+def db_setup(
+ db_name,
+ db_host,
+ db_port,
+ username=None,
+ password=None,
+ ensure_indexes=True,
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
+
+ connection = _db_connect(
+ db_name,
+ db_host,
+ db_port,
+ username=username,
+ password=password,
+ ssl=ssl,
+ ssl_keyfile=ssl_keyfile,
+ ssl_certfile=ssl_certfile,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ authentication_mechanism=authentication_mechanism,
+ ssl_match_hostname=ssl_match_hostname,
+ )
# Create all the indexes upfront to prevent race-conditions caused by
# lazy index creation
@@ -192,7 +236,7 @@ def db_ensure_indexes(model_classes=None):
ensured for all the models.
:type model_classes: ``list``
"""
- LOG.debug('Ensuring database indexes...')
+ LOG.debug("Ensuring database indexes...")
if not model_classes:
model_classes = get_model_classes()
@@ -210,34 +254,44 @@ def db_ensure_indexes(model_classes=None):
# Note: This condition would only be encountered when upgrading existing StackStorm
# installation from MongoDB 3.2 to 3.4.
msg = six.text_type(e)
- if 'already exists with different options' in msg and 'uid_1' in msg:
+ if "already exists with different options" in msg and "uid_1" in msg:
drop_obsolete_types_indexes(model_class=model_class)
else:
raise e
except Exception as e:
tb_msg = traceback.format_exc()
- msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, six.text_type(e))
- msg += '\n\n' + tb_msg
+ msg = 'Failed to ensure indexes for model "%s": %s' % (
+ class_name,
+ six.text_type(e),
+ )
+ msg += "\n\n" + tb_msg
exc_cls = type(e)
raise exc_cls(msg)
if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST:
- LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name))
+ LOG.debug(
+ 'Skipping index cleanup for blacklisted model "%s"...' % (class_name)
+ )
continue
removed_count = cleanup_extra_indexes(model_class=model_class)
if removed_count:
- LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name))
+ LOG.debug(
+ 'Removed "%s" extra indexes for model "%s"'
+ % (removed_count, class_name)
+ )
- LOG.debug('Indexes are ensured for models: %s' %
- ', '.join(sorted((model_class.__name__ for model_class in model_classes))))
+ LOG.debug(
+ "Indexes are ensured for models: %s"
+ % ", ".join(sorted((model_class.__name__ for model_class in model_classes)))
+ )
def cleanup_extra_indexes(model_class):
"""
Finds any extra indexes and removes those from mongodb.
"""
- extra_indexes = model_class.compare_indexes().get('extra', None)
+ extra_indexes = model_class.compare_indexes().get("extra", None)
if not extra_indexes:
return 0
@@ -248,10 +302,14 @@ def cleanup_extra_indexes(model_class):
for extra_index in extra_indexes:
try:
c.drop_index(extra_index)
- LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__)
+ LOG.debug(
+ "Dropped index %s for model %s.", extra_index, model_class.__name__
+ )
removed_count += 1
except OperationFailure:
- LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True)
+ LOG.warning(
+ "Attempt to cleanup index %s failed.", extra_index, exc_info=True
+ )
return removed_count
@@ -266,14 +324,19 @@ def drop_obsolete_types_indexes(model_class):
LOG.debug('Dropping obsolete types index for model "%s"' % (class_name))
collection = model_class._get_collection()
- collection.update({}, {'$unset': {'_types': 1}}, multi=True)
+ collection.update({}, {"$unset": {"_types": 1}}, multi=True)
info = collection.index_information()
- indexes_to_drop = [key for key, value in six.iteritems(info)
- if '_types' in dict(value['key']) or 'types' in value]
+ indexes_to_drop = [
+ key
+ for key, value in six.iteritems(info)
+ if "_types" in dict(value["key"]) or "types" in value
+ ]
- LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name,
- str(indexes_to_drop)))
+ LOG.debug(
+ 'Will drop obsolete types indexes for model "%s": %s'
+ % (class_name, str(indexes_to_drop))
+ )
for index in indexes_to_drop:
collection.drop_index(index)
@@ -286,57 +349,87 @@ def db_teardown():
mongoengine.connection.disconnect()
-def db_cleanup(db_name, db_host, db_port, username=None, password=None,
- ssl=False, ssl_keyfile=None, ssl_certfile=None,
- ssl_cert_reqs=None, ssl_ca_certs=None,
- authentication_mechanism=None, ssl_match_hostname=True):
-
- connection = _db_connect(db_name, db_host, db_port, username=username,
- password=password, ssl=ssl, ssl_keyfile=ssl_keyfile,
- ssl_certfile=ssl_certfile,
- ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs,
- authentication_mechanism=authentication_mechanism,
- ssl_match_hostname=ssl_match_hostname)
-
- LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".',
- db_name, db_host, db_port, str(username))
+def db_cleanup(
+ db_name,
+ db_host,
+ db_port,
+ username=None,
+ password=None,
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
+
+ connection = _db_connect(
+ db_name,
+ db_host,
+ db_port,
+ username=username,
+ password=password,
+ ssl=ssl,
+ ssl_keyfile=ssl_keyfile,
+ ssl_certfile=ssl_certfile,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ authentication_mechanism=authentication_mechanism,
+ ssl_match_hostname=ssl_match_hostname,
+ )
+
+ LOG.info(
+ 'Dropping database "%s" @ "%s:%s" as user "%s".',
+ db_name,
+ db_host,
+ db_port,
+ str(username),
+ )
connection.drop_database(db_name)
return connection
-def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
- ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True):
+def _get_ssl_kwargs(
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
# NOTE: In pymongo 3.9.0 some of the ssl related arguments have been renamed -
# https://api.mongodb.com/python/current/changelog.html#changes-in-version-3-9-0
# Old names still work, but we should eventually update to new argument names.
ssl_kwargs = {
- 'ssl': ssl,
+ "ssl": ssl,
}
if ssl_keyfile:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['ssl_keyfile'] = ssl_keyfile
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["ssl_keyfile"] = ssl_keyfile
if ssl_certfile:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['ssl_certfile'] = ssl_certfile
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["ssl_certfile"] = ssl_certfile
if ssl_cert_reqs:
- if ssl_cert_reqs == 'none':
+ if ssl_cert_reqs == "none":
ssl_cert_reqs = ssl_lib.CERT_NONE
- elif ssl_cert_reqs == 'optional':
+ elif ssl_cert_reqs == "optional":
ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
- elif ssl_cert_reqs == 'required':
+ elif ssl_cert_reqs == "required":
ssl_cert_reqs = ssl_lib.CERT_REQUIRED
- ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs
+ ssl_kwargs["ssl_cert_reqs"] = ssl_cert_reqs
if ssl_ca_certs:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["ssl_ca_certs"] = ssl_ca_certs
if authentication_mechanism:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['authentication_mechanism'] = authentication_mechanism
- if ssl_kwargs.get('ssl', False):
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["authentication_mechanism"] = authentication_mechanism
+ if ssl_kwargs.get("ssl", False):
# pass in ssl_match_hostname only if ssl is True. The right default value
# for ssl_match_hostname in almost all cases is True.
- ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname
+ ssl_kwargs["ssl_match_hostname"] = ssl_match_hostname
return ssl_kwargs
@@ -362,9 +455,9 @@ def get_by_pack(self, value):
return self.get(pack=value, raise_exception=True)
def get(self, *args, **kwargs):
- exclude_fields = kwargs.pop('exclude_fields', None)
- raise_exception = kwargs.pop('raise_exception', False)
- only_fields = kwargs.pop('only_fields', None)
+ exclude_fields = kwargs.pop("exclude_fields", None)
+ raise_exception = kwargs.pop("raise_exception", False)
+ only_fields = kwargs.pop("only_fields", None)
args = self._process_arg_filters(args)
@@ -377,14 +470,17 @@ def get(self, *args, **kwargs):
try:
instances = instances.only(*only_fields)
except (mongoengine.errors.LookUpError, AttributeError) as e:
- msg = ('Invalid or unsupported include attribute specified: %s' % six.text_type(e))
+ msg = (
+ "Invalid or unsupported include attribute specified: %s"
+ % six.text_type(e)
+ )
raise ValueError(msg)
instance = instances[0] if instances else None
log_query_and_profile_data_for_queryset(queryset=instances)
if not instance and raise_exception:
- msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs)
+ msg = "Unable to find the %s instance. %s" % (self.model.__name__, kwargs)
raise db_exc.StackStormDBObjectNotFoundError(msg)
return instance
@@ -404,12 +500,12 @@ def count(self, *args, **kwargs):
# **filters):
def query(self, *args, **filters):
# Python 2: Pop keyword parameters that aren't actually filters off of the kwargs
- offset = filters.pop('offset', 0)
- limit = filters.pop('limit', None)
- order_by = filters.pop('order_by', None)
- exclude_fields = filters.pop('exclude_fields', None)
- only_fields = filters.pop('only_fields', None)
- no_dereference = filters.pop('no_dereference', None)
+ offset = filters.pop("offset", 0)
+ limit = filters.pop("limit", None)
+ order_by = filters.pop("order_by", None)
+ exclude_fields = filters.pop("exclude_fields", None)
+ only_fields = filters.pop("only_fields", None)
+ no_dereference = filters.pop("no_dereference", None)
order_by = order_by or []
exclude_fields = exclude_fields or []
@@ -419,7 +515,9 @@ def query(self, *args, **filters):
# Process the filters
# Note: Both of those functions manipulate "filters" variable so the order in which they
# are called matters
- filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by)
+ filters, order_by = self._process_datetime_range_filters(
+ filters=filters, order_by=order_by
+ )
filters = self._process_null_filters(filters=filters)
result = self.model.objects(*args, **filters)
@@ -429,7 +527,7 @@ def query(self, *args, **filters):
result = result.exclude(*exclude_fields)
except (mongoengine.errors.LookUpError, AttributeError) as e:
field = get_field_name_from_mongoengine_error(e)
- msg = ('Invalid or unsupported exclude attribute specified: %s' % field)
+ msg = "Invalid or unsupported exclude attribute specified: %s" % field
raise ValueError(msg)
if only_fields:
@@ -437,7 +535,7 @@ def query(self, *args, **filters):
result = result.only(*only_fields)
except (mongoengine.errors.LookUpError, AttributeError) as e:
field = get_field_name_from_mongoengine_error(e)
- msg = ('Invalid or unsupported include attribute specified: %s' % field)
+ msg = "Invalid or unsupported include attribute specified: %s" % field
raise ValueError(msg)
if no_dereference:
@@ -450,7 +548,7 @@ def query(self, *args, **filters):
return result
def distinct(self, *args, **kwargs):
- field = kwargs.pop('field')
+ field = kwargs.pop("field")
result = self.model.objects(**kwargs).distinct(field)
log_query_and_profile_data_for_queryset(queryset=result)
return result
@@ -513,8 +611,10 @@ def _process_arg_filters(self, args):
# Create a new QCombination object with the same operation and fixed filters
_args += (visitor.QCombination(arg.operation, children),)
else:
- raise TypeError("Unknown argument type '%s' of argument '%s'"
- % (type(arg), repr(arg)))
+ raise TypeError(
+ "Unknown argument type '%s' of argument '%s'"
+ % (type(arg), repr(arg))
+ )
return _args
@@ -526,35 +626,38 @@ def _process_null_filters(self, filters):
for key, value in six.iteritems(filters):
if value is None:
null_filters[key] = value
- elif isinstance(value, (str, six.text_type)) and value.lower() == 'null':
+ elif isinstance(value, (str, six.text_type)) and value.lower() == "null":
null_filters[key] = value
else:
continue
for key in null_filters.keys():
- result['%s__exists' % (key)] = False
+ result["%s__exists" % (key)] = False
del result[key]
return result
def _process_datetime_range_filters(self, filters, order_by=None):
- ranges = {k: v for k, v in six.iteritems(filters)
- if type(v) in [str, six.text_type] and '..' in v}
+ ranges = {
+ k: v
+ for k, v in six.iteritems(filters)
+ if type(v) in [str, six.text_type] and ".." in v
+ }
order_by_list = copy.deepcopy(order_by) if order_by else []
for k, v in six.iteritems(ranges):
- values = v.split('..')
+ values = v.split("..")
dt1 = isotime.parse(values[0])
dt2 = isotime.parse(values[1])
- k__gte = '%s__gte' % k
- k__lte = '%s__lte' % k
+ k__gte = "%s__gte" % k
+ k__lte = "%s__lte" % k
if dt1 < dt2:
query = {k__gte: dt1, k__lte: dt2}
- sort_key, reverse_sort_key = k, '-' + k
+ sort_key, reverse_sort_key = k, "-" + k
else:
query = {k__gte: dt2, k__lte: dt1}
- sort_key, reverse_sort_key = '-' + k, k
+ sort_key, reverse_sort_key = "-" + k, k
del filters[k]
filters.update(query)
@@ -569,7 +672,6 @@ def _process_datetime_range_filters(self, filters, order_by=None):
class ChangeRevisionMongoDBAccess(MongoDBAccess):
-
def insert(self, instance):
instance = self.model.objects.insert(instance)
@@ -585,11 +687,11 @@ def update(self, instance, **kwargs):
return self.save(instance)
def save(self, instance, validate=True):
- if not hasattr(instance, 'id') or not instance.id:
+ if not hasattr(instance, "id") or not instance.id:
return self.insert(instance)
else:
try:
- save_condition = {'id': instance.id, 'rev': instance.rev}
+ save_condition = {"id": instance.id, "rev": instance.rev}
instance.rev = instance.rev + 1
instance.save(save_condition=save_condition, validate=validate)
except mongoengine.SaveConditionError:
@@ -601,8 +703,8 @@ def save(self, instance, validate=True):
def get_host_names_for_uri_dict(uri_dict):
hosts = []
- for host, port in uri_dict['nodelist']:
- hosts.append('%s:%s' % (host, port))
+ for host, port in uri_dict["nodelist"]:
+ hosts.append("%s:%s" % (host, port))
- hosts = ','.join(hosts)
+ hosts = ",".join(hosts)
return hosts
diff --git a/st2common/st2common/models/db/action.py b/st2common/st2common/models/db/action.py
index 1c28b207c2..52a1ed0374 100644
--- a/st2common/st2common/models/db/action.py
+++ b/st2common/st2common/models/db/action.py
@@ -29,22 +29,26 @@
from st2common.constants.types import ResourceType
__all__ = [
- 'RunnerTypeDB',
- 'ActionDB',
- 'LiveActionDB',
- 'ActionExecutionDB',
- 'ActionExecutionStateDB',
- 'ActionAliasDB'
+ "RunnerTypeDB",
+ "ActionDB",
+ "LiveActionDB",
+ "ActionExecutionDB",
+ "ActionExecutionStateDB",
+ "ActionAliasDB",
]
LOG = logging.getLogger(__name__)
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
-class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
- stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin):
+class ActionDB(
+ stormbase.StormFoundationDB,
+ stormbase.TagsMixin,
+ stormbase.ContentPackResourceMixin,
+ stormbase.UIDFieldMixin,
+):
"""
The system entity that represents a Stack Action/Automation in the system.
@@ -56,38 +60,46 @@ class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
"""
RESOURCE_TYPE = ResourceType.ACTION
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
ref = me.StringField(required=True)
description = me.StringField()
enabled = me.BooleanField(
- required=True, default=True,
- help_text='A flag indicating whether the action is enabled.')
- entry_point = me.StringField(
required=True,
- help_text='The entry point to the action.')
+ default=True,
+ help_text="A flag indicating whether the action is enabled.",
+ )
+ entry_point = me.StringField(
+ required=True, help_text="The entry point to the action."
+ )
pack = me.StringField(
- required=False,
- help_text='Name of the content pack.',
- unique_with='name')
+ required=False, help_text="Name of the content pack.", unique_with="name"
+ )
runner_type = me.DictField(
- required=True, default={},
- help_text='The action runner to use for executing the action.')
+ required=True,
+ default={},
+ help_text="The action runner to use for executing the action.",
+ )
parameters = stormbase.EscapedDynamicField(
- help_text='The specification for parameters for the action.')
+ help_text="The specification for parameters for the action."
+ )
output_schema = stormbase.EscapedDynamicField(
- help_text='The schema for output of the action.')
+ help_text="The schema for output of the action."
+ )
notify = me.EmbeddedDocumentField(NotificationSchema)
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['pack']},
- {'fields': ['ref']},
- ] + (stormbase.ContentPackResourceMixin.get_indexes() +
- stormbase.TagsMixin.get_indexes() +
- stormbase.UIDFieldMixin.get_indexes())
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["pack"]},
+ {"fields": ["ref"]},
+ ]
+ + (
+ stormbase.ContentPackResourceMixin.get_indexes()
+ + stormbase.TagsMixin.get_indexes()
+ + stormbase.UIDFieldMixin.get_indexes()
+ )
}
def __init__(self, *args, **values):
@@ -102,11 +114,17 @@ def is_workflow(self):
:rtype: ``bool``
"""
# pylint: disable=unsubscriptable-object
- return self.runner_type['name'] in WORKFLOW_RUNNER_TYPES
+ return self.runner_type["name"] in WORKFLOW_RUNNER_TYPES
# specialized access objects
action_access = MongoDBAccess(ActionDB)
-MODELS = [ActionDB, ActionExecutionDB, ActionExecutionStateDB, ActionAliasDB,
- LiveActionDB, RunnerTypeDB]
+MODELS = [
+ ActionDB,
+ ActionExecutionDB,
+ ActionExecutionStateDB,
+ ActionAliasDB,
+ LiveActionDB,
+ RunnerTypeDB,
+]
diff --git a/st2common/st2common/models/db/actionalias.py b/st2common/st2common/models/db/actionalias.py
index a696ff08b4..765630d8a4 100644
--- a/st2common/st2common/models/db/actionalias.py
+++ b/st2common/st2common/models/db/actionalias.py
@@ -21,18 +21,19 @@
from st2common.models.db import stormbase
from st2common.constants.types import ResourceType
-__all__ = [
- 'ActionAliasDB'
-]
+__all__ = ["ActionAliasDB"]
LOG = logging.getLogger(__name__)
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
-class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin,
- stormbase.UIDFieldMixin):
+class ActionAliasDB(
+ stormbase.StormFoundationDB,
+ stormbase.ContentPackResourceMixin,
+ stormbase.UIDFieldMixin,
+):
"""
Database entity that represent an Alias for an action.
@@ -46,42 +47,48 @@ class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMi
"""
RESOURCE_TYPE = ResourceType.ACTION_ALIAS
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
ref = me.StringField(required=True)
description = me.StringField()
pack = me.StringField(
- required=True,
- help_text='Name of the content pack.',
- unique_with='name')
+ required=True, help_text="Name of the content pack.", unique_with="name"
+ )
enabled = me.BooleanField(
- required=True, default=True,
- help_text='A flag indicating whether the action alias is enabled.')
- action_ref = me.StringField(
required=True,
- help_text='Reference of the Action map this alias.')
+ default=True,
+ help_text="A flag indicating whether the action alias is enabled.",
+ )
+ action_ref = me.StringField(
+ required=True, help_text="Reference of the Action map this alias."
+ )
formats = me.ListField(
- help_text='Possible parameter formats that an alias supports.')
+ help_text="Possible parameter formats that an alias supports."
+ )
ack = me.DictField(
- help_text='Parameters pertaining to the acknowledgement message.'
+ help_text="Parameters pertaining to the acknowledgement message."
)
result = me.DictField(
- help_text='Parameters pertaining to the execution result message.'
+ help_text="Parameters pertaining to the execution result message."
)
extra = me.DictField(
- help_text='Additional parameters (usually adapter-specific) not covered in the schema.'
+ help_text="Additional parameters (usually adapter-specific) not covered in the schema."
)
immutable_parameters = me.DictField(
- help_text='Parameters to be passed to the action on every execution.')
+ help_text="Parameters to be passed to the action on every execution."
+ )
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['enabled']},
- {'fields': ['formats']},
- ] + (stormbase.ContentPackResourceMixin().get_indexes() +
- stormbase.UIDFieldMixin.get_indexes())
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["enabled"]},
+ {"fields": ["formats"]},
+ ]
+ + (
+ stormbase.ContentPackResourceMixin().get_indexes()
+ + stormbase.UIDFieldMixin.get_indexes()
+ )
}
def __init__(self, *args, **values):
@@ -97,10 +104,12 @@ def get_format_strings(self):
"""
result = []
- formats = getattr(self, 'formats', [])
+ formats = getattr(self, "formats", [])
for format_string in formats:
- if isinstance(format_string, dict) and format_string.get('representation', None):
- result.extend(format_string['representation'])
+ if isinstance(format_string, dict) and format_string.get(
+ "representation", None
+ ):
+ result.extend(format_string["representation"])
else:
result.append(format_string)
diff --git a/st2common/st2common/models/db/auth.py b/st2common/st2common/models/db/auth.py
index 7ef30ee017..2531ecb11a 100644
--- a/st2common/st2common/models/db/auth.py
+++ b/st2common/st2common/models/db/auth.py
@@ -25,11 +25,7 @@
from st2common.rbac.backends import get_rbac_backend
from st2common.util import date as date_utils
-__all__ = [
- 'UserDB',
- 'TokenDB',
- 'ApiKeyDB'
-]
+__all__ = ["UserDB", "TokenDB", "ApiKeyDB"]
class UserDB(stormbase.StormFoundationDB):
@@ -42,10 +38,12 @@ class UserDB(stormbase.StormFoundationDB):
is_service: True if this is a service account.
nicknames: Nickname + origin pairs for ChatOps auth.
"""
+
name = me.StringField(required=True, unique=True)
is_service = me.BooleanField(required=True, default=False)
- nicknames = me.DictField(required=False,
- help_text='"Nickname + origin" pairs for ChatOps auth')
+ nicknames = me.DictField(
+ required=False, help_text='"Nickname + origin" pairs for ChatOps auth'
+ )
def get_roles(self, include_remote=True):
"""
@@ -57,7 +55,9 @@ def get_roles(self, include_remote=True):
:rtype: ``list`` of :class:`RoleDB`
"""
rbac_service = get_rbac_backend().get_service_class()
- result = rbac_service.get_roles_for_user(user_db=self, include_remote=include_remote)
+ result = rbac_service.get_roles_for_user(
+ user_db=self, include_remote=include_remote
+ )
return result
def get_permission_assignments(self):
@@ -75,11 +75,13 @@ class TokenDB(stormbase.StormFoundationDB):
expiry: Date when this token expires.
service: True if this is a service (system) token.
"""
+
user = me.StringField(required=True)
token = me.StringField(required=True, unique=True)
expiry = me.DateTimeField(required=True)
- metadata = me.DictField(required=False,
- help_text='Arbitrary metadata associated with this token')
+ metadata = me.DictField(
+ required=False, help_text="Arbitrary metadata associated with this token"
+ )
service = me.BooleanField(required=True, default=False)
@@ -91,23 +93,24 @@ class ApiKeyDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin):
"""
RESOURCE_TYPE = ResourceType.API_KEY
- UID_FIELDS = ['key_hash']
+ UID_FIELDS = ["key_hash"]
user = me.StringField(required=True)
key_hash = me.StringField(required=True, unique=True)
- metadata = me.DictField(required=False,
- help_text='Arbitrary metadata associated with this token')
- created_at = ComplexDateTimeField(default=date_utils.get_datetime_utc_now,
- help_text='The creation time of this ApiKey.')
- enabled = me.BooleanField(required=True, default=True,
- help_text='A flag indicating whether the ApiKey is enabled.')
-
- meta = {
- 'indexes': [
- {'fields': ['user']},
- {'fields': ['key_hash']}
- ]
- }
+ metadata = me.DictField(
+ required=False, help_text="Arbitrary metadata associated with this token"
+ )
+ created_at = ComplexDateTimeField(
+ default=date_utils.get_datetime_utc_now,
+ help_text="The creation time of this ApiKey.",
+ )
+ enabled = me.BooleanField(
+ required=True,
+ default=True,
+ help_text="A flag indicating whether the ApiKey is enabled.",
+ )
+
+ meta = {"indexes": [{"fields": ["user"]}, {"fields": ["key_hash"]}]}
def __init__(self, *args, **values):
super(ApiKeyDB, self).__init__(*args, **values)
@@ -119,8 +122,8 @@ def mask_secrets(self, value):
# In theory the key_hash is safe to return as it is one way. On the other
# hand given that this is actually a secret no real point in letting the hash
# escape. Since uid contains key_hash masking that as well.
- result['key_hash'] = MASKED_ATTRIBUTE_VALUE
- result['uid'] = MASKED_ATTRIBUTE_VALUE
+ result["key_hash"] = MASKED_ATTRIBUTE_VALUE
+ result["uid"] = MASKED_ATTRIBUTE_VALUE
return result
diff --git a/st2common/st2common/models/db/execution.py b/st2common/st2common/models/db/execution.py
index 3e8f3c7742..a44e5072d6 100644
--- a/st2common/st2common/models/db/execution.py
+++ b/st2common/st2common/models/db/execution.py
@@ -27,10 +27,7 @@
from st2common.util.secrets import mask_secret_parameters
from st2common.constants.types import ResourceType
-__all__ = [
- 'ActionExecutionDB',
- 'ActionExecutionOutputDB'
-]
+__all__ = ["ActionExecutionDB", "ActionExecutionOutputDB"]
LOG = logging.getLogger(__name__)
@@ -38,7 +35,7 @@
class ActionExecutionDB(stormbase.StormFoundationDB):
RESOURCE_TYPE = ResourceType.EXECUTION
- UID_FIELDS = ['id']
+ UID_FIELDS = ["id"]
trigger = stormbase.EscapedDictField()
trigger_type = stormbase.EscapedDictField()
@@ -52,22 +49,25 @@ class ActionExecutionDB(stormbase.StormFoundationDB):
workflow_execution = me.StringField()
task_execution = me.StringField()
status = me.StringField(
- required=True,
- help_text='The current status of the liveaction.')
+ required=True, help_text="The current status of the liveaction."
+ )
start_timestamp = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the liveaction was created.')
+ help_text="The timestamp when the liveaction was created.",
+ )
end_timestamp = ComplexDateTimeField(
- help_text='The timestamp when the liveaction has finished.')
+ help_text="The timestamp when the liveaction has finished."
+ )
parameters = stormbase.EscapedDynamicField(
default={},
- help_text='The key-value pairs passed as to the action runner & action.')
+ help_text="The key-value pairs passed as to the action runner & action.",
+ )
result = stormbase.EscapedDynamicField(
- default={},
- help_text='Action defined result.')
+ default={}, help_text="Action defined result."
+ )
context = me.DictField(
- default={},
- help_text='Contextual information on the action execution.')
+ default={}, help_text="Contextual information on the action execution."
+ )
parent = me.StringField()
children = me.ListField(field=me.StringField())
log = me.ListField(field=me.DictField())
@@ -76,49 +76,51 @@ class ActionExecutionDB(stormbase.StormFoundationDB):
web_url = me.StringField(required=False)
meta = {
- 'indexes': [
- {'fields': ['rule.ref']},
- {'fields': ['action.ref']},
- {'fields': ['liveaction.id']},
- {'fields': ['start_timestamp']},
- {'fields': ['end_timestamp']},
- {'fields': ['status']},
- {'fields': ['parent']},
- {'fields': ['rule.name']},
- {'fields': ['runner.name']},
- {'fields': ['trigger.name']},
- {'fields': ['trigger_type.name']},
- {'fields': ['trigger_instance.id']},
- {'fields': ['context.user']},
- {'fields': ['-start_timestamp', 'action.ref', 'status']},
- {'fields': ['workflow_execution']},
- {'fields': ['task_execution']}
+ "indexes": [
+ {"fields": ["rule.ref"]},
+ {"fields": ["action.ref"]},
+ {"fields": ["liveaction.id"]},
+ {"fields": ["start_timestamp"]},
+ {"fields": ["end_timestamp"]},
+ {"fields": ["status"]},
+ {"fields": ["parent"]},
+ {"fields": ["rule.name"]},
+ {"fields": ["runner.name"]},
+ {"fields": ["trigger.name"]},
+ {"fields": ["trigger_type.name"]},
+ {"fields": ["trigger_instance.id"]},
+ {"fields": ["context.user"]},
+ {"fields": ["-start_timestamp", "action.ref", "status"]},
+ {"fields": ["workflow_execution"]},
+ {"fields": ["task_execution"]},
]
}
def get_uid(self):
# TODO Construct id from non id field:
uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=no-member
- return ':'.join(uid)
+ return ":".join(uid)
def mask_secrets(self, value):
result = copy.deepcopy(value)
- liveaction = result['liveaction']
+ liveaction = result["liveaction"]
parameters = {}
# pylint: disable=no-member
- parameters.update(value.get('action', {}).get('parameters', {}))
- parameters.update(value.get('runner', {}).get('runner_parameters', {}))
+ parameters.update(value.get("action", {}).get("parameters", {}))
+ parameters.update(value.get("runner", {}).get("runner_parameters", {}))
secret_parameters = get_secret_parameters(parameters=parameters)
- result['parameters'] = mask_secret_parameters(parameters=result.get('parameters', {}),
- secret_parameters=secret_parameters)
+ result["parameters"] = mask_secret_parameters(
+ parameters=result.get("parameters", {}), secret_parameters=secret_parameters
+ )
- if 'parameters' in liveaction:
- liveaction['parameters'] = mask_secret_parameters(parameters=liveaction['parameters'],
- secret_parameters=secret_parameters)
+ if "parameters" in liveaction:
+ liveaction["parameters"] = mask_secret_parameters(
+ parameters=liveaction["parameters"], secret_parameters=secret_parameters
+ )
- if liveaction.get('action', '') == 'st2.inquiry.respond':
+ if liveaction.get("action", "") == "st2.inquiry.respond":
# Special case to mask parameters for `st2.inquiry.respond` action
# In this case, this execution is just a plain python action, not
# an inquiry, so we don't natively have a handle on the response
@@ -130,22 +132,24 @@ def mask_secrets(self, value):
# it's just a placeholder to tell mask_secret_parameters()
# that this parameter is indeed a secret parameter and to
# mask it.
- result['parameters']['response'] = mask_secret_parameters(
- parameters=liveaction['parameters']['response'],
- secret_parameters={p: 'string' for p in liveaction['parameters']['response']}
+ result["parameters"]["response"] = mask_secret_parameters(
+ parameters=liveaction["parameters"]["response"],
+ secret_parameters={
+ p: "string" for p in liveaction["parameters"]["response"]
+ },
)
# TODO(mierdin): This logic should be moved to the dedicated Inquiry
# data model once it exists.
- if self.runner.get('name') == "inquirer":
+ if self.runner.get("name") == "inquirer":
- schema = result['result'].get('schema', {})
- response = result['result'].get('response', {})
+ schema = result["result"].get("schema", {})
+ response = result["result"].get("response", {})
# We can only mask response secrets if response and schema exist and are
# not empty
if response and schema:
- result['result']['response'] = mask_inquiry_response(response, schema)
+ result["result"]["response"] = mask_inquiry_response(response, schema)
return result
def get_masked_parameters(self):
@@ -155,7 +159,7 @@ def get_masked_parameters(self):
:rtype: ``dict``
"""
serializable_dict = self.to_serializable_dict(mask_secrets=True)
- return serializable_dict['parameters']
+ return serializable_dict["parameters"]
class ActionExecutionOutputDB(stormbase.StormFoundationDB):
@@ -174,22 +178,25 @@ class ActionExecutionOutputDB(stormbase.StormFoundationDB):
data: Actual output data. This could either be line, chunk or similar, depending on the
runner.
"""
+
execution_id = me.StringField(required=True)
action_ref = me.StringField(required=True)
runner_ref = me.StringField(required=True)
- timestamp = ComplexDateTimeField(required=True, default=date_utils.get_datetime_utc_now)
- output_type = me.StringField(required=True, default='output')
+ timestamp = ComplexDateTimeField(
+ required=True, default=date_utils.get_datetime_utc_now
+ )
+ output_type = me.StringField(required=True, default="output")
delay = me.IntField()
data = me.StringField()
meta = {
- 'indexes': [
- {'fields': ['execution_id']},
- {'fields': ['action_ref']},
- {'fields': ['runner_ref']},
- {'fields': ['timestamp']},
- {'fields': ['output_type']}
+ "indexes": [
+ {"fields": ["execution_id"]},
+ {"fields": ["action_ref"]},
+ {"fields": ["runner_ref"]},
+ {"fields": ["timestamp"]},
+ {"fields": ["output_type"]},
]
}
diff --git a/st2common/st2common/models/db/execution_queue.py b/st2common/st2common/models/db/execution_queue.py
index 31dcebbd1a..8db0993363 100644
--- a/st2common/st2common/models/db/execution_queue.py
+++ b/st2common/st2common/models/db/execution_queue.py
@@ -25,15 +25,16 @@
from st2common.constants.types import ResourceType
__all__ = [
- 'ActionExecutionSchedulingQueueItemDB',
+ "ActionExecutionSchedulingQueueItemDB",
]
LOG = logging.getLogger(__name__)
-class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB,
- stormbase.ChangeRevisionFieldMixin):
+class ActionExecutionSchedulingQueueItemDB(
+ stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin
+):
"""
A model which represents a request for execution to be scheduled.
@@ -42,36 +43,45 @@ class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB,
"""
RESOURCE_TYPE = ResourceType.EXECUTION_REQUEST
- UID_FIELDS = ['id']
+ UID_FIELDS = ["id"]
- liveaction_id = me.StringField(required=True,
- help_text='Foreign key to the LiveActionDB which is to be scheduled')
+ liveaction_id = me.StringField(
+ required=True,
+ help_text="Foreign key to the LiveActionDB which is to be scheduled",
+ )
action_execution_id = me.StringField(
- help_text='Foreign key to the ActionExecutionDB which is to be scheduled')
+ help_text="Foreign key to the ActionExecutionDB which is to be scheduled"
+ )
original_start_timestamp = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the liveaction was created and originally be scheduled to '
- 'run.')
+ help_text="The timestamp when the liveaction was created and originally be scheduled to "
+ "run.",
+ )
scheduled_start_timestamp = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when liveaction is scheduled to run.')
+ help_text="The timestamp when liveaction is scheduled to run.",
+ )
delay = me.IntField()
- handling = me.BooleanField(default=False,
- help_text='Flag indicating if this item is currently being handled / '
- 'processed by a scheduler service')
+ handling = me.BooleanField(
+ default=False,
+ help_text="Flag indicating if this item is currently being handled / "
+ "processed by a scheduler service",
+ )
meta = {
- 'indexes': [
+ "indexes": [
# NOTE: We limit index names to 65 characters total for compatibility with AWS
# DocumentDB.
# See https://github.com/StackStorm/st2/pull/4690 for details.
- {'fields': ['action_execution_id'], 'name': 'ac_exc_id'},
- {'fields': ['liveaction_id'], 'name': 'lv_ac_id'},
- {'fields': ['original_start_timestamp'], 'name': 'orig_s_ts'},
- {'fields': ['scheduled_start_timestamp'], 'name': 'schd_s_ts'},
+ {"fields": ["action_execution_id"], "name": "ac_exc_id"},
+ {"fields": ["liveaction_id"], "name": "lv_ac_id"},
+ {"fields": ["original_start_timestamp"], "name": "orig_s_ts"},
+ {"fields": ["scheduled_start_timestamp"], "name": "schd_s_ts"},
]
}
MODELS = [ActionExecutionSchedulingQueueItemDB]
-EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess(ActionExecutionSchedulingQueueItemDB)
+EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess(
+ ActionExecutionSchedulingQueueItemDB
+)
diff --git a/st2common/st2common/models/db/executionstate.py b/st2common/st2common/models/db/executionstate.py
index db949b6658..94b883038d 100644
--- a/st2common/st2common/models/db/executionstate.py
+++ b/st2common/st2common/models/db/executionstate.py
@@ -21,33 +21,32 @@
from st2common.models.db import stormbase
__all__ = [
- 'ActionExecutionStateDB',
+ "ActionExecutionStateDB",
]
LOG = logging.getLogger(__name__)
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
class ActionExecutionStateDB(stormbase.StormFoundationDB):
"""
- Database entity that represents the state of Action execution.
+ Database entity that represents the state of Action execution.
"""
+
execution_id = me.ObjectIdField(
- required=True,
- unique=True,
- help_text='liveaction ID.')
+ required=True, unique=True, help_text="liveaction ID."
+ )
query_module = me.StringField(
- required=True,
- help_text='Reference to the runner model.')
+ required=True, help_text="Reference to the runner model."
+ )
query_context = me.DictField(
required=True,
- help_text='Context about the action execution that is needed for results query.')
+ help_text="Context about the action execution that is needed for results query.",
+ )
- meta = {
- 'indexes': ['query_module']
- }
+ meta = {"indexes": ["query_module"]}
# specialized access objects
diff --git a/st2common/st2common/models/db/keyvalue.py b/st2common/st2common/models/db/keyvalue.py
index debe58ebbb..ea7fda3b9d 100644
--- a/st2common/st2common/models/db/keyvalue.py
+++ b/st2common/st2common/models/db/keyvalue.py
@@ -21,9 +21,7 @@
from st2common.models.db import MongoDBAccess
from st2common.models.db import stormbase
-__all__ = [
- 'KeyValuePairDB'
-]
+__all__ = ["KeyValuePairDB"]
class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
@@ -34,22 +32,20 @@ class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
"""
RESOURCE_TYPE = ResourceType.KEY_VALUE_PAIR
- UID_FIELDS = ['scope', 'name']
+ UID_FIELDS = ["scope", "name"]
- scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with='name')
+ scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with="name")
name = me.StringField(required=True)
value = me.StringField()
secret = me.BooleanField(default=False)
expire_timestamp = me.DateTimeField()
meta = {
- 'indexes': [
- {'fields': ['name']},
- {
- 'fields': ['expire_timestamp'],
- 'expireAfterSeconds': 0
- }
- ] + stormbase.UIDFieldMixin.get_indexes()
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["expire_timestamp"], "expireAfterSeconds": 0},
+ ]
+ + stormbase.UIDFieldMixin.get_indexes()
}
def __init__(self, *args, **values):
diff --git a/st2common/st2common/models/db/liveaction.py b/st2common/st2common/models/db/liveaction.py
index 6bc5fd77fa..29f5a13bfc 100644
--- a/st2common/st2common/models/db/liveaction.py
+++ b/st2common/st2common/models/db/liveaction.py
@@ -28,12 +28,12 @@
from st2common.util.secrets import mask_secret_parameters
__all__ = [
- 'LiveActionDB',
+ "LiveActionDB",
]
LOG = logging.getLogger(__name__)
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
class LiveActionDB(stormbase.StormFoundationDB):
@@ -41,50 +41,56 @@ class LiveActionDB(stormbase.StormFoundationDB):
task_execution = me.StringField()
# TODO: Can status be an enum at the Mongo layer?
status = me.StringField(
- required=True,
- help_text='The current status of the liveaction.')
+ required=True, help_text="The current status of the liveaction."
+ )
start_timestamp = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the liveaction was created.')
+ help_text="The timestamp when the liveaction was created.",
+ )
end_timestamp = ComplexDateTimeField(
- help_text='The timestamp when the liveaction has finished.')
+ help_text="The timestamp when the liveaction has finished."
+ )
action = me.StringField(
- required=True,
- help_text='Reference to the action that has to be executed.')
+ required=True, help_text="Reference to the action that has to be executed."
+ )
action_is_workflow = me.BooleanField(
default=False,
- help_text='A flag indicating whether the referenced action is a workflow.')
+ help_text="A flag indicating whether the referenced action is a workflow.",
+ )
parameters = stormbase.EscapedDynamicField(
default={},
- help_text='The key-value pairs passed as to the action runner & execution.')
+ help_text="The key-value pairs passed as to the action runner & execution.",
+ )
result = stormbase.EscapedDynamicField(
- default={},
- help_text='Action defined result.')
+ default={}, help_text="Action defined result."
+ )
context = me.DictField(
- default={},
- help_text='Contextual information on the action execution.')
+ default={}, help_text="Contextual information on the action execution."
+ )
callback = me.DictField(
default={},
- help_text='Callback information for the on completion of action execution.')
+ help_text="Callback information for the on completion of action execution.",
+ )
runner_info = me.DictField(
default={},
- help_text='Information about the runner which executed this live action (hostname, pid).')
+ help_text="Information about the runner which executed this live action (hostname, pid).",
+ )
notify = me.EmbeddedDocumentField(NotificationSchema)
delay = me.IntField(
min_value=0,
- help_text='How long (in milliseconds) to delay the execution before scheduling.'
+ help_text="How long (in milliseconds) to delay the execution before scheduling.",
)
meta = {
- 'indexes': [
- {'fields': ['-start_timestamp', 'action']},
- {'fields': ['start_timestamp']},
- {'fields': ['end_timestamp']},
- {'fields': ['action']},
- {'fields': ['status']},
- {'fields': ['context.trigger_instance.id']},
- {'fields': ['workflow_execution']},
- {'fields': ['task_execution']}
+ "indexes": [
+ {"fields": ["-start_timestamp", "action"]},
+ {"fields": ["start_timestamp"]},
+ {"fields": ["end_timestamp"]},
+ {"fields": ["action"]},
+ {"fields": ["status"]},
+ {"fields": ["context.trigger_instance.id"]},
+ {"fields": ["workflow_execution"]},
+ {"fields": ["task_execution"]},
]
}
@@ -92,7 +98,7 @@ def mask_secrets(self, value):
from st2common.util import action_db
result = copy.deepcopy(value)
- execution_parameters = value['parameters']
+ execution_parameters = value["parameters"]
# TODO: This results into two DB looks, we should cache action and runner type object
# for each liveaction...
@@ -104,8 +110,9 @@ def mask_secrets(self, value):
parameters = action_db.get_action_parameters_specs(action_ref=self.action)
secret_parameters = get_secret_parameters(parameters=parameters)
- result['parameters'] = mask_secret_parameters(parameters=execution_parameters,
- secret_parameters=secret_parameters)
+ result["parameters"] = mask_secret_parameters(
+ parameters=execution_parameters, secret_parameters=secret_parameters
+ )
return result
def get_masked_parameters(self):
@@ -115,7 +122,7 @@ def get_masked_parameters(self):
:rtype: ``dict``
"""
serializable_dict = self.to_serializable_dict(mask_secrets=True)
- return serializable_dict['parameters']
+ return serializable_dict["parameters"]
# specialized access objects
diff --git a/st2common/st2common/models/db/marker.py b/st2common/st2common/models/db/marker.py
index 1bddf3f604..7a053e5490 100644
--- a/st2common/st2common/models/db/marker.py
+++ b/st2common/st2common/models/db/marker.py
@@ -20,10 +20,7 @@
from st2common.models.db import stormbase
from st2common.util import date as date_utils
-__all__ = [
- 'MarkerDB',
- 'DumperMarkerDB'
-]
+__all__ = ["MarkerDB", "DumperMarkerDB"]
class MarkerDB(stormbase.StormFoundationDB):
@@ -37,20 +34,21 @@ class MarkerDB(stormbase.StormFoundationDB):
:param updated_at: Timestamp when marker was updated.
:type updated_at: ``datetime.datetime``
"""
+
marker = me.StringField(required=True)
updated_at = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the liveaction was created.')
+ help_text="The timestamp when the liveaction was created.",
+ )
- meta = {
- 'abstract': True
- }
+ meta = {"abstract": True}
class DumperMarkerDB(MarkerDB):
"""
Marker model used by Dumper (in exporter).
"""
+
pass
diff --git a/st2common/st2common/models/db/notification.py b/st2common/st2common/models/db/notification.py
index e311f46b75..8ef793887b 100644
--- a/st2common/st2common/models/db/notification.py
+++ b/st2common/st2common/models/db/notification.py
@@ -21,43 +21,47 @@
class NotificationSubSchema(me.EmbeddedDocument):
"""
- Schema for notification settings to be specified for action success/failure.
+ Schema for notification settings to be specified for action success/failure.
"""
+
message = me.StringField()
data = stormbase.EscapedDynamicField(
- default={},
- help_text='Payload to be sent as part of notification.')
+ default={}, help_text="Payload to be sent as part of notification."
+ )
routes = me.ListField(
- default=['notify.default'],
- help_text='Routes to post notifications to.')
- channels = me.ListField( # Deprecated. Only here for backward compatibility reasons.
- default=['notify.default'],
- help_text='Routes to post notifications to.')
+ default=["notify.default"], help_text="Routes to post notifications to."
+ )
+ channels = (
+ me.ListField( # Deprecated. Only here for backward compatibility reasons.
+ default=["notify.default"], help_text="Routes to post notifications to."
+ )
+ )
def __str__(self):
result = []
- result.append('NotificationSubSchema@')
+ result.append("NotificationSubSchema@")
result.append(str(id(self)))
result.append('(message="%s", ' % str(self.message))
result.append('data="%s", ' % str(self.data))
result.append('routes="%s", ' % str(self.routes))
result.append('[**deprecated**]channels="%s")' % str(self.channels))
- return ''.join(result)
+ return "".join(result)
class NotificationSchema(me.EmbeddedDocument):
"""
- Schema for notification settings to be specified for actions.
+ Schema for notification settings to be specified for actions.
"""
+
on_success = me.EmbeddedDocumentField(NotificationSubSchema)
on_failure = me.EmbeddedDocumentField(NotificationSubSchema)
on_complete = me.EmbeddedDocumentField(NotificationSubSchema)
def __str__(self):
result = []
- result.append('NotifySchema@')
+ result.append("NotifySchema@")
result.append(str(id(self)))
result.append('(on_complete="%s", ' % str(self.on_complete))
result.append('on_success="%s", ' % str(self.on_success))
result.append('on_failure="%s")' % str(self.on_failure))
- return ''.join(result)
+ return "".join(result)
diff --git a/st2common/st2common/models/db/pack.py b/st2common/st2common/models/db/pack.py
index cf16910987..c92b009624 100644
--- a/st2common/st2common/models/db/pack.py
+++ b/st2common/st2common/models/db/pack.py
@@ -25,21 +25,16 @@
from st2common.util.secrets import get_secret_parameters
from st2common.util.secrets import mask_secret_parameters
-__all__ = [
- 'PackDB',
- 'ConfigSchemaDB',
- 'ConfigDB'
-]
+__all__ = ["PackDB", "ConfigSchemaDB", "ConfigDB"]
-class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin,
- me.DynamicDocument):
+class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, me.DynamicDocument):
"""
System entity which represents a pack.
"""
RESOURCE_TYPE = ResourceType.PACK
- UID_FIELDS = ['ref']
+ UID_FIELDS = ["ref"]
ref = me.StringField(required=True, unique=True)
name = me.StringField(required=True, unique=True)
@@ -56,9 +51,7 @@ class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin,
dependencies = me.ListField(field=me.StringField())
system = me.DictField()
- meta = {
- 'indexes': stormbase.UIDFieldMixin.get_indexes()
- }
+ meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()}
def __init__(self, *args, **values):
super(PackDB, self).__init__(*args, **values)
@@ -73,22 +66,24 @@ class ConfigSchemaDB(stormbase.StormFoundationDB):
pack = me.StringField(
required=True,
unique=True,
- help_text='Name of the content pack this schema belongs to.')
+ help_text="Name of the content pack this schema belongs to.",
+ )
attributes = stormbase.EscapedDynamicField(
- help_text='The specification for config schema attributes.')
+ help_text="The specification for config schema attributes."
+ )
class ConfigDB(stormbase.StormFoundationDB):
"""
System entity representing pack config.
"""
+
pack = me.StringField(
required=True,
unique=True,
- help_text='Name of the content pack this config belongs to.')
- values = stormbase.EscapedDynamicField(
- help_text='Config values.',
- default={})
+ help_text="Name of the content pack this config belongs to.",
+ )
+ values = stormbase.EscapedDynamicField(help_text="Config values.", default={})
def mask_secrets(self, value):
"""
@@ -101,11 +96,12 @@ def mask_secrets(self, value):
"""
result = copy.deepcopy(value)
- config_schema = config_schema_access.get_by_pack(result['pack'])
+ config_schema = config_schema_access.get_by_pack(result["pack"])
secret_parameters = get_secret_parameters(parameters=config_schema.attributes)
- result['values'] = mask_secret_parameters(parameters=result['values'],
- secret_parameters=secret_parameters)
+ result["values"] = mask_secret_parameters(
+ parameters=result["values"], secret_parameters=secret_parameters
+ )
return result
diff --git a/st2common/st2common/models/db/policy.py b/st2common/st2common/models/db/policy.py
index 69f709093c..8b9fcafef0 100644
--- a/st2common/st2common/models/db/policy.py
+++ b/st2common/st2common/models/db/policy.py
@@ -23,9 +23,7 @@
from st2common.constants.types import ResourceType
-__all__ = ['PolicyTypeReference',
- 'PolicyTypeDB',
- 'PolicyDB']
+__all__ = ["PolicyTypeReference", "PolicyTypeDB", "PolicyDB"]
LOG = logging.getLogger(__name__)
@@ -34,7 +32,8 @@ class PolicyTypeReference(object):
"""
Class used for referring to policy types which belong to a resource type.
"""
- separator = '.'
+
+ separator = "."
def __init__(self, resource_type=None, name=None):
self.resource_type = self.validate_resource_type(resource_type)
@@ -54,14 +53,15 @@ def is_reference(cls, ref):
@classmethod
def from_string_reference(cls, ref):
- return cls(resource_type=cls.get_resource_type(ref),
- name=cls.get_name(ref))
+ return cls(resource_type=cls.get_resource_type(ref), name=cls.get_name(ref))
@classmethod
def to_string_reference(cls, resource_type=None, name=None):
if not resource_type or not name:
- raise ValueError('Both resource_type and name are required for building ref. '
- 'resource_type=%s, name=%s' % (resource_type, name))
+ raise ValueError(
+ "Both resource_type and name are required for building ref. "
+ "resource_type=%s, name=%s" % (resource_type, name)
+ )
resource_type = cls.validate_resource_type(resource_type)
return cls.separator.join([resource_type, name])
@@ -69,7 +69,7 @@ def to_string_reference(cls, resource_type=None, name=None):
@classmethod
def validate_resource_type(cls, resource_type):
if not resource_type:
- raise ValueError('Resource type should not be empty.')
+ raise ValueError("Resource type should not be empty.")
if cls.separator in resource_type:
raise ValueError('Resource type should not contain "%s".' % cls.separator)
@@ -80,7 +80,7 @@ def validate_resource_type(cls, resource_type):
def get_resource_type(cls, ref):
try:
if not cls.is_reference(ref):
- raise ValueError('%s is not a valid reference.' % ref)
+ raise ValueError("%s is not a valid reference." % ref)
return ref.split(cls.separator, 1)[0]
except (ValueError, IndexError, AttributeError):
@@ -90,15 +90,19 @@ def get_resource_type(cls, ref):
def get_name(cls, ref):
try:
if not cls.is_reference(ref):
- raise ValueError('%s is not a valid reference.' % ref)
+ raise ValueError("%s is not a valid reference." % ref)
return ref.split(cls.separator, 1)[1]
except (ValueError, IndexError, AttributeError):
raise common_models.InvalidReferenceError(ref=ref)
def __repr__(self):
- return ('<%s resource_type=%s,name=%s,ref=%s>' %
- (self.__class__.__name__, self.resource_type, self.name, self.ref))
+ return "<%s resource_type=%s,name=%s,ref=%s>" % (
+ self.__class__.__name__,
+ self.resource_type,
+ self.name,
+ self.ref,
+ )
class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
@@ -114,29 +118,35 @@ class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
module: The python module that implements the policy for this type.
parameters: The specification for parameters for the policy type.
"""
+
RESOURCE_TYPE = ResourceType.POLICY_TYPE
- UID_FIELDS = ['resource_type', 'name']
+ UID_FIELDS = ["resource_type", "name"]
ref = me.StringField(required=True)
resource_type = me.StringField(
required=True,
- unique_with='name',
- help_text='The type of resource that this policy type can be applied to.')
+ unique_with="name",
+ help_text="The type of resource that this policy type can be applied to.",
+ )
enabled = me.BooleanField(
required=True,
default=True,
- help_text='A flag indicating whether the runner for this type is enabled.')
+ help_text="A flag indicating whether the runner for this type is enabled.",
+ )
module = me.StringField(
required=True,
- help_text='The python module that implements the policy for this type.')
+ help_text="The python module that implements the policy for this type.",
+ )
parameters = me.DictField(
- help_text='The specification for parameters for the policy type.')
+ help_text="The specification for parameters for the policy type."
+ )
def __init__(self, *args, **kwargs):
super(PolicyTypeDB, self).__init__(*args, **kwargs)
self.uid = self.get_uid()
- self.ref = PolicyTypeReference.to_string_reference(resource_type=self.resource_type,
- name=self.name)
+ self.ref = PolicyTypeReference.to_string_reference(
+ resource_type=self.resource_type, name=self.name
+ )
def get_reference(self):
"""
@@ -147,8 +157,11 @@ def get_reference(self):
return PolicyTypeReference(resource_type=self.resource_type, name=self.name)
-class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin,
- stormbase.UIDFieldMixin):
+class PolicyDB(
+ stormbase.StormFoundationDB,
+ stormbase.ContentPackResourceMixin,
+ stormbase.UIDFieldMixin,
+):
"""
The representation for a policy in the system.
@@ -158,43 +171,47 @@ class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin,
policy_type: The type of policy.
parameters: The specification of input parameters for the policy.
"""
+
RESOURCE_TYPE = ResourceType.POLICY
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
ref = me.StringField(required=True)
pack = me.StringField(
required=False,
default=pack_constants.DEFAULT_PACK_NAME,
- unique_with='name',
- help_text='Name of the content pack.')
+ unique_with="name",
+ help_text="Name of the content pack.",
+ )
description = me.StringField()
enabled = me.BooleanField(
required=True,
default=True,
- help_text='A flag indicating whether this policy is enabled in the system.')
+ help_text="A flag indicating whether this policy is enabled in the system.",
+ )
resource_ref = me.StringField(
- required=True,
- help_text='The resource that this policy is applied to.')
+ required=True, help_text="The resource that this policy is applied to."
+ )
policy_type = me.StringField(
- required=True,
- unique_with='resource_ref',
- help_text='The type of policy.')
+ required=True, unique_with="resource_ref", help_text="The type of policy."
+ )
parameters = me.DictField(
- help_text='The specification of input parameters for the policy.')
+ help_text="The specification of input parameters for the policy."
+ )
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['resource_ref']},
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["resource_ref"]},
]
}
def __init__(self, *args, **kwargs):
super(PolicyDB, self).__init__(*args, **kwargs)
self.uid = self.get_uid()
- self.ref = common_models.ResourceReference.to_string_reference(pack=self.pack,
- name=self.name)
+ self.ref = common_models.ResourceReference.to_string_reference(
+ pack=self.pack, name=self.name
+ )
MODELS = [PolicyTypeDB, PolicyDB]
diff --git a/st2common/st2common/models/db/rbac.py b/st2common/st2common/models/db/rbac.py
index 68b41ea314..bb82ba88cb 100644
--- a/st2common/st2common/models/db/rbac.py
+++ b/st2common/st2common/models/db/rbac.py
@@ -21,14 +21,13 @@
__all__ = [
- 'RoleDB',
- 'UserRoleAssignmentDB',
- 'PermissionGrantDB',
- 'GroupToRoleMappingDB',
-
- 'role_access',
- 'user_role_assignment_access',
- 'permission_grant_access'
+ "RoleDB",
+ "UserRoleAssignmentDB",
+ "PermissionGrantDB",
+ "GroupToRoleMappingDB",
+ "role_access",
+ "user_role_assignment_access",
+ "permission_grant_access",
]
@@ -43,15 +42,16 @@ class RoleDB(stormbase.StormFoundationDB):
permission_grants: A list of IDs to the permission grant which apply to this
role.
"""
+
name = me.StringField(required=True, unique=True)
description = me.StringField()
system = me.BooleanField(default=False)
permission_grants = me.ListField(field=me.StringField())
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['system']},
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["system"]},
]
}
@@ -67,9 +67,10 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB):
and "API" for API assignments.
description: Optional assigment description.
"""
+
user = me.StringField(required=True)
- role = me.StringField(required=True, unique_with=['user', 'source'])
- source = me.StringField(required=True, unique_with=['user', 'role'])
+ role = me.StringField(required=True, unique_with=["user", "source"])
+ source = me.StringField(required=True, unique_with=["user", "role"])
description = me.StringField()
# True if this is assigned created on authentication based on the remote groups provided by
# the auth backends.
@@ -78,12 +79,12 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB):
is_remote = me.BooleanField(default=False)
meta = {
- 'indexes': [
- {'fields': ['user']},
- {'fields': ['role']},
- {'fields': ['source']},
- {'fields': ['is_remote']},
- {'fields': ['user', 'role']},
+ "indexes": [
+ {"fields": ["user"]},
+ {"fields": ["role"]},
+ {"fields": ["source"]},
+ {"fields": ["is_remote"]},
+ {"fields": ["user", "role"]},
]
}
@@ -98,13 +99,14 @@ class PermissionGrantDB(stormbase.StormFoundationDB):
convenience and to allow for more efficient queries.
permission_types: A list of permission type granted to that resources.
"""
+
resource_uid = me.StringField(required=False)
resource_type = me.StringField(required=False)
permission_types = me.ListField(field=me.StringField())
meta = {
- 'indexes': [
- {'fields': ['resource_uid']},
+ "indexes": [
+ {"fields": ["resource_uid"]},
]
}
@@ -120,12 +122,16 @@ class GroupToRoleMappingDB(stormbase.StormFoundationDB):
and "API" for API assignments.
description: Optional description for this mapping.
"""
+
group = me.StringField(required=True, unique=True)
roles = me.ListField(field=me.StringField())
source = me.StringField()
description = me.StringField()
- enabled = me.BooleanField(required=True, default=True,
- help_text='A flag indicating whether the mapping is enabled.')
+ enabled = me.BooleanField(
+ required=True,
+ default=True,
+ help_text="A flag indicating whether the mapping is enabled.",
+ )
# Specialized access objects
diff --git a/st2common/st2common/models/db/reactor.py b/st2common/st2common/models/db/reactor.py
index dc9f08b58e..8b8032654b 100644
--- a/st2common/st2common/models/db/reactor.py
+++ b/st2common/st2common/models/db/reactor.py
@@ -14,18 +14,17 @@
# limitations under the License.
from __future__ import absolute_import
-from st2common.models.db.rule import (ActionExecutionSpecDB, RuleDB)
+from st2common.models.db.rule import ActionExecutionSpecDB, RuleDB
from st2common.models.db.sensor import SensorTypeDB
-from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB)
+from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB
__all__ = [
- 'ActionExecutionSpecDB',
- 'RuleDB',
- 'SensorTypeDB',
- 'TriggerTypeDB',
- 'TriggerDB',
- 'TriggerInstanceDB'
+ "ActionExecutionSpecDB",
+ "RuleDB",
+ "SensorTypeDB",
+ "TriggerTypeDB",
+ "TriggerDB",
+ "TriggerInstanceDB",
]
-MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB,
- TriggerTypeDB]
+MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB, TriggerTypeDB]
diff --git a/st2common/st2common/models/db/rule.py b/st2common/st2common/models/db/rule.py
index f056734f8c..f4f26ec669 100644
--- a/st2common/st2common/models/db/rule.py
+++ b/st2common/st2common/models/db/rule.py
@@ -28,25 +28,24 @@
class RuleTypeDB(stormbase.StormBaseDB):
enabled = me.BooleanField(
default=True,
- help_text='A flag indicating whether the runner for this type is enabled.')
+ help_text="A flag indicating whether the runner for this type is enabled.",
+ )
parameters = me.DictField(
- help_text='The specification for parameters for the action.',
- default={})
+ help_text="The specification for parameters for the action.", default={}
+ )
class RuleTypeSpecDB(me.EmbeddedDocument):
- ref = me.StringField(unique=False,
- help_text='Type of rule.',
- default='standard')
+ ref = me.StringField(unique=False, help_text="Type of rule.", default="standard")
parameters = me.DictField(default={})
def __str__(self):
result = []
- result.append('RuleTypeSpecDB@')
+ result.append("RuleTypeSpecDB@")
result.append(str(id(self)))
result.append('(ref="%s", ' % self.ref)
result.append('parameters="%s")' % self.parameters)
- return ''.join(result)
+ return "".join(result)
class ActionExecutionSpecDB(me.EmbeddedDocument):
@@ -55,15 +54,19 @@ class ActionExecutionSpecDB(me.EmbeddedDocument):
def __str__(self):
result = []
- result.append('ActionExecutionSpecDB@')
+ result.append("ActionExecutionSpecDB@")
result.append(str(id(self)))
result.append('(ref="%s", ' % self.ref)
result.append('parameters="%s")' % self.parameters)
- return ''.join(result)
+ return "".join(result)
-class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
- stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin):
+class RuleDB(
+ stormbase.StormFoundationDB,
+ stormbase.TagsMixin,
+ stormbase.ContentPackResourceMixin,
+ stormbase.UIDFieldMixin,
+):
"""Specifies the action to invoke on the occurrence of a Trigger. It
also includes the transformation to perform to match the impedance
between the payload of a TriggerInstance and input of a action.
@@ -74,36 +77,39 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
status: enabled or disabled. If disabled occurrence of the trigger
does not lead to execution of a action and vice-versa.
"""
+
RESOURCE_TYPE = ResourceType.RULE
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
ref = me.StringField(required=True)
description = me.StringField()
pack = me.StringField(
- required=False,
- help_text='Name of the content pack.',
- unique_with='name')
+ required=False, help_text="Name of the content pack.", unique_with="name"
+ )
type = me.EmbeddedDocumentField(RuleTypeSpecDB, default=RuleTypeSpecDB())
trigger = me.StringField()
criteria = stormbase.EscapedDictField()
action = me.EmbeddedDocumentField(ActionExecutionSpecDB)
- context = me.DictField(
- default={},
- help_text='Contextual info on the rule'
+ context = me.DictField(default={}, help_text="Contextual info on the rule")
+ enabled = me.BooleanField(
+ required=True,
+ default=True,
+ help_text="Flag indicating whether the rule is enabled.",
)
- enabled = me.BooleanField(required=True, default=True,
- help_text=u'Flag indicating whether the rule is enabled.')
meta = {
- 'indexes': [
- {'fields': ['enabled']},
- {'fields': ['action.ref']},
- {'fields': ['trigger']},
- {'fields': ['context.user']},
- ] + (stormbase.ContentPackResourceMixin.get_indexes() +
- stormbase.TagsMixin.get_indexes() +
- stormbase.UIDFieldMixin.get_indexes())
+ "indexes": [
+ {"fields": ["enabled"]},
+ {"fields": ["action.ref"]},
+ {"fields": ["trigger"]},
+ {"fields": ["context.user"]},
+ ]
+ + (
+ stormbase.ContentPackResourceMixin.get_indexes()
+ + stormbase.TagsMixin.get_indexes()
+ + stormbase.UIDFieldMixin.get_indexes()
+ )
}
def mask_secrets(self, value):
@@ -120,7 +126,7 @@ def mask_secrets(self, value):
"""
result = copy.deepcopy(value)
- action_ref = result.get('action', {}).get('ref', None)
+ action_ref = result.get("action", {}).get("ref", None)
if not action_ref:
return result
@@ -131,9 +137,10 @@ def mask_secrets(self, value):
return result
secret_parameters = get_secret_parameters(parameters=action_db.parameters)
- result['action']['parameters'] = mask_secret_parameters(
- parameters=result['action']['parameters'],
- secret_parameters=secret_parameters)
+ result["action"]["parameters"] = mask_secret_parameters(
+ parameters=result["action"]["parameters"],
+ secret_parameters=secret_parameters,
+ )
return result
@@ -147,8 +154,9 @@ def _get_referenced_action_model(self, action_ref):
:rtype: ``ActionDB``
"""
# NOTE: We need to retrieve pack and name since that's needed for the PK
- action_dbs = Action.query(only_fields=['pack', 'ref', 'name', 'parameters'],
- ref=action_ref, limit=1)
+ action_dbs = Action.query(
+ only_fields=["pack", "ref", "name", "parameters"], ref=action_ref, limit=1
+ )
if action_dbs:
return action_dbs[0]
diff --git a/st2common/st2common/models/db/rule_enforcement.py b/st2common/st2common/models/db/rule_enforcement.py
index 80ea1f14fe..62d2a21faf 100644
--- a/st2common/st2common/models/db/rule_enforcement.py
+++ b/st2common/st2common/models/db/rule_enforcement.py
@@ -24,34 +24,27 @@
from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_SUCCEEDED
from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_FAILED
-__all__ = [
- 'RuleReferenceSpecDB',
- 'RuleEnforcementDB'
-]
+__all__ = ["RuleReferenceSpecDB", "RuleEnforcementDB"]
class RuleReferenceSpecDB(me.EmbeddedDocument):
- ref = me.StringField(unique=False,
- help_text='Reference to rule.',
- required=True)
- id = me.StringField(required=False,
- help_text='Rule ID.')
- uid = me.StringField(required=True,
- help_text='Rule UID.')
+ ref = me.StringField(unique=False, help_text="Reference to rule.", required=True)
+ id = me.StringField(required=False, help_text="Rule ID.")
+ uid = me.StringField(required=True, help_text="Rule UID.")
def __str__(self):
result = []
- result.append('RuleReferenceSpecDB@')
+ result.append("RuleReferenceSpecDB@")
result.append(str(id(self)))
result.append('(ref="%s", ' % self.ref)
result.append('id="%s", ' % self.id)
result.append('uid="%s")' % self.uid)
- return ''.join(result)
+ return "".join(result)
class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin):
- UID_FIELDS = ['id']
+ UID_FIELDS = ["id"]
trigger_instance_id = me.StringField(required=True)
execution_id = me.StringField(required=False)
@@ -59,31 +52,34 @@ class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin):
rule = me.EmbeddedDocumentField(RuleReferenceSpecDB, required=True)
enforced_at = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the rule enforcement happened.')
+ help_text="The timestamp when the rule enforcement happened.",
+ )
status = me.StringField(
required=True,
default=RULE_ENFORCEMENT_STATUS_SUCCEEDED,
- help_text='Rule enforcement status.')
+ help_text="Rule enforcement status.",
+ )
meta = {
- 'indexes': [
- {'fields': ['trigger_instance_id']},
- {'fields': ['execution_id']},
- {'fields': ['rule.id']},
- {'fields': ['rule.ref']},
- {'fields': ['enforced_at']},
- {'fields': ['-enforced_at']},
- {'fields': ['-enforced_at', 'rule.ref']},
- {'fields': ['status']},
- ] + stormbase.TagsMixin.get_indexes()
+ "indexes": [
+ {"fields": ["trigger_instance_id"]},
+ {"fields": ["execution_id"]},
+ {"fields": ["rule.id"]},
+ {"fields": ["rule.ref"]},
+ {"fields": ["enforced_at"]},
+ {"fields": ["-enforced_at"]},
+ {"fields": ["-enforced_at", "rule.ref"]},
+ {"fields": ["status"]},
+ ]
+ + stormbase.TagsMixin.get_indexes()
}
def __init__(self, *args, **values):
super(RuleEnforcementDB, self).__init__(*args, **values)
# Set status to succeeded for old / existing RuleEnforcementDB which predate status field
- status = getattr(self, 'status', None)
- failure_reason = getattr(self, 'failure_reason', None)
+ status = getattr(self, "status", None)
+ failure_reason = getattr(self, "failure_reason", None)
if status in [None, RULE_ENFORCEMENT_STATUS_SUCCEEDED] and failure_reason:
self.status = RULE_ENFORCEMENT_STATUS_FAILED
@@ -92,8 +88,8 @@ def __init__(self, *args, **values):
# with a consistent get_uid interface.
def get_uid(self):
# TODO Construct uid from non id field:
- uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101
- return ':'.join(uid)
+ uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101
+ return ":".join(uid)
rule_enforcement_access = MongoDBAccess(RuleEnforcementDB)
diff --git a/st2common/st2common/models/db/runner.py b/st2common/st2common/models/db/runner.py
index c2f290f5b4..9097d35be6 100644
--- a/st2common/st2common/models/db/runner.py
+++ b/st2common/st2common/models/db/runner.py
@@ -22,13 +22,13 @@
from st2common.constants.types import ResourceType
__all__ = [
- 'RunnerTypeDB',
+ "RunnerTypeDB",
]
LOG = logging.getLogger(__name__)
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
@@ -46,31 +46,37 @@ class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin):
"""
RESOURCE_TYPE = ResourceType.RUNNER_TYPE
- UID_FIELDS = ['name']
+ UID_FIELDS = ["name"]
enabled = me.BooleanField(
- required=True, default=True,
- help_text='A flag indicating whether the runner for this type is enabled.')
+ required=True,
+ default=True,
+ help_text="A flag indicating whether the runner for this type is enabled.",
+ )
runner_package = me.StringField(
required=False,
- help_text=('The python package that implements the action runner for this type. If'
- 'not provided it assumes package name equals module name.'))
+ help_text=(
+ "The python package that implements the action runner for this type. If"
+ "not provided it assumes package name equals module name."
+ ),
+ )
runner_module = me.StringField(
required=True,
- help_text='The python module that implements the action runner for this type.')
+ help_text="The python module that implements the action runner for this type.",
+ )
runner_parameters = me.DictField(
- help_text='The specification for parameters for the action runner.')
+ help_text="The specification for parameters for the action runner."
+ )
output_key = me.StringField(
- help_text='Default key to expect results to be published to.')
- output_schema = me.DictField(
- help_text='The schema for runner output.')
+ help_text="Default key to expect results to be published to."
+ )
+ output_schema = me.DictField(help_text="The schema for runner output.")
query_module = me.StringField(
required=False,
- help_text='The python module that implements the query module for this runner.')
+ help_text="The python module that implements the query module for this runner.",
+ )
- meta = {
- 'indexes': stormbase.UIDFieldMixin.get_indexes()
- }
+ meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()}
def __init__(self, *args, **values):
super(RunnerTypeDB, self).__init__(*args, **values)
diff --git a/st2common/st2common/models/db/sensor.py b/st2common/st2common/models/db/sensor.py
index 6517fb3a75..31437ad321 100644
--- a/st2common/st2common/models/db/sensor.py
+++ b/st2common/st2common/models/db/sensor.py
@@ -20,13 +20,12 @@
from st2common.models.db import stormbase
from st2common.constants.types import ResourceType
-__all__ = [
- 'SensorTypeDB'
-]
+__all__ = ["SensorTypeDB"]
-class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin,
- stormbase.UIDFieldMixin):
+class SensorTypeDB(
+ stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin
+):
"""
Description of a specific type of a sensor (think of it as a sensor
template).
@@ -40,25 +39,29 @@ class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin,
"""
RESOURCE_TYPE = ResourceType.SENSOR_TYPE
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
ref = me.StringField(required=True)
- pack = me.StringField(required=True, unique_with='name')
+ pack = me.StringField(required=True, unique_with="name")
artifact_uri = me.StringField()
entry_point = me.StringField()
trigger_types = me.ListField(field=me.StringField())
poll_interval = me.IntField()
- enabled = me.BooleanField(default=True,
- help_text=u'Flag indicating whether the sensor is enabled.')
+ enabled = me.BooleanField(
+ default=True, help_text="Flag indicating whether the sensor is enabled."
+ )
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['enabled']},
- {'fields': ['trigger_types']},
- ] + (stormbase.ContentPackResourceMixin.get_indexes() +
- stormbase.UIDFieldMixin.get_indexes())
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["enabled"]},
+ {"fields": ["trigger_types"]},
+ ]
+ + (
+ stormbase.ContentPackResourceMixin.get_indexes()
+ + stormbase.UIDFieldMixin.get_indexes()
+ )
}
def __init__(self, *args, **values):
diff --git a/st2common/st2common/models/db/stormbase.py b/st2common/st2common/models/db/stormbase.py
index bf312c6e4f..50f79dde78 100644
--- a/st2common/st2common/models/db/stormbase.py
+++ b/st2common/st2common/models/db/stormbase.py
@@ -29,17 +29,15 @@
from st2common.constants.types import ResourceType
__all__ = [
- 'StormFoundationDB',
- 'StormBaseDB',
-
- 'EscapedDictField',
- 'EscapedDynamicField',
- 'TagField',
-
- 'RefFieldMixin',
- 'UIDFieldMixin',
- 'TagsMixin',
- 'ContentPackResourceMixin'
+ "StormFoundationDB",
+ "StormBaseDB",
+ "EscapedDictField",
+ "EscapedDynamicField",
+ "TagField",
+ "RefFieldMixin",
+ "UIDFieldMixin",
+ "TagsMixin",
+ "ContentPackResourceMixin",
]
JSON_UNFRIENDLY_TYPES = (datetime.datetime, bson.ObjectId)
@@ -62,17 +60,19 @@ class StormFoundationDB(me.Document, DictSerializableClassMixin):
# don't do that
# see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes
- meta = {
- 'abstract': True
- }
+ meta = {"abstract": True}
def __str__(self):
attrs = list()
- for k in sorted(self._fields.keys()): # pylint: disable=E1101
+ for k in sorted(self._fields.keys()): # pylint: disable=E1101
v = getattr(self, k)
- v = '"%s"' % str(v) if type(v) in [str, six.text_type, datetime.datetime] else str(v)
- attrs.append('%s=%s' % (k, v))
- return '%s(%s)' % (self.__class__.__name__, ', '.join(attrs))
+ v = (
+ '"%s"' % str(v)
+ if type(v) in [str, six.text_type, datetime.datetime]
+ else str(v)
+ )
+ attrs.append("%s=%s" % (k, v))
+ return "%s(%s)" % (self.__class__.__name__, ", ".join(attrs))
def get_resource_type(self):
return self.RESOURCE_TYPE
@@ -98,7 +98,7 @@ def to_serializable_dict(self, mask_secrets=False):
:rtype: ``dict``
"""
serializable_dict = {}
- for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101
+ for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101
v = getattr(self, k)
if isinstance(v, JSON_UNFRIENDLY_TYPES):
v = str(v)
@@ -120,17 +120,15 @@ class StormBaseDB(StormFoundationDB):
description = me.StringField()
# see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes
- meta = {
- 'abstract': True
- }
+ meta = {"abstract": True}
class EscapedDictField(me.DictField):
-
def to_mongo(self, value, use_db_field=True, fields=None):
value = mongoescape.escape_chars(value)
- return super(EscapedDictField, self).to_mongo(value=value, use_db_field=use_db_field,
- fields=fields)
+ return super(EscapedDictField, self).to_mongo(
+ value=value, use_db_field=use_db_field, fields=fields
+ )
def to_python(self, value):
value = super(EscapedDictField, self).to_python(value)
@@ -138,18 +136,18 @@ def to_python(self, value):
def validate(self, value):
if not isinstance(value, dict):
- self.error('Only dictionaries may be used in a DictField')
+ self.error("Only dictionaries may be used in a DictField")
if me.fields.key_not_string(value):
self.error("Invalid dictionary key - documents must have only string keys")
me.base.ComplexBaseField.validate(self, value)
class EscapedDynamicField(me.DynamicField):
-
def to_mongo(self, value, use_db_field=True, fields=None):
value = mongoescape.escape_chars(value)
- return super(EscapedDynamicField, self).to_mongo(value=value, use_db_field=use_db_field,
- fields=fields)
+ return super(EscapedDynamicField, self).to_mongo(
+ value=value, use_db_field=use_db_field, fields=fields
+ )
def to_python(self, value):
value = super(EscapedDynamicField, self).to_python(value)
@@ -161,6 +159,7 @@ class TagField(me.EmbeddedDocument):
To be attached to a db model object for the purpose of providing supplemental
information.
"""
+
name = me.StringField(max_length=1024)
value = me.StringField(max_length=1024)
@@ -169,11 +168,12 @@ class TagsMixin(object):
"""
Mixin to include tags on an object.
"""
+
tags = me.ListField(field=me.EmbeddedDocumentField(TagField))
@classmethod
def get_indexes(cls):
- return ['tags.name', 'tags.value']
+ return ["tags.name", "tags.value"]
class RefFieldMixin(object):
@@ -192,7 +192,7 @@ class UIDFieldMixin(object):
the system.
"""
- UID_SEPARATOR = ':' # TODO: Move to constants
+ UID_SEPARATOR = ":" # TODO: Move to constants
RESOURCE_TYPE = abc.abstractproperty
UID_FIELDS = abc.abstractproperty
@@ -205,13 +205,7 @@ def get_indexes(cls):
# models in the database before ensure_indexes() is called.
# This field gets populated in the constructor which means it will be lazily assigned next
# time the model is saved (e.g. once register-content is ran).
- indexes = [
- {
- 'fields': ['uid'],
- 'unique': True,
- 'sparse': True
- }
- ]
+ indexes = [{"fields": ["uid"], "unique": True, "sparse": True}]
return indexes
def get_uid(self):
@@ -224,7 +218,7 @@ def get_uid(self):
parts.append(self.RESOURCE_TYPE)
for field in self.UID_FIELDS:
- value = getattr(self, field, None) or ''
+ value = getattr(self, field, None) or ""
parts.append(value)
uid = self.UID_SEPARATOR.join(parts)
@@ -257,8 +251,11 @@ class ContentPackResourceMixin(object):
metadata_file = me.StringField(
required=False,
- help_text=('Path to the metadata file (file on disk which contains resource definition) '
- 'relative to the pack directory.'))
+ help_text=(
+ "Path to the metadata file (file on disk which contains resource definition) "
+ "relative to the pack directory."
+ ),
+ )
def get_pack_uid(self):
"""
@@ -276,7 +273,7 @@ def get_reference(self):
:rtype: :class:`ResourceReference`
"""
- if getattr(self, 'ref', None):
+ if getattr(self, "ref", None):
ref = ResourceReference.from_string_reference(ref=self.ref)
else:
ref = ResourceReference(pack=self.pack, name=self.name)
@@ -287,7 +284,7 @@ def get_reference(self):
def get_indexes(cls):
return [
{
- 'fields': ['metadata_file'],
+ "fields": ["metadata_file"],
}
]
@@ -298,9 +295,4 @@ class ChangeRevisionFieldMixin(object):
@classmethod
def get_indexes(cls):
- return [
- {
- 'fields': ['id', 'rev'],
- 'unique': True
- }
- ]
+ return [{"fields": ["id", "rev"], "unique": True}]
diff --git a/st2common/st2common/models/db/timer.py b/st2common/st2common/models/db/timer.py
index 98bb7952e1..652d6a056a 100644
--- a/st2common/st2common/models/db/timer.py
+++ b/st2common/st2common/models/db/timer.py
@@ -30,10 +30,10 @@ class TimerDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin):
"""
RESOURCE_TYPE = ResourceType.TIMER
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
name = me.StringField(required=True)
- pack = me.StringField(required=True, unique_with='name')
+ pack = me.StringField(required=True, unique_with="name")
type = me.StringField()
parameters = me.DictField()
diff --git a/st2common/st2common/models/db/trace.py b/st2common/st2common/models/db/trace.py
index 00b7010d91..fe358e90c9 100644
--- a/st2common/st2common/models/db/trace.py
+++ b/st2common/st2common/models/db/trace.py
@@ -25,25 +25,24 @@
from st2common.models.db import MongoDBAccess
-__all__ = [
- 'TraceDB',
- 'TraceComponentDB'
-]
+__all__ = ["TraceDB", "TraceComponentDB"]
class TraceComponentDB(me.EmbeddedDocument):
- """
- """
+ """"""
+
object_id = me.StringField()
- ref = me.StringField(default='')
+ ref = me.StringField(default="")
updated_at = ComplexDateTimeField(
default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the TraceComponent was included.')
- caused_by = me.DictField(help_text='Causal component.')
+ help_text="The timestamp when the TraceComponent was included.",
+ )
+ caused_by = me.DictField(help_text="Causal component.")
def __str__(self):
- return 'TraceComponentDB@(object_id:{}, updated_at:{})'.format(
- self.object_id, self.updated_at)
+ return "TraceComponentDB@(object_id:{}, updated_at:{})".format(
+ self.object_id, self.updated_at
+ )
class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin):
@@ -66,28 +65,37 @@ class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin):
RESOURCE_TYPE = ResourceType.TRACE
- trace_tag = me.StringField(required=True,
- help_text='A user specified reference to the trace.')
- trigger_instances = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB),
- required=False,
- help_text='Associated TriggerInstances.')
- rules = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB),
- required=False,
- help_text='Associated Rules.')
- action_executions = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB),
- required=False,
- help_text='Associated ActionExecutions.')
- start_timestamp = ComplexDateTimeField(default=date_utils.get_datetime_utc_now,
- help_text='The timestamp when the Trace was created.')
+ trace_tag = me.StringField(
+ required=True, help_text="A user specified reference to the trace."
+ )
+ trigger_instances = me.ListField(
+ field=me.EmbeddedDocumentField(TraceComponentDB),
+ required=False,
+ help_text="Associated TriggerInstances.",
+ )
+ rules = me.ListField(
+ field=me.EmbeddedDocumentField(TraceComponentDB),
+ required=False,
+ help_text="Associated Rules.",
+ )
+ action_executions = me.ListField(
+ field=me.EmbeddedDocumentField(TraceComponentDB),
+ required=False,
+ help_text="Associated ActionExecutions.",
+ )
+ start_timestamp = ComplexDateTimeField(
+ default=date_utils.get_datetime_utc_now,
+ help_text="The timestamp when the Trace was created.",
+ )
meta = {
- 'indexes': [
- {'fields': ['trace_tag']},
- {'fields': ['start_timestamp']},
- {'fields': ['action_executions.object_id']},
- {'fields': ['trigger_instances.object_id']},
- {'fields': ['rules.object_id']},
- {'fields': ['-start_timestamp', 'trace_tag']},
+ "indexes": [
+ {"fields": ["trace_tag"]},
+ {"fields": ["start_timestamp"]},
+ {"fields": ["action_executions.object_id"]},
+ {"fields": ["trigger_instances.object_id"]},
+ {"fields": ["rules.object_id"]},
+ {"fields": ["-start_timestamp", "trace_tag"]},
]
}
diff --git a/st2common/st2common/models/db/trigger.py b/st2common/st2common/models/db/trigger.py
index 0546c3b739..9b749c5241 100644
--- a/st2common/st2common/models/db/trigger.py
+++ b/st2common/st2common/models/db/trigger.py
@@ -24,16 +24,18 @@
from st2common.constants.types import ResourceType
__all__ = [
- 'TriggerTypeDB',
- 'TriggerDB',
- 'TriggerInstanceDB',
+ "TriggerTypeDB",
+ "TriggerDB",
+ "TriggerInstanceDB",
]
-class TriggerTypeDB(stormbase.StormBaseDB,
- stormbase.ContentPackResourceMixin,
- stormbase.UIDFieldMixin,
- stormbase.TagsMixin):
+class TriggerTypeDB(
+ stormbase.StormBaseDB,
+ stormbase.ContentPackResourceMixin,
+ stormbase.UIDFieldMixin,
+ stormbase.TagsMixin,
+):
"""Description of a specific kind/type of a trigger. The
(pack, name) tuple is expected uniquely identify a trigger in
the namespace of all triggers provided by a specific trigger_source.
@@ -45,18 +47,20 @@ class TriggerTypeDB(stormbase.StormBaseDB,
"""
RESOURCE_TYPE = ResourceType.TRIGGER_TYPE
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
ref = me.StringField(required=False)
name = me.StringField(required=True)
- pack = me.StringField(required=True, unique_with='name')
+ pack = me.StringField(required=True, unique_with="name")
payload_schema = me.DictField()
parameters_schema = me.DictField(default={})
meta = {
- 'indexes': (stormbase.ContentPackResourceMixin.get_indexes() +
- stormbase.TagsMixin.get_indexes() +
- stormbase.UIDFieldMixin.get_indexes())
+ "indexes": (
+ stormbase.ContentPackResourceMixin.get_indexes()
+ + stormbase.TagsMixin.get_indexes()
+ + stormbase.UIDFieldMixin.get_indexes()
+ )
}
def __init__(self, *args, **values):
@@ -66,8 +70,9 @@ def __init__(self, *args, **values):
self.uid = self.get_uid()
-class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin,
- stormbase.UIDFieldMixin):
+class TriggerDB(
+ stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin
+):
"""
Attribute:
name - Trigger name.
@@ -77,21 +82,22 @@ class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin,
"""
RESOURCE_TYPE = ResourceType.TRIGGER
- UID_FIELDS = ['pack', 'name']
+ UID_FIELDS = ["pack", "name"]
ref = me.StringField(required=False)
name = me.StringField(required=True)
- pack = me.StringField(required=True, unique_with='name')
+ pack = me.StringField(required=True, unique_with="name")
type = me.StringField()
parameters = me.DictField()
ref_count = me.IntField(default=0)
meta = {
- 'indexes': [
- {'fields': ['name']},
- {'fields': ['type']},
- {'fields': ['parameters']},
- ] + stormbase.UIDFieldMixin.get_indexes()
+ "indexes": [
+ {"fields": ["name"]},
+ {"fields": ["type"]},
+ {"fields": ["parameters"]},
+ ]
+ + stormbase.UIDFieldMixin.get_indexes()
}
def __init__(self, *args, **values):
@@ -106,7 +112,7 @@ def get_uid(self):
# Note: We sort the resulting JSON object so that the same dictionary always results
# in the same hash
- parameters = getattr(self, 'parameters', {})
+ parameters = getattr(self, "parameters", {})
parameters = json.dumps(parameters, sort_keys=True)
parameters = hashlib.md5(parameters.encode()).hexdigest()
@@ -126,19 +132,20 @@ class TriggerInstanceDB(stormbase.StormFoundationDB):
payload (dict): payload specific to the occurrence.
occurrence_time (datetime): time of occurrence of the trigger.
"""
+
trigger = me.StringField()
payload = stormbase.EscapedDictField()
occurrence_time = me.DateTimeField()
status = me.StringField(
- required=True,
- help_text='Processing status of TriggerInstance.')
+ required=True, help_text="Processing status of TriggerInstance."
+ )
meta = {
- 'indexes': [
- {'fields': ['occurrence_time']},
- {'fields': ['trigger']},
- {'fields': ['-occurrence_time', 'trigger']},
- {'fields': ['status']}
+ "indexes": [
+ {"fields": ["occurrence_time"]},
+ {"fields": ["trigger"]},
+ {"fields": ["-occurrence_time", "trigger"]},
+ {"fields": ["status"]},
]
}
diff --git a/st2common/st2common/models/db/webhook.py b/st2common/st2common/models/db/webhook.py
index 0ef2906b90..b608f6c355 100644
--- a/st2common/st2common/models/db/webhook.py
+++ b/st2common/st2common/models/db/webhook.py
@@ -29,7 +29,7 @@ class WebhookDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin):
"""
RESOURCE_TYPE = ResourceType.WEBHOOK
- UID_FIELDS = ['name']
+ UID_FIELDS = ["name"]
name = me.StringField(required=True)
@@ -40,7 +40,7 @@ def __init__(self, *args, **values):
def _normalize_name(self, name):
# Remove trailing slash if present
- if name.endswith('/'):
+ if name.endswith("/"):
name = name[:-1]
return name
diff --git a/st2common/st2common/models/db/workflow.py b/st2common/st2common/models/db/workflow.py
index dc73c1c55c..fd5cdb111e 100644
--- a/st2common/st2common/models/db/workflow.py
+++ b/st2common/st2common/models/db/workflow.py
@@ -24,16 +24,15 @@
from st2common.util import date as date_utils
-__all__ = [
- 'WorkflowExecutionDB',
- 'TaskExecutionDB'
-]
+__all__ = ["WorkflowExecutionDB", "TaskExecutionDB"]
LOG = logging.getLogger(__name__)
-class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin):
+class WorkflowExecutionDB(
+ stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin
+):
RESOURCE_TYPE = types.ResourceType.EXECUTION
action_execution = me.StringField(required=True)
@@ -46,14 +45,12 @@ class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionF
status = me.StringField(required=True)
output = stormbase.EscapedDictField()
errors = stormbase.EscapedDynamicField()
- start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now)
+ start_timestamp = db_field_types.ComplexDateTimeField(
+ default=date_utils.get_datetime_utc_now
+ )
end_timestamp = db_field_types.ComplexDateTimeField()
- meta = {
- 'indexes': [
- {'fields': ['action_execution']}
- ]
- }
+ meta = {"indexes": [{"fields": ["action_execution"]}]}
class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin):
@@ -71,21 +68,20 @@ class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionField
context = stormbase.EscapedDictField()
status = me.StringField(required=True)
result = stormbase.EscapedDictField()
- start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now)
+ start_timestamp = db_field_types.ComplexDateTimeField(
+ default=date_utils.get_datetime_utc_now
+ )
end_timestamp = db_field_types.ComplexDateTimeField()
meta = {
- 'indexes': [
- {'fields': ['workflow_execution']},
- {'fields': ['task_id']},
- {'fields': ['task_id', 'task_route']},
- {'fields': ['workflow_execution', 'task_id']},
- {'fields': ['workflow_execution', 'task_id', 'task_route']}
+ "indexes": [
+ {"fields": ["workflow_execution"]},
+ {"fields": ["task_id"]},
+ {"fields": ["task_id", "task_route"]},
+ {"fields": ["workflow_execution", "task_id"]},
+ {"fields": ["workflow_execution", "task_id", "task_route"]},
]
}
-MODELS = [
- WorkflowExecutionDB,
- TaskExecutionDB
-]
+MODELS = [WorkflowExecutionDB, TaskExecutionDB]
diff --git a/st2common/st2common/models/system/action.py b/st2common/st2common/models/system/action.py
index b5efe124f5..2afcbf649b 100644
--- a/st2common/st2common/models/system/action.py
+++ b/st2common/st2common/models/system/action.py
@@ -35,11 +35,11 @@
from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE
__all__ = [
- 'ShellCommandAction',
- 'ShellScriptAction',
- 'RemoteAction',
- 'RemoteScriptAction',
- 'ResolvedActionParameters'
+ "ShellCommandAction",
+ "ShellScriptAction",
+ "RemoteAction",
+ "RemoteScriptAction",
+ "ResolvedActionParameters",
]
LOG = logging.getLogger(__name__)
@@ -48,21 +48,31 @@
# Flags which are passed to every sudo invocation
SUDO_COMMON_OPTIONS = [
- '-E' # we want to preserve the environment of the user which ran sudo
-]
+ "-E"
+] # we want to preserve the environment of the user which ran sudo
# Flags which are only passed to sudo when not running as current user and when
# -u flag is used
SUDO_DIFFERENT_USER_OPTIONS = [
- '-H' # we want $HOME to reflect the home directory of the requested / target user
+ "-H" # we want $HOME to reflect the home directory of the requested / target user
]
class ShellCommandAction(object):
- EXPORT_CMD = 'export'
-
- def __init__(self, name, action_exec_id, command, user, env_vars=None, sudo=False,
- timeout=None, cwd=None, sudo_password=None):
+ EXPORT_CMD = "export"
+
+ def __init__(
+ self,
+ name,
+ action_exec_id,
+ command,
+ user,
+ env_vars=None,
+ sudo=False,
+ timeout=None,
+ cwd=None,
+ sudo_password=None,
+ ):
self.name = name
self.action_exec_id = action_exec_id
self.command = command
@@ -77,15 +87,15 @@ def get_full_command_string(self):
# Note: We pass -E to sudo because we want to preserve user provided environment variables
if self.sudo:
command = quote_unix(self.command)
- sudo_arguments = ' '.join(self._get_common_sudo_arguments())
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ sudo_arguments = " ".join(self._get_common_sudo_arguments())
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
else:
if self.user and self.user != LOGGED_USER_USERNAME:
# Need to use sudo to run as a different (requested) user
user = quote_unix(self.user)
- sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user))
+ sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user))
command = quote_unix(self.command)
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
else:
command = self.command
@@ -103,7 +113,10 @@ def get_sanitized_full_command_string(self):
if self.sudo_password:
# Mask sudo password
- command_string = 'echo -e \'%s\n\' | %s' % (MASKED_ATTRIBUTE_VALUE, command_string)
+ command_string = "echo -e '%s\n' | %s" % (
+ MASKED_ATTRIBUTE_VALUE,
+ command_string,
+ )
return command_string
@@ -124,7 +137,7 @@ def _get_common_sudo_arguments(self):
if self.sudo_password:
# Note: We use subprocess.Popen in local runner so we provide password via subprocess
# stdin (using echo -e won't work when using subprocess.Popen)
- flags.append('-S')
+ flags.append("-S")
flags = flags + SUDO_COMMON_OPTIONS
@@ -139,7 +152,7 @@ def _get_user_sudo_arguments(self, user):
"""
flags = self._get_common_sudo_arguments()
flags += SUDO_DIFFERENT_USER_OPTIONS
- flags += ['-u', user]
+ flags += ["-u", user]
return flags
@@ -150,21 +163,21 @@ def _get_env_vars_export_string(self):
# If sudo_password is provided, explicitly disable bash history to make sure password
# is not logged, because password is provided via command line
if self.sudo and self.sudo_password:
- env_vars['HISTFILE'] = '/dev/null'
- env_vars['HISTSIZE'] = '0'
+ env_vars["HISTFILE"] = "/dev/null"
+ env_vars["HISTSIZE"] = "0"
# Sort the dict to guarantee consistent order
env_vars = collections.OrderedDict(sorted(env_vars.items()))
# Environment variables could contain spaces and open us to shell
# injection attacks. Always quote the key and the value.
- exports = ' '.join(
- '%s=%s' % (quote_unix(k), quote_unix(v))
+ exports = " ".join(
+ "%s=%s" % (quote_unix(k), quote_unix(v))
for k, v in six.iteritems(env_vars)
)
- shell_env_str = '%s %s' % (ShellCommandAction.EXPORT_CMD, exports)
+ shell_env_str = "%s %s" % (ShellCommandAction.EXPORT_CMD, exports)
else:
- shell_env_str = ''
+ shell_env_str = ""
return shell_env_str
@@ -180,8 +193,8 @@ def _get_command_string(self, cmd, args):
assert isinstance(args, (list, tuple))
args = [quote_unix(arg) for arg in args]
- args = ' '.join(args)
- result = '%s %s' % (cmd, args)
+ args = " ".join(args)
+ result = "%s %s" % (cmd, args)
return result
def _get_error_result(self):
@@ -195,24 +208,42 @@ def _get_error_result(self):
_, exc_value, exc_traceback = sys.exc_info()
exc_value = str(exc_value)
- exc_traceback = ''.join(traceback.format_tb(exc_traceback))
+ exc_traceback = "".join(traceback.format_tb(exc_traceback))
result = {}
- result['failed'] = True
- result['succeeded'] = False
- result['error'] = exc_value
- result['traceback'] = exc_traceback
+ result["failed"] = True
+ result["succeeded"] = False
+ result["error"] = exc_value
+ result["traceback"] = exc_traceback
return result
class ShellScriptAction(ShellCommandAction):
- def __init__(self, name, action_exec_id, script_local_path_abs, named_args=None,
- positional_args=None, env_vars=None, user=None, sudo=False, timeout=None,
- cwd=None, sudo_password=None):
- super(ShellScriptAction, self).__init__(name=name, action_exec_id=action_exec_id,
- command=None, user=user, env_vars=env_vars,
- sudo=sudo, timeout=timeout,
- cwd=cwd, sudo_password=sudo_password)
+ def __init__(
+ self,
+ name,
+ action_exec_id,
+ script_local_path_abs,
+ named_args=None,
+ positional_args=None,
+ env_vars=None,
+ user=None,
+ sudo=False,
+ timeout=None,
+ cwd=None,
+ sudo_password=None,
+ ):
+ super(ShellScriptAction, self).__init__(
+ name=name,
+ action_exec_id=action_exec_id,
+ command=None,
+ user=user,
+ env_vars=env_vars,
+ sudo=sudo,
+ timeout=timeout,
+ cwd=cwd,
+ sudo_password=sudo_password,
+ )
self.script_local_path_abs = script_local_path_abs
self.named_args = named_args
self.positional_args = positional_args
@@ -221,33 +252,38 @@ def get_full_command_string(self):
return self._format_command()
def _format_command(self):
- script_arguments = self._get_script_arguments(named_args=self.named_args,
- positional_args=self.positional_args)
+ script_arguments = self._get_script_arguments(
+ named_args=self.named_args, positional_args=self.positional_args
+ )
if self.sudo:
if script_arguments:
- command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments))
+ command = quote_unix(
+ "%s %s" % (self.script_local_path_abs, script_arguments)
+ )
else:
command = quote_unix(self.script_local_path_abs)
- sudo_arguments = ' '.join(self._get_common_sudo_arguments())
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ sudo_arguments = " ".join(self._get_common_sudo_arguments())
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
else:
if self.user and self.user != LOGGED_USER_USERNAME:
# Need to use sudo to run as a different user
user = quote_unix(self.user)
if script_arguments:
- command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments))
+ command = quote_unix(
+ "%s %s" % (self.script_local_path_abs, script_arguments)
+ )
else:
command = quote_unix(self.script_local_path_abs)
- sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user))
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user))
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
else:
script_path = quote_unix(self.script_local_path_abs)
if script_arguments:
- command = '%s %s' % (script_path, script_arguments)
+ command = "%s %s" % (script_path, script_arguments)
else:
command = script_path
return command
@@ -270,8 +306,10 @@ def _get_script_arguments(self, named_args=None, positional_args=None):
# add all named_args in the format name=value (e.g. --name=value)
if named_args is not None:
for (arg, value) in six.iteritems(named_args):
- if value is None or (isinstance(value, (str, six.text_type)) and len(value) < 1):
- LOG.debug('Ignoring arg %s as its value is %s.', arg, value)
+ if value is None or (
+ isinstance(value, (str, six.text_type)) and len(value) < 1
+ ):
+ LOG.debug("Ignoring arg %s as its value is %s.", arg, value)
continue
if isinstance(value, bool):
@@ -279,24 +317,45 @@ def _get_script_arguments(self, named_args=None, positional_args=None):
command_parts.append(arg)
else:
values = (quote_unix(arg), quote_unix(six.text_type(value)))
- command_parts.append(six.text_type('%s=%s' % values))
+ command_parts.append(six.text_type("%s=%s" % values))
# add the positional args
if positional_args:
quoted_pos_args = [quote_unix(pos_arg) for pos_arg in positional_args]
- pos_args_string = ' '.join(quoted_pos_args)
+ pos_args_string = " ".join(quoted_pos_args)
command_parts.append(pos_args_string)
- return ' '.join(command_parts)
+ return " ".join(command_parts)
class SSHCommandAction(ShellCommandAction):
- def __init__(self, name, action_exec_id, command, env_vars, user, password=None, pkey=None,
- hosts=None, parallel=True, sudo=False, timeout=None, cwd=None, passphrase=None,
- sudo_password=None):
- super(SSHCommandAction, self).__init__(name=name, action_exec_id=action_exec_id,
- command=command, env_vars=env_vars, user=user,
- sudo=sudo, timeout=timeout, cwd=cwd,
- sudo_password=sudo_password)
+ def __init__(
+ self,
+ name,
+ action_exec_id,
+ command,
+ env_vars,
+ user,
+ password=None,
+ pkey=None,
+ hosts=None,
+ parallel=True,
+ sudo=False,
+ timeout=None,
+ cwd=None,
+ passphrase=None,
+ sudo_password=None,
+ ):
+ super(SSHCommandAction, self).__init__(
+ name=name,
+ action_exec_id=action_exec_id,
+ command=command,
+ env_vars=env_vars,
+ user=user,
+ sudo=sudo,
+ timeout=timeout,
+ cwd=cwd,
+ sudo_password=sudo_password,
+ )
self.hosts = hosts
self.parallel = parallel
self.pkey = pkey
@@ -329,25 +388,51 @@ def get_command(self):
def __str__(self):
str_rep = []
- str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name))
- str_rep.append('id: %s' % self.action_exec_id)
- str_rep.append('command: %s' % self.command)
- str_rep.append('user: %s' % self.user)
- str_rep.append('sudo: %s' % str(self.sudo))
- str_rep.append('parallel: %s' % str(self.parallel))
- str_rep.append('hosts: %s)' % str(self.hosts))
- return ', '.join(str_rep)
+ str_rep.append(
+ "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name)
+ )
+ str_rep.append("id: %s" % self.action_exec_id)
+ str_rep.append("command: %s" % self.command)
+ str_rep.append("user: %s" % self.user)
+ str_rep.append("sudo: %s" % str(self.sudo))
+ str_rep.append("parallel: %s" % str(self.parallel))
+ str_rep.append("hosts: %s)" % str(self.hosts))
+ return ", ".join(str_rep)
class RemoteAction(SSHCommandAction):
- def __init__(self, name, action_exec_id, command, env_vars=None, on_behalf_user=None,
- user=None, password=None, private_key=None, hosts=None, parallel=True, sudo=False,
- timeout=None, cwd=None, passphrase=None, sudo_password=None):
- super(RemoteAction, self).__init__(name=name, action_exec_id=action_exec_id,
- command=command, env_vars=env_vars, user=user,
- hosts=hosts, parallel=parallel, sudo=sudo,
- timeout=timeout, cwd=cwd, passphrase=passphrase,
- sudo_password=sudo_password)
+ def __init__(
+ self,
+ name,
+ action_exec_id,
+ command,
+ env_vars=None,
+ on_behalf_user=None,
+ user=None,
+ password=None,
+ private_key=None,
+ hosts=None,
+ parallel=True,
+ sudo=False,
+ timeout=None,
+ cwd=None,
+ passphrase=None,
+ sudo_password=None,
+ ):
+ super(RemoteAction, self).__init__(
+ name=name,
+ action_exec_id=action_exec_id,
+ command=command,
+ env_vars=env_vars,
+ user=user,
+ hosts=hosts,
+ parallel=parallel,
+ sudo=sudo,
+ timeout=timeout,
+ cwd=cwd,
+ passphrase=passphrase,
+ sudo_password=sudo_password,
+ )
self.password = password
self.private_key = private_key
self.passphrase = passphrase
@@ -359,34 +444,61 @@ def get_on_behalf_user(self):
def __str__(self):
str_rep = []
- str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name))
- str_rep.append('id: %s' % self.action_exec_id)
- str_rep.append('command: %s' % self.command)
- str_rep.append('user: %s' % self.user)
- str_rep.append('on_behalf_user: %s' % self.on_behalf_user)
- str_rep.append('sudo: %s' % str(self.sudo))
- str_rep.append('parallel: %s' % str(self.parallel))
- str_rep.append('hosts: %s)' % str(self.hosts))
- str_rep.append('timeout: %s)' % str(self.timeout))
+ str_rep.append(
+ "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name)
+ )
+ str_rep.append("id: %s" % self.action_exec_id)
+ str_rep.append("command: %s" % self.command)
+ str_rep.append("user: %s" % self.user)
+ str_rep.append("on_behalf_user: %s" % self.on_behalf_user)
+ str_rep.append("sudo: %s" % str(self.sudo))
+ str_rep.append("parallel: %s" % str(self.parallel))
+ str_rep.append("hosts: %s)" % str(self.hosts))
+ str_rep.append("timeout: %s)" % str(self.timeout))
- return ', '.join(str_rep)
+ return ", ".join(str_rep)
class RemoteScriptAction(ShellScriptAction):
- def __init__(self, name, action_exec_id, script_local_path_abs, script_local_libs_path_abs,
- named_args=None, positional_args=None, env_vars=None, on_behalf_user=None,
- user=None, password=None, private_key=None, remote_dir=None, hosts=None,
- parallel=True, sudo=False, timeout=None, cwd=None, sudo_password=None):
- super(RemoteScriptAction, self).__init__(name=name, action_exec_id=action_exec_id,
- script_local_path_abs=script_local_path_abs,
- user=user,
- named_args=named_args,
- positional_args=positional_args, env_vars=env_vars,
- sudo=sudo, timeout=timeout, cwd=cwd,
- sudo_password=sudo_password)
+ def __init__(
+ self,
+ name,
+ action_exec_id,
+ script_local_path_abs,
+ script_local_libs_path_abs,
+ named_args=None,
+ positional_args=None,
+ env_vars=None,
+ on_behalf_user=None,
+ user=None,
+ password=None,
+ private_key=None,
+ remote_dir=None,
+ hosts=None,
+ parallel=True,
+ sudo=False,
+ timeout=None,
+ cwd=None,
+ sudo_password=None,
+ ):
+ super(RemoteScriptAction, self).__init__(
+ name=name,
+ action_exec_id=action_exec_id,
+ script_local_path_abs=script_local_path_abs,
+ user=user,
+ named_args=named_args,
+ positional_args=positional_args,
+ env_vars=env_vars,
+ sudo=sudo,
+ timeout=timeout,
+ cwd=cwd,
+ sudo_password=sudo_password,
+ )
self.script_local_libs_path_abs = script_local_libs_path_abs
- self.script_local_dir, self.script_name = os.path.split(self.script_local_path_abs)
- self.remote_dir = remote_dir if remote_dir is not None else '/tmp'
+ self.script_local_dir, self.script_name = os.path.split(
+ self.script_local_path_abs
+ )
+ self.remote_dir = remote_dir if remote_dir is not None else "/tmp"
self.remote_libs_path_abs = os.path.join(self.remote_dir, ACTION_LIBS_DIR)
self.on_behalf_user = on_behalf_user
self.password = password
@@ -395,7 +507,7 @@ def __init__(self, name, action_exec_id, script_local_path_abs, script_local_lib
self.hosts = hosts
self.parallel = parallel
self.command = self._format_command()
- LOG.debug('RemoteScriptAction: command to run on remote box: %s', self.command)
+ LOG.debug("RemoteScriptAction: command to run on remote box: %s", self.command)
def get_remote_script_abs_path(self):
return self.remote_script
@@ -413,11 +525,12 @@ def get_remote_base_dir(self):
return self.remote_dir
def _format_command(self):
- script_arguments = self._get_script_arguments(named_args=self.named_args,
- positional_args=self.positional_args)
+ script_arguments = self._get_script_arguments(
+ named_args=self.named_args, positional_args=self.positional_args
+ )
if script_arguments:
- command = '%s %s' % (self.remote_script, script_arguments)
+ command = "%s %s" % (self.remote_script, script_arguments)
else:
command = self.remote_script
@@ -425,21 +538,23 @@ def _format_command(self):
def __str__(self):
str_rep = []
- str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name))
- str_rep.append('id: %s' % self.action_exec_id)
- str_rep.append('local_script: %s' % self.script_local_path_abs)
- str_rep.append('local_libs: %s' % self.script_local_libs_path_abs)
- str_rep.append('remote_dir: %s' % self.remote_dir)
- str_rep.append('remote_libs: %s' % self.remote_libs_path_abs)
- str_rep.append('named_args: %s' % self.named_args)
- str_rep.append('positional_args: %s' % self.positional_args)
- str_rep.append('user: %s' % self.user)
- str_rep.append('on_behalf_user: %s' % self.on_behalf_user)
- str_rep.append('sudo: %s' % self.sudo)
- str_rep.append('parallel: %s' % self.parallel)
- str_rep.append('hosts: %s)' % self.hosts)
-
- return ', '.join(str_rep)
+ str_rep.append(
+ "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name)
+ )
+ str_rep.append("id: %s" % self.action_exec_id)
+ str_rep.append("local_script: %s" % self.script_local_path_abs)
+ str_rep.append("local_libs: %s" % self.script_local_libs_path_abs)
+ str_rep.append("remote_dir: %s" % self.remote_dir)
+ str_rep.append("remote_libs: %s" % self.remote_libs_path_abs)
+ str_rep.append("named_args: %s" % self.named_args)
+ str_rep.append("positional_args: %s" % self.positional_args)
+ str_rep.append("user: %s" % self.user)
+ str_rep.append("on_behalf_user: %s" % self.on_behalf_user)
+ str_rep.append("sudo: %s" % self.sudo)
+ str_rep.append("parallel: %s" % self.parallel)
+ str_rep.append("hosts: %s)" % self.hosts)
+
+ return ", ".join(str_rep)
class ResolvedActionParameters(DictSerializableClassMixin):
@@ -447,7 +562,9 @@ class ResolvedActionParameters(DictSerializableClassMixin):
Class which contains resolved runner and action parameters for a particular action.
"""
- def __init__(self, action_db, runner_type_db, runner_parameters=None, action_parameters=None):
+ def __init__(
+ self, action_db, runner_type_db, runner_parameters=None, action_parameters=None
+ ):
self._action_db = action_db
self._runner_type_db = runner_type_db
self._runner_parameters = runner_parameters
@@ -456,28 +573,34 @@ def __init__(self, action_db, runner_type_db, runner_parameters=None, action_par
def mask_secrets(self, value):
result = copy.deepcopy(value)
- runner_parameters = result['runner_parameters']
- action_parameters = result['action_parameters']
+ runner_parameters = result["runner_parameters"]
+ action_parameters = result["action_parameters"]
runner_parameters_specs = self._runner_type_db.runner_parameters
action_parameters_sepcs = self._action_db.parameters
- secret_runner_parameters = get_secret_parameters(parameters=runner_parameters_specs)
- secret_action_parameters = get_secret_parameters(parameters=action_parameters_sepcs)
-
- runner_parameters = mask_secret_parameters(parameters=runner_parameters,
- secret_parameters=secret_runner_parameters)
- action_parameters = mask_secret_parameters(parameters=action_parameters,
- secret_parameters=secret_action_parameters)
- result['runner_parameters'] = runner_parameters
- result['action_parameters'] = action_parameters
+ secret_runner_parameters = get_secret_parameters(
+ parameters=runner_parameters_specs
+ )
+ secret_action_parameters = get_secret_parameters(
+ parameters=action_parameters_sepcs
+ )
+
+ runner_parameters = mask_secret_parameters(
+ parameters=runner_parameters, secret_parameters=secret_runner_parameters
+ )
+ action_parameters = mask_secret_parameters(
+ parameters=action_parameters, secret_parameters=secret_action_parameters
+ )
+ result["runner_parameters"] = runner_parameters
+ result["action_parameters"] = action_parameters
return result
def to_serializable_dict(self, mask_secrets=False):
result = {}
- result['runner_parameters'] = self._runner_parameters
- result['action_parameters'] = self._action_parameters
+ result["runner_parameters"] = self._runner_parameters
+ result["action_parameters"] = self._action_parameters
if mask_secrets and cfg.CONF.log.mask_secrets:
result = self.mask_secrets(value=result)
diff --git a/st2common/st2common/models/system/actionchain.py b/st2common/st2common/models/system/actionchain.py
index 2c5ce24c3d..24a84cc6b6 100644
--- a/st2common/st2common/models/system/actionchain.py
+++ b/st2common/st2common/models/system/actionchain.py
@@ -31,45 +31,45 @@ class Node(object):
"name": {
"description": "The name of this node.",
"type": "string",
- "required": True
+ "required": True,
},
"ref": {
"type": "string",
"description": "Ref of the action to be executed.",
- "required": True
+ "required": True,
},
"params": {
"type": "object",
- "description": ("Parameter for the execution (old name, here for backward "
- "compatibility reasons)."),
- "default": {}
+ "description": (
+ "Parameter for the execution (old name, here for backward "
+ "compatibility reasons)."
+ ),
+ "default": {},
},
"parameters": {
"type": "object",
"description": "Parameter for the execution.",
- "default": {}
+ "default": {},
},
"on-success": {
"type": "string",
"description": "Name of the node to invoke on successful completion of action"
- " executed for this node.",
- "default": ""
+ " executed for this node.",
+ "default": "",
},
"on-failure": {
"type": "string",
"description": "Name of the node to invoke on failure of action executed for this"
- " node.",
- "default": ""
+ " node.",
+ "default": "",
},
"publish": {
"description": "The variables to publish from the result. Should be of the form"
- " name.foo. o1: {{node_name.foo}} will result in creation of a"
- " variable o1 which is now available for reference through"
- " remainder of the chain as a global variable.",
+ " name.foo. o1: {{node_name.foo}} will result in creation of a"
+ " variable o1 which is now available for reference through"
+ " remainder of the chain as a global variable.",
"type": "object",
- "patternProperties": {
- r"^\w+$": {}
- }
+ "patternProperties": {r"^\w+$": {}},
},
"notify": {
"description": "Notification settings for action.",
@@ -77,43 +77,49 @@ class Node(object):
"properties": {
"on-complete": NotificationSubSchemaAPI,
"on-failure": NotificationSubSchemaAPI,
- "on-success": NotificationSubSchemaAPI
+ "on-success": NotificationSubSchemaAPI,
},
- "additionalProperties": False
- }
+ "additionalProperties": False,
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def __init__(self, **kw):
- for prop in six.iterkeys(self.schema.get('properties', [])):
+ for prop in six.iterkeys(self.schema.get("properties", [])):
value = kw.get(prop, None)
# having '-' in the property name lead to challenges in referencing the property.
# At hindsight the schema property should've been on_success rather than on-success.
- prop = prop.replace('-', '_')
+ prop = prop.replace("-", "_")
setattr(self, prop, value)
def validate(self):
- params = getattr(self, 'params', {})
- parameters = getattr(self, 'parameters', {})
+ params = getattr(self, "params", {})
+ parameters = getattr(self, "parameters", {})
if params and parameters:
- msg = ('Either "params" or "parameters" attribute needs to be provided, but not '
- 'both')
+ msg = (
+ 'Either "params" or "parameters" attribute needs to be provided, but not '
+ "both"
+ )
raise ValueError(msg)
return self
def get_parameters(self):
# Note: "params" is old deprecated attribute which will be removed in a future release
- params = getattr(self, 'params', {})
- parameters = getattr(self, 'parameters', {})
+ params = getattr(self, "params", {})
+ parameters = getattr(self, "parameters", {})
return parameters or params
def __repr__(self):
- return ('' %
- (self.name, self.ref, self.on_success, self.on_failure))
+ return "" % (
+ self.name,
+ self.ref,
+ self.on_success,
+ self.on_failure,
+ )
class ActionChain(object):
@@ -127,31 +133,34 @@ class ActionChain(object):
"description": "The chain.",
"type": "array",
"items": [Node.schema],
- "required": True
+ "required": True,
},
"default": {
"type": "string",
- "description": "name of the action to be executed."
+ "description": "name of the action to be executed.",
},
"vars": {
"description": "",
"type": "object",
- "patternProperties": {
- r"^\w+$": {}
- }
- }
+ "patternProperties": {r"^\w+$": {}},
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
def __init__(self, **kw):
- util_schema.validate(instance=kw, schema=self.schema, cls=util_schema.CustomValidator,
- use_default=False, allow_default_none=True)
-
- for prop in six.iterkeys(self.schema.get('properties', [])):
+ util_schema.validate(
+ instance=kw,
+ schema=self.schema,
+ cls=util_schema.CustomValidator,
+ use_default=False,
+ allow_default_none=True,
+ )
+
+ for prop in six.iterkeys(self.schema.get("properties", [])):
value = kw.get(prop, None)
# special handling for chain property to create the Node object
- if prop == 'chain':
+ if prop == "chain":
nodes = []
for node in value:
ac_node = Node(**node)
diff --git a/st2common/st2common/models/system/common.py b/st2common/st2common/models/system/common.py
index a56f6701ac..72ad6c3f84 100644
--- a/st2common/st2common/models/system/common.py
+++ b/st2common/st2common/models/system/common.py
@@ -14,17 +14,17 @@
# limitations under the License.
__all__ = [
- 'InvalidReferenceError',
- 'InvalidResourceReferenceError',
- 'ResourceReference',
+ "InvalidReferenceError",
+ "InvalidResourceReferenceError",
+ "ResourceReference",
]
-PACK_SEPARATOR = '.'
+PACK_SEPARATOR = "."
class InvalidReferenceError(ValueError):
def __init__(self, ref):
- message = 'Invalid reference: %s' % (ref)
+ message = "Invalid reference: %s" % (ref)
self.ref = ref
self.message = message
super(InvalidReferenceError, self).__init__(message)
@@ -32,7 +32,7 @@ def __init__(self, ref):
class InvalidResourceReferenceError(ValueError):
def __init__(self, ref):
- message = 'Invalid resource reference: %s' % (ref)
+ message = "Invalid resource reference: %s" % (ref)
self.ref = ref
self.message = message
super(InvalidResourceReferenceError, self).__init__(message)
@@ -42,6 +42,7 @@ class ResourceReference(object):
"""
Class used for referring to resources which belong to a content pack.
"""
+
def __init__(self, pack=None, name=None):
self.pack = self.validate_pack_name(pack=pack)
self.name = name
@@ -72,8 +73,10 @@ def to_string_reference(pack=None, name=None):
pack = ResourceReference.validate_pack_name(pack=pack)
return PACK_SEPARATOR.join([pack, name])
else:
- raise ValueError('Both pack and name needed for building ref. pack=%s, name=%s' %
- (pack, name))
+ raise ValueError(
+ "Both pack and name needed for building ref. pack=%s, name=%s"
+ % (pack, name)
+ )
@staticmethod
def validate_pack_name(pack):
@@ -97,5 +100,8 @@ def get_name(ref):
raise InvalidResourceReferenceError(ref=ref)
def __repr__(self):
- return ('' %
- (self.pack, self.name, self.ref))
+ return "" % (
+ self.pack,
+ self.name,
+ self.ref,
+ )
diff --git a/st2common/st2common/models/system/keyvalue.py b/st2common/st2common/models/system/keyvalue.py
index 0bac5949d8..018df95602 100644
--- a/st2common/st2common/models/system/keyvalue.py
+++ b/st2common/st2common/models/system/keyvalue.py
@@ -17,13 +17,13 @@
from st2common.constants.keyvalue import USER_SEPARATOR
__all__ = [
- 'InvalidUserKeyReferenceError',
+ "InvalidUserKeyReferenceError",
]
class InvalidUserKeyReferenceError(ValueError):
def __init__(self, ref):
- message = 'Invalid resource reference: %s' % (ref)
+ message = "Invalid resource reference: %s" % (ref)
self.ref = ref
self.message = message
super(InvalidUserKeyReferenceError, self).__init__(message)
@@ -38,7 +38,7 @@ class UserKeyReference(object):
def __init__(self, user, name):
self._user = user
self._name = name
- self.ref = ('%s%s%s' % (self._user, USER_SEPARATOR, self._name))
+ self.ref = "%s%s%s" % (self._user, USER_SEPARATOR, self._name)
def __str__(self):
return self.ref
diff --git a/st2common/st2common/models/system/paramiko_command_action.py b/st2common/st2common/models/system/paramiko_command_action.py
index a96183ef9e..685ffeb67c 100644
--- a/st2common/st2common/models/system/paramiko_command_action.py
+++ b/st2common/st2common/models/system/paramiko_command_action.py
@@ -23,7 +23,7 @@
from st2common.util.shell import quote_unix
__all__ = [
- 'ParamikoRemoteCommandAction',
+ "ParamikoRemoteCommandAction",
]
LOG = logging.getLogger(__name__)
@@ -32,7 +32,6 @@
class ParamikoRemoteCommandAction(RemoteAction):
-
def get_full_command_string(self):
# Note: We pass -E to sudo because we want to preserve user provided environment variables
env_str = self._get_env_vars_export_string()
@@ -40,24 +39,25 @@ def get_full_command_string(self):
if self.sudo:
if env_str:
- command = quote_unix('%s && cd %s && %s' % (env_str, cwd, self.command))
+ command = quote_unix("%s && cd %s && %s" % (env_str, cwd, self.command))
else:
- command = quote_unix('cd %s && %s' % (cwd, self.command))
+ command = quote_unix("cd %s && %s" % (cwd, self.command))
- sudo_arguments = ' '.join(self._get_common_sudo_arguments())
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ sudo_arguments = " ".join(self._get_common_sudo_arguments())
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
if self.sudo_password:
- command = ('set +o history ; echo -e %s | %s' %
- (quote_unix('%s\n' % (self.sudo_password)), command))
+ command = "set +o history ; echo -e %s | %s" % (
+ quote_unix("%s\n" % (self.sudo_password)),
+ command,
+ )
else:
if env_str:
- command = '%s && cd %s && %s' % (env_str, cwd,
- self.command)
+ command = "%s && cd %s && %s" % (env_str, cwd, self.command)
else:
- command = 'cd %s && %s' % (cwd, self.command)
+ command = "cd %s && %s" % (cwd, self.command)
- LOG.debug('Command to run on remote host will be: %s', command)
+ LOG.debug("Command to run on remote host will be: %s", command)
return command
def _get_common_sudo_arguments(self):
@@ -69,7 +69,7 @@ def _get_common_sudo_arguments(self):
flags = []
if self.sudo_password:
- flags.append('-S')
+ flags.append("-S")
flags = flags + SUDO_COMMON_OPTIONS
diff --git a/st2common/st2common/models/system/paramiko_script_action.py b/st2common/st2common/models/system/paramiko_script_action.py
index a6ff26a751..284e87a708 100644
--- a/st2common/st2common/models/system/paramiko_script_action.py
+++ b/st2common/st2common/models/system/paramiko_script_action.py
@@ -20,7 +20,7 @@
from st2common.util.shell import quote_unix
__all__ = [
- 'ParamikoRemoteScriptAction',
+ "ParamikoRemoteScriptAction",
]
@@ -28,10 +28,10 @@
class ParamikoRemoteScriptAction(RemoteScriptAction):
-
def _format_command(self):
- script_arguments = self._get_script_arguments(named_args=self.named_args,
- positional_args=self.positional_args)
+ script_arguments = self._get_script_arguments(
+ named_args=self.named_args, positional_args=self.positional_args
+ )
env_str = self._get_env_vars_export_string()
cwd = quote_unix(self.get_cwd())
script_path = quote_unix(self.remote_script)
@@ -39,36 +39,46 @@ def _format_command(self):
if self.sudo:
if script_arguments:
if env_str:
- command = quote_unix('%s && cd %s && %s %s' % (
- env_str, cwd, script_path, script_arguments))
+ command = quote_unix(
+ "%s && cd %s && %s %s"
+ % (env_str, cwd, script_path, script_arguments)
+ )
else:
- command = quote_unix('cd %s && %s %s' % (
- cwd, script_path, script_arguments))
+ command = quote_unix(
+ "cd %s && %s %s" % (cwd, script_path, script_arguments)
+ )
else:
if env_str:
- command = quote_unix('%s && cd %s && %s' % (
- env_str, cwd, script_path))
+ command = quote_unix(
+ "%s && cd %s && %s" % (env_str, cwd, script_path)
+ )
else:
- command = quote_unix('cd %s && %s' % (cwd, script_path))
+ command = quote_unix("cd %s && %s" % (cwd, script_path))
- sudo_arguments = ' '.join(self._get_common_sudo_arguments())
- command = 'sudo %s -- bash -c %s' % (sudo_arguments, command)
+ sudo_arguments = " ".join(self._get_common_sudo_arguments())
+ command = "sudo %s -- bash -c %s" % (sudo_arguments, command)
if self.sudo_password:
- command = ('set +o history ; echo -e %s | %s' %
- (quote_unix('%s\n' % (self.sudo_password)), command))
+ command = "set +o history ; echo -e %s | %s" % (
+ quote_unix("%s\n" % (self.sudo_password)),
+ command,
+ )
else:
if script_arguments:
if env_str:
- command = '%s && cd %s && %s %s' % (env_str, cwd,
- script_path, script_arguments)
+ command = "%s && cd %s && %s %s" % (
+ env_str,
+ cwd,
+ script_path,
+ script_arguments,
+ )
else:
- command = 'cd %s && %s %s' % (cwd, script_path, script_arguments)
+ command = "cd %s && %s %s" % (cwd, script_path, script_arguments)
else:
if env_str:
- command = '%s && cd %s && %s' % (env_str, cwd, script_path)
+ command = "%s && cd %s && %s" % (env_str, cwd, script_path)
else:
- command = 'cd %s && %s' % (cwd, script_path)
+ command = "cd %s && %s" % (cwd, script_path)
return command
@@ -81,7 +91,7 @@ def _get_common_sudo_arguments(self):
flags = []
if self.sudo_password:
- flags.append('-S')
+ flags.append("-S")
flags = flags + SUDO_COMMON_OPTIONS
diff --git a/st2common/st2common/models/utils/action_alias_utils.py b/st2common/st2common/models/utils/action_alias_utils.py
index 06106a2794..bf6d47c8b4 100644
--- a/st2common/st2common/models/utils/action_alias_utils.py
+++ b/st2common/st2common/models/utils/action_alias_utils.py
@@ -18,9 +18,15 @@
import re
import sys
-from sre_parse import ( # pylint: disable=E0611
- parse, AT, AT_BEGINNING, AT_BEGINNING_STRING,
- AT_END, AT_END_STRING, BRANCH, SUBPATTERN,
+from sre_parse import ( # pylint: disable=E0611
+ parse,
+ AT,
+ AT_BEGINNING,
+ AT_BEGINNING_STRING,
+ AT_END,
+ AT_END_STRING,
+ BRANCH,
+ SUBPATTERN,
)
from st2common.util.jinja import render_values
@@ -30,11 +36,10 @@
from st2common import log
__all__ = [
- 'ActionAliasFormatParser',
-
- 'extract_parameters_for_action_alias_db',
- 'extract_parameters',
- 'search_regex_tokens',
+ "ActionAliasFormatParser",
+ "extract_parameters_for_action_alias_db",
+ "extract_parameters",
+ "search_regex_tokens",
]
@@ -48,10 +53,9 @@
class ActionAliasFormatParser(object):
-
def __init__(self, alias_format=None, param_stream=None):
- self._format = alias_format or ''
- self._original_param_stream = param_stream or ''
+ self._format = alias_format or ""
+ self._original_param_stream = param_stream or ""
self._param_stream = self._original_param_stream
self._snippets = self.generate_snippets()
@@ -76,26 +80,26 @@ def generate_snippets(self):
# Formats for keys and values: key is a non-spaced string,
# value is anything in quotes or curly braces, or a single word.
- snippets['key'] = r'\s*(\S+?)\s*'
- snippets['value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)'
+ snippets["key"] = r"\s*(\S+?)\s*"
+ snippets["value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)'
# Extended value: also matches unquoted text (caution).
- snippets['ext_value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)'
+ snippets["ext_value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)'
# Key-value pair:
- snippets['pairs'] = r'(?:^|\s+){key}=({value})'.format(**snippets)
+ snippets["pairs"] = r"(?:^|\s+){key}=({value})".format(**snippets)
# End of string: multiple space-separated key-value pairs:
- snippets['ending'] = r'.*?(({pairs}\s*)*)$'.format(**snippets)
+ snippets["ending"] = r".*?(({pairs}\s*)*)$".format(**snippets)
# Default value in optional parameters:
- snippets['default'] = r'\s*=\s*(?:{ext_value})\s*'.format(**snippets)
+ snippets["default"] = r"\s*=\s*(?:{ext_value})\s*".format(**snippets)
# Optional parameter (has a default value):
- snippets['optional'] = '{{' + snippets['key'] + snippets['default'] + '}}'
+ snippets["optional"] = "{{" + snippets["key"] + snippets["default"] + "}}"
# Required parameter (no default value):
- snippets['required'] = '{{' + snippets['key'] + '}}'
+ snippets["required"] = "{{" + snippets["key"] + "}}"
return snippets
@@ -105,11 +109,13 @@ def match_kv_pairs_at_end(self):
# 1. Matching the arbitrary key-value pairs at the end of the command
# to support extra parameters (not specified in the format string),
# and cutting them from the command string afterwards.
- ending_pairs = re.match(self._snippets['ending'], param_stream, re.DOTALL)
+ ending_pairs = re.match(self._snippets["ending"], param_stream, re.DOTALL)
has_ending_pairs = ending_pairs and ending_pairs.group(1)
if has_ending_pairs:
- kv_pairs = re.findall(self._snippets['pairs'], ending_pairs.group(1), re.DOTALL)
- param_stream = param_stream.replace(ending_pairs.group(1), '')
+ kv_pairs = re.findall(
+ self._snippets["pairs"], ending_pairs.group(1), re.DOTALL
+ )
+ param_stream = param_stream.replace(ending_pairs.group(1), "")
else:
kv_pairs = []
param_stream = " %s " % (param_stream)
@@ -118,27 +124,36 @@ def match_kv_pairs_at_end(self):
def generate_optional_params_regex(self):
# 2. Matching optional parameters (with default values).
- return re.findall(self._snippets['optional'], self._format, re.DOTALL)
+ return re.findall(self._snippets["optional"], self._format, re.DOTALL)
def transform_format_string_into_regex(self):
# 3. Convert the mangled format string into a regex object
# Transforming our format string into a regular expression,
# substituting {{ ... }} with regex named groups, so that param_stream
# matched against this expression yields a dict of params with values.
- param_match = r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?'
- reg = re.sub(r'(\s*)' + self._snippets['optional'], r'(?:' + param_match + r')?',
- self._format)
- reg = re.sub(r'(\s*)' + self._snippets['required'], param_match, reg)
+ param_match = (
+ r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?'
+ )
+ reg = re.sub(
+ r"(\s*)" + self._snippets["optional"],
+ r"(?:" + param_match + r")?",
+ self._format,
+ )
+ reg = re.sub(r"(\s*)" + self._snippets["required"], param_match, reg)
reg_tokens = parse(reg, flags=re.DOTALL)
# Add a beginning anchor if none exists
- if not search_regex_tokens(((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens):
- reg = r'^\s*' + reg
+ if not search_regex_tokens(
+ ((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens
+ ):
+ reg = r"^\s*" + reg
# Add an ending anchor if none exists
- if not search_regex_tokens(((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True):
- reg = reg + r'\s*$'
+ if not search_regex_tokens(
+ ((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True
+ ):
+ reg = reg + r"\s*$"
return re.compile(reg, re.DOTALL)
@@ -147,8 +162,10 @@ def match_params_in_stream(self, matched_stream):
if not matched_stream:
# If no match is found we throw since this indicates provided user string (command)
# didn't match the provided format string
- raise ParseException('Command "%s" doesn\'t match format string "%s"' %
- (self._original_param_stream, self._format))
+ raise ParseException(
+ 'Command "%s" doesn\'t match format string "%s"'
+ % (self._original_param_stream, self._format)
+ )
# Compiling results from the steps 1-3.
if matched_stream:
@@ -157,16 +174,16 @@ def match_params_in_stream(self, matched_stream):
# Apply optional parameters/add the default parameters
for param in self._optional:
matched_value = result[param[0]] if matched_stream else None
- matched_result = matched_value or ''.join(param[1:])
+ matched_result = matched_value or "".join(param[1:])
if matched_result is not None:
result[param[0]] = matched_result
# Apply given parameters
for pair in self._kv_pairs:
- result[pair[0]] = ''.join(pair[2:])
+ result[pair[0]] = "".join(pair[2:])
if self._format and not (self._param_stream.strip() or any(result.values())):
- raise ParseException('No value supplied and no default value found.')
+ raise ParseException("No value supplied and no default value found.")
return result
@@ -196,8 +213,9 @@ def get_multiple_extracted_param_value(self):
return results
-def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_stream,
- match_multiple=False):
+def extract_parameters_for_action_alias_db(
+ action_alias_db, format_str, param_stream, match_multiple=False
+):
"""
Extract parameters from the user input based on the provided format string.
@@ -208,13 +226,14 @@ def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_st
formats = action_alias_db.get_format_strings()
if format_str not in formats:
- raise ValueError('Format string "%s" is not available on the alias "%s"' %
- (format_str, action_alias_db.name))
+ raise ValueError(
+ 'Format string "%s" is not available on the alias "%s"'
+ % (format_str, action_alias_db.name)
+ )
result = extract_parameters(
- format_str=format_str,
- param_stream=param_stream,
- match_multiple=match_multiple)
+ format_str=format_str, param_stream=param_stream, match_multiple=match_multiple
+ )
return result
@@ -226,7 +245,9 @@ def extract_parameters(format_str, param_stream, match_multiple=False):
return parser.get_extracted_param_value()
-def inject_immutable_parameters(action_alias_db, multiple_execution_parameters, action_context):
+def inject_immutable_parameters(
+ action_alias_db, multiple_execution_parameters, action_context
+):
"""
Inject immutable parameters from the alias definiton on the execution parameters.
Jinja expressions will be resolved.
@@ -235,26 +256,34 @@ def inject_immutable_parameters(action_alias_db, multiple_execution_parameters,
if not immutable_parameters:
return multiple_execution_parameters
- user = action_context.get('user', None)
+ user = action_context.get("user", None)
context = {}
- context.update({
- kv_constants.DATASTORE_PARENT_SCOPE: {
- kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
- scope=kv_constants.FULL_SYSTEM_SCOPE),
- kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup(
- scope=kv_constants.FULL_USER_SCOPE, user=user)
+ context.update(
+ {
+ kv_constants.DATASTORE_PARENT_SCOPE: {
+ kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup(
+ scope=kv_constants.FULL_SYSTEM_SCOPE
+ ),
+ kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup(
+ scope=kv_constants.FULL_USER_SCOPE, user=user
+ ),
+ }
}
- })
+ )
context.update(action_context)
rendered_params = render_values(immutable_parameters, context)
for exec_params in multiple_execution_parameters:
- overriden = [param for param in immutable_parameters.keys() if param in exec_params]
+ overriden = [
+ param for param in immutable_parameters.keys() if param in exec_params
+ ]
if overriden:
raise ValueError(
"Immutable arguments cannot be overriden: {}".format(
- ','.join(overriden)))
+ ",".join(overriden)
+ )
+ )
exec_params.update(rendered_params)
diff --git a/st2common/st2common/models/utils/action_param_utils.py b/st2common/st2common/models/utils/action_param_utils.py
index 1ecf6dbbe8..3edbeae6ed 100644
--- a/st2common/st2common/models/utils/action_param_utils.py
+++ b/st2common/st2common/models/utils/action_param_utils.py
@@ -33,7 +33,7 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None):
merged_meta = {}
# ?? Runner immutable param's meta shouldn't be allowed to be modified by action whatsoever.
- if runner_meta and runner_meta.get('immutable', False):
+ if runner_meta and runner_meta.get("immutable", False):
merged_meta = runner_meta
for key in all_keys:
@@ -42,8 +42,10 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None):
elif key in runner_meta_keys and key not in action_meta_keys:
merged_meta[key] = runner_meta[key]
else:
- if key in ['immutable']:
- merged_meta[key] = runner_meta.get(key, False) or action_meta.get(key, False)
+ if key in ["immutable"]:
+ merged_meta[key] = runner_meta.get(key, False) or action_meta.get(
+ key, False
+ )
else:
merged_meta[key] = action_meta.get(key)
return merged_meta
@@ -51,12 +53,12 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None):
def get_params_view(action_db=None, runner_db=None, merged_only=False):
if runner_db:
- runner_params = fast_deepcopy(getattr(runner_db, 'runner_parameters', {})) or {}
+ runner_params = fast_deepcopy(getattr(runner_db, "runner_parameters", {})) or {}
else:
runner_params = {}
if action_db:
- action_params = fast_deepcopy(getattr(action_db, 'parameters', {})) or {}
+ action_params = fast_deepcopy(getattr(action_db, "parameters", {})) or {}
else:
action_params = {}
@@ -64,19 +66,22 @@ def get_params_view(action_db=None, runner_db=None, merged_only=False):
merged_params = {}
for param in parameters:
- merged_params[param] = _merge_param_meta_values(action_meta=action_params.get(param),
- runner_meta=runner_params.get(param))
+ merged_params[param] = _merge_param_meta_values(
+ action_meta=action_params.get(param), runner_meta=runner_params.get(param)
+ )
if merged_only:
return merged_params
def is_required(param_meta):
- return param_meta.get('required', False)
+ return param_meta.get("required", False)
def is_immutable(param_meta):
- return param_meta.get('immutable', False)
+ return param_meta.get("immutable", False)
- immutable = {param for param in parameters if is_immutable(merged_params.get(param))}
+ immutable = {
+ param for param in parameters if is_immutable(merged_params.get(param))
+ }
required = {param for param in parameters if is_required(merged_params.get(param))}
required = required - immutable
optional = parameters - required - immutable
@@ -89,8 +94,7 @@ def is_immutable(param_meta):
def cast_params(action_ref, params, cast_overrides=None):
- """
- """
+ """"""
params = params or {}
action_db = action_db_util.get_action_by_ref(action_ref)
@@ -98,7 +102,7 @@ def cast_params(action_ref, params, cast_overrides=None):
raise ValueError('Action with ref "%s" doesn\'t exist' % (action_ref))
action_parameters_schema = action_db.parameters
- runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type["name"])
runner_parameters_schema = runnertype_db.runner_parameters
# combine into 1 list of parameter schemas
parameters_schema = {}
@@ -110,29 +114,37 @@ def cast_params(action_ref, params, cast_overrides=None):
for k, v in six.iteritems(params):
parameter_schema = parameters_schema.get(k, None)
if not parameter_schema:
- LOG.debug('Will skip cast of param[name: %s, value: %s]. No schema.', k, v)
+ LOG.debug("Will skip cast of param[name: %s, value: %s]. No schema.", k, v)
continue
- parameter_type = parameter_schema.get('type', None)
+ parameter_type = parameter_schema.get("type", None)
if not parameter_type:
- LOG.debug('Will skip cast of param[name: %s, value: %s]. No type.', k, v)
+ LOG.debug("Will skip cast of param[name: %s, value: %s]. No type.", k, v)
continue
# Pick up cast from teh override and then from the system suppied ones.
cast = cast_overrides.get(parameter_type, None) if cast_overrides else None
if not cast:
cast = get_cast(cast_type=parameter_type)
if not cast:
- LOG.debug('Will skip cast of param[name: %s, value: %s]. No cast for %s.', k, v,
- parameter_type)
+ LOG.debug(
+ "Will skip cast of param[name: %s, value: %s]. No cast for %s.",
+ k,
+ v,
+ parameter_type,
+ )
continue
- LOG.debug('Casting param: %s of type %s to type: %s', v, type(v), parameter_type)
+ LOG.debug(
+ "Casting param: %s of type %s to type: %s", v, type(v), parameter_type
+ )
try:
params[k] = cast(v)
except Exception as e:
v_type = type(v).__name__
- msg = ('Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. '
- 'Perhaps the value is of an invalid type?' %
- (v, v_type, k, parameter_type, six.text_type(e)))
+ msg = (
+ 'Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. '
+ "Perhaps the value is of an invalid type?"
+ % (v, v_type, k, parameter_type, six.text_type(e))
+ )
raise ValueError(msg)
return params
@@ -145,8 +157,13 @@ def validate_action_parameters(action_ref, inputs):
parameters = action_db_util.get_action_parameters_specs(action_ref)
# Check required parameters that have no default defined.
- required = set([param for param, meta in six.iteritems(parameters)
- if meta.get('required', False) and 'default' not in meta])
+ required = set(
+ [
+ param
+ for param, meta in six.iteritems(parameters)
+ if meta.get("required", False) and "default" not in meta
+ ]
+ )
requires = sorted(required.difference(input_set))
diff --git a/st2common/st2common/models/utils/profiling.py b/st2common/st2common/models/utils/profiling.py
index c9d26636b0..47add2adc3 100644
--- a/st2common/st2common/models/utils/profiling.py
+++ b/st2common/st2common/models/utils/profiling.py
@@ -23,10 +23,10 @@
from st2common import log as logging
__all__ = [
- 'enable_profiling',
- 'disable_profiling',
- 'is_enabled',
- 'log_query_and_profile_data_for_queryset'
+ "enable_profiling",
+ "disable_profiling",
+ "is_enabled",
+ "log_query_and_profile_data_for_queryset",
]
LOG = logging.getLogger(__name__)
@@ -72,13 +72,13 @@ def log_query_and_profile_data_for_queryset(queryset):
# Note: Some mongoengine methods don't return queryset (e.g. count)
return queryset
- query = getattr(queryset, '_query', None)
- mongo_query = getattr(queryset, '_mongo_query', query)
- ordering = getattr(queryset, '_ordering', None)
- limit = getattr(queryset, '_limit', None)
- collection = getattr(queryset, '_collection', None)
- collection_name = getattr(collection, 'name', None)
- only_fields = getattr(queryset, 'only_fields', None)
+ query = getattr(queryset, "_query", None)
+ mongo_query = getattr(queryset, "_mongo_query", query)
+ ordering = getattr(queryset, "_ordering", None)
+ limit = getattr(queryset, "_limit", None)
+ collection = getattr(queryset, "_collection", None)
+ collection_name = getattr(collection, "name", None)
+ only_fields = getattr(queryset, "only_fields", None)
# Note: We need to clone the queryset when using explain because explain advances the cursor
# internally which changes the function result
@@ -86,42 +86,46 @@ def log_query_and_profile_data_for_queryset(queryset):
explain_info = cloned_queryset.explain(format=True)
if mongo_query is not None and collection_name is not None:
- mongo_shell_query = construct_mongo_shell_query(mongo_query=mongo_query,
- collection_name=collection_name,
- ordering=ordering,
- limit=limit,
- only_fields=only_fields)
- extra = {'mongo_query': mongo_query, 'mongo_shell_query': mongo_shell_query}
- LOG.debug('MongoDB query: %s' % (mongo_shell_query), extra=extra)
- LOG.debug('MongoDB explain data: %s' % (explain_info))
+ mongo_shell_query = construct_mongo_shell_query(
+ mongo_query=mongo_query,
+ collection_name=collection_name,
+ ordering=ordering,
+ limit=limit,
+ only_fields=only_fields,
+ )
+ extra = {"mongo_query": mongo_query, "mongo_shell_query": mongo_shell_query}
+ LOG.debug("MongoDB query: %s" % (mongo_shell_query), extra=extra)
+ LOG.debug("MongoDB explain data: %s" % (explain_info))
return queryset
-def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit,
- only_fields=None):
+def construct_mongo_shell_query(
+ mongo_query, collection_name, ordering, limit, only_fields=None
+):
result = []
# Select collection
- part = 'db.{collection}'.format(collection=collection_name)
+ part = "db.{collection}".format(collection=collection_name)
result.append(part)
# Include filters (if any)
if mongo_query:
filter_predicate = mongo_query
else:
- filter_predicate = ''
+ filter_predicate = ""
- part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate)
+ part = "find({filter_predicate})".format(filter_predicate=filter_predicate)
# Include only fields (projection)
if only_fields:
- projection_items = ['\'%s\': 1' % (field) for field in only_fields]
- projection = ', '.join(projection_items)
- part = 'find({filter_predicate}, {{{projection}}})'.format(
- filter_predicate=filter_predicate, projection=projection)
+ projection_items = ["'%s': 1" % (field) for field in only_fields]
+ projection = ", ".join(projection_items)
+ part = "find({filter_predicate}, {{{projection}}})".format(
+ filter_predicate=filter_predicate, projection=projection
+ )
else:
- part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate)
+ part = "find({filter_predicate})".format(filter_predicate=filter_predicate)
result.append(part)
@@ -129,17 +133,18 @@ def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit,
if ordering:
sort_predicate = []
for field_name, direction in ordering:
- sort_predicate.append('{name}: {direction}'.format(name=field_name,
- direction=direction))
+ sort_predicate.append(
+ "{name}: {direction}".format(name=field_name, direction=direction)
+ )
- sort_predicate = ', '.join(sort_predicate)
- part = 'sort({{{sort_predicate}}})'.format(sort_predicate=sort_predicate)
+ sort_predicate = ", ".join(sort_predicate)
+ part = "sort({{{sort_predicate}}})".format(sort_predicate=sort_predicate)
result.append(part)
# Include limit info (if any)
if limit is not None:
- part = 'limit({limit})'.format(limit=limit)
+ part = "limit({limit})".format(limit=limit)
result.append(part)
- result = '.'.join(result) + ';'
+ result = ".".join(result) + ";"
return result
diff --git a/st2common/st2common/models/utils/sensor_type_utils.py b/st2common/st2common/models/utils/sensor_type_utils.py
index f67a65e530..cd4b068db9 100644
--- a/st2common/st2common/models/utils/sensor_type_utils.py
+++ b/st2common/st2common/models/utils/sensor_type_utils.py
@@ -21,11 +21,7 @@
from st2common.models.db.sensor import SensorTypeDB
from st2common.services import triggers as trigger_service
-__all__ = [
- 'to_sensor_db_model',
- 'get_sensor_entry_point',
- 'create_trigger_types'
-]
+__all__ = ["to_sensor_db_model", "get_sensor_entry_point", "create_trigger_types"]
def to_sensor_db_model(sensor_api_model=None):
@@ -38,37 +34,40 @@ def to_sensor_db_model(sensor_api_model=None):
:rtype: :class:`SensorTypeDB`
"""
- class_name = getattr(sensor_api_model, 'class_name', None)
- pack = getattr(sensor_api_model, 'pack', None)
+ class_name = getattr(sensor_api_model, "class_name", None)
+ pack = getattr(sensor_api_model, "pack", None)
entry_point = get_sensor_entry_point(sensor_api_model)
- artifact_uri = getattr(sensor_api_model, 'artifact_uri', None)
- description = getattr(sensor_api_model, 'description', None)
- trigger_types = getattr(sensor_api_model, 'trigger_types', [])
- poll_interval = getattr(sensor_api_model, 'poll_interval', None)
- enabled = getattr(sensor_api_model, 'enabled', True)
- metadata_file = getattr(sensor_api_model, 'metadata_file', None)
-
- poll_interval = getattr(sensor_api_model, 'poll_interval', None)
+ artifact_uri = getattr(sensor_api_model, "artifact_uri", None)
+ description = getattr(sensor_api_model, "description", None)
+ trigger_types = getattr(sensor_api_model, "trigger_types", [])
+ poll_interval = getattr(sensor_api_model, "poll_interval", None)
+ enabled = getattr(sensor_api_model, "enabled", True)
+ metadata_file = getattr(sensor_api_model, "metadata_file", None)
+
+ poll_interval = getattr(sensor_api_model, "poll_interval", None)
if poll_interval and (poll_interval < MINIMUM_POLL_INTERVAL):
- raise ValueError('Minimum possible poll_interval is %s seconds' %
- (MINIMUM_POLL_INTERVAL))
+ raise ValueError(
+ "Minimum possible poll_interval is %s seconds" % (MINIMUM_POLL_INTERVAL)
+ )
# Add pack and metadata fileto each trigger type item
for trigger_type in trigger_types:
- trigger_type['pack'] = pack
- trigger_type['metadata_file'] = metadata_file
+ trigger_type["pack"] = pack
+ trigger_type["metadata_file"] = metadata_file
trigger_type_refs = create_trigger_types(trigger_types)
- return _create_sensor_type(pack=pack,
- name=class_name,
- description=description,
- artifact_uri=artifact_uri,
- entry_point=entry_point,
- trigger_types=trigger_type_refs,
- poll_interval=poll_interval,
- enabled=enabled,
- metadata_file=metadata_file)
+ return _create_sensor_type(
+ pack=pack,
+ name=class_name,
+ description=description,
+ artifact_uri=artifact_uri,
+ entry_point=entry_point,
+ trigger_types=trigger_type_refs,
+ poll_interval=poll_interval,
+ enabled=enabled,
+ metadata_file=metadata_file,
+ )
def create_trigger_types(trigger_types, metadata_file=None):
@@ -87,29 +86,44 @@ def create_trigger_types(trigger_types, metadata_file=None):
return trigger_type_refs
-def _create_sensor_type(pack=None, name=None, description=None, artifact_uri=None,
- entry_point=None, trigger_types=None, poll_interval=10,
- enabled=True, metadata_file=None):
-
- sensor_type = SensorTypeDB(pack=pack, name=name, description=description,
- artifact_uri=artifact_uri, entry_point=entry_point,
- poll_interval=poll_interval, enabled=enabled,
- trigger_types=trigger_types, metadata_file=metadata_file)
+def _create_sensor_type(
+ pack=None,
+ name=None,
+ description=None,
+ artifact_uri=None,
+ entry_point=None,
+ trigger_types=None,
+ poll_interval=10,
+ enabled=True,
+ metadata_file=None,
+):
+
+ sensor_type = SensorTypeDB(
+ pack=pack,
+ name=name,
+ description=description,
+ artifact_uri=artifact_uri,
+ entry_point=entry_point,
+ poll_interval=poll_interval,
+ enabled=enabled,
+ trigger_types=trigger_types,
+ metadata_file=metadata_file,
+ )
return sensor_type
def get_sensor_entry_point(sensor_api_model):
- file_path = getattr(sensor_api_model, 'artifact_uri', None)
- class_name = getattr(sensor_api_model, 'class_name', None)
- pack = getattr(sensor_api_model, 'pack', None)
+ file_path = getattr(sensor_api_model, "artifact_uri", None)
+ class_name = getattr(sensor_api_model, "class_name", None)
+ pack = getattr(sensor_api_model, "pack", None)
if pack == SYSTEM_PACK_NAME:
# Special case for sensors which come included with the default installation
entry_point = class_name
else:
- module_path = file_path.split('/%s/' % (pack))[1]
- module_path = module_path.replace(os.path.sep, '.')
- module_path = module_path.replace('.py', '')
- entry_point = '%s.%s' % (module_path, class_name)
+ module_path = file_path.split("/%s/" % (pack))[1]
+ module_path = module_path.replace(os.path.sep, ".")
+ module_path = module_path.replace(".py", "")
+ entry_point = "%s.%s" % (module_path, class_name)
return entry_point
diff --git a/st2common/st2common/operators.py b/st2common/st2common/operators.py
index fc38d63215..6896e87658 100644
--- a/st2common/st2common/operators.py
+++ b/st2common/st2common/operators.py
@@ -24,10 +24,10 @@
from st2common.util.payload import PayloadLookup
__all__ = [
- 'SEARCH',
- 'get_operator',
- 'get_allowed_operators',
- 'UnrecognizedConditionError',
+ "SEARCH",
+ "get_operator",
+ "get_allowed_operators",
+ "UnrecognizedConditionError",
]
@@ -40,7 +40,7 @@ def get_operator(op):
if op in operators:
return operators[op]
else:
- raise Exception('Invalid operator: ' + op)
+ raise Exception("Invalid operator: " + op)
class UnrecognizedConditionError(Exception):
@@ -106,35 +106,57 @@ def search(value, criteria_pattern, criteria_condition, check_function):
type: "equals"
pattern: "Approved"
"""
- if criteria_condition == 'any':
+ if criteria_condition == "any":
# Any item of the list can match all patterns
- rtn = any([
- # Any payload item can match
- all([
- # Match all patterns
- check_function(
- child_criterion_k, child_criterion_v,
- PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX))
- for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern)
- ])
- for child_payload in value
- ])
- elif criteria_condition == 'all':
+ rtn = any(
+ [
+ # Any payload item can match
+ all(
+ [
+ # Match all patterns
+ check_function(
+ child_criterion_k,
+ child_criterion_v,
+ PayloadLookup(
+ child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX
+ ),
+ )
+ for child_criterion_k, child_criterion_v in six.iteritems(
+ criteria_pattern
+ )
+ ]
+ )
+ for child_payload in value
+ ]
+ )
+ elif criteria_condition == "all":
# Every item of the list must match all patterns
- rtn = all([
- # All payload items must match
- all([
- # Match all patterns
- check_function(
- child_criterion_k, child_criterion_v,
- PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX))
- for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern)
- ])
- for child_payload in value
- ])
+ rtn = all(
+ [
+ # All payload items must match
+ all(
+ [
+ # Match all patterns
+ check_function(
+ child_criterion_k,
+ child_criterion_v,
+ PayloadLookup(
+ child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX
+ ),
+ )
+ for child_criterion_k, child_criterion_v in six.iteritems(
+ criteria_pattern
+ )
+ ]
+ )
+ for child_payload in value
+ ]
+ )
else:
- raise UnrecognizedConditionError("The '%s' search condition is not recognized, only 'any' "
- "and 'all' are allowed" % criteria_condition)
+ raise UnrecognizedConditionError(
+ "The '%s' search condition is not recognized, only 'any' "
+ "and 'all' are allowed" % criteria_condition
+ )
return rtn
@@ -298,13 +320,17 @@ def _timediff(diff_target, period_seconds, operator):
def timediff_lt(value, criteria_pattern):
if criteria_pattern is None:
return False
- return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=less_than)
+ return _timediff(
+ diff_target=value, period_seconds=criteria_pattern, operator=less_than
+ )
def timediff_gt(value, criteria_pattern):
if criteria_pattern is None:
return False
- return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=greater_than)
+ return _timediff(
+ diff_target=value, period_seconds=criteria_pattern, operator=greater_than
+ )
def exists(value, criteria_pattern):
@@ -344,48 +370,48 @@ def ensure_operators_are_strings(value, criteria_pattern):
:return: tuple(value, criteria_pattern)
"""
if isinstance(value, bytes):
- value = value.decode('utf-8')
+ value = value.decode("utf-8")
if isinstance(criteria_pattern, bytes):
- criteria_pattern = criteria_pattern.decode('utf-8')
+ criteria_pattern = criteria_pattern.decode("utf-8")
return value, criteria_pattern
# operator match strings
-MATCH_WILDCARD = 'matchwildcard'
-MATCH_REGEX = 'matchregex'
-REGEX = 'regex'
-IREGEX = 'iregex'
-EQUALS_SHORT = 'eq'
-EQUALS_LONG = 'equals'
-NEQUALS_LONG = 'nequals'
-NEQUALS_SHORT = 'neq'
-IEQUALS_SHORT = 'ieq'
-IEQUALS_LONG = 'iequals'
-CONTAINS_LONG = 'contains'
-ICONTAINS_LONG = 'icontains'
-NCONTAINS_LONG = 'ncontains'
-INCONTAINS_LONG = 'incontains'
-STARTSWITH_LONG = 'startswith'
-ISTARTSWITH_LONG = 'istartswith'
-ENDSWITH_LONG = 'endswith'
-IENDSWITH_LONG = 'iendswith'
-LESS_THAN_SHORT = 'lt'
-LESS_THAN_LONG = 'lessthan'
-GREATER_THAN_SHORT = 'gt'
-GREATER_THAN_LONG = 'greaterthan'
-TIMEDIFF_LT_SHORT = 'td_lt'
-TIMEDIFF_LT_LONG = 'timediff_lt'
-TIMEDIFF_GT_SHORT = 'td_gt'
-TIMEDIFF_GT_LONG = 'timediff_gt'
-KEY_EXISTS = 'exists'
-KEY_NOT_EXISTS = 'nexists'
-INSIDE_LONG = 'inside'
-INSIDE_SHORT = 'in'
-NINSIDE_LONG = 'ninside'
-NINSIDE_SHORT = 'nin'
-SEARCH = 'search'
+MATCH_WILDCARD = "matchwildcard"
+MATCH_REGEX = "matchregex"
+REGEX = "regex"
+IREGEX = "iregex"
+EQUALS_SHORT = "eq"
+EQUALS_LONG = "equals"
+NEQUALS_LONG = "nequals"
+NEQUALS_SHORT = "neq"
+IEQUALS_SHORT = "ieq"
+IEQUALS_LONG = "iequals"
+CONTAINS_LONG = "contains"
+ICONTAINS_LONG = "icontains"
+NCONTAINS_LONG = "ncontains"
+INCONTAINS_LONG = "incontains"
+STARTSWITH_LONG = "startswith"
+ISTARTSWITH_LONG = "istartswith"
+ENDSWITH_LONG = "endswith"
+IENDSWITH_LONG = "iendswith"
+LESS_THAN_SHORT = "lt"
+LESS_THAN_LONG = "lessthan"
+GREATER_THAN_SHORT = "gt"
+GREATER_THAN_LONG = "greaterthan"
+TIMEDIFF_LT_SHORT = "td_lt"
+TIMEDIFF_LT_LONG = "timediff_lt"
+TIMEDIFF_GT_SHORT = "td_gt"
+TIMEDIFF_GT_LONG = "timediff_gt"
+KEY_EXISTS = "exists"
+KEY_NOT_EXISTS = "nexists"
+INSIDE_LONG = "inside"
+INSIDE_SHORT = "in"
+NINSIDE_LONG = "ninside"
+NINSIDE_SHORT = "nin"
+SEARCH = "search"
# operator lookups
operators = {
diff --git a/st2common/st2common/persistence/action.py b/st2common/st2common/persistence/action.py
index 0a91fc5cef..1f3d17ee01 100644
--- a/st2common/st2common/persistence/action.py
+++ b/st2common/st2common/persistence/action.py
@@ -23,12 +23,12 @@
from st2common.persistence.runner import RunnerType
__all__ = [
- 'Action',
- 'ActionAlias',
- 'ActionExecution',
- 'ActionExecutionState',
- 'LiveAction',
- 'RunnerType'
+ "Action",
+ "ActionAlias",
+ "ActionExecution",
+ "ActionExecutionState",
+ "LiveAction",
+ "RunnerType",
]
diff --git a/st2common/st2common/persistence/auth.py b/st2common/st2common/persistence/auth.py
index f03e3ab4e1..51f0a59ea1 100644
--- a/st2common/st2common/persistence/auth.py
+++ b/st2common/st2common/persistence/auth.py
@@ -14,9 +14,13 @@
# limitations under the License.
from __future__ import absolute_import
-from st2common.exceptions.auth import (TokenNotFoundError, ApiKeyNotFoundError,
- UserNotFoundError, AmbiguousUserError,
- NoNicknameOriginProvidedError)
+from st2common.exceptions.auth import (
+ TokenNotFoundError,
+ ApiKeyNotFoundError,
+ UserNotFoundError,
+ AmbiguousUserError,
+ NoNicknameOriginProvidedError,
+)
from st2common.models.db import MongoDBAccess
from st2common.models.db.auth import UserDB, TokenDB, ApiKeyDB
from st2common.persistence.base import Access
@@ -35,7 +39,7 @@ def get_by_nickname(cls, nickname, origin):
if not origin:
raise NoNicknameOriginProvidedError()
- result = cls.query(**{('nicknames__%s' % origin): nickname})
+ result = cls.query(**{("nicknames__%s" % origin): nickname})
if not result.first():
raise UserNotFoundError()
@@ -51,7 +55,7 @@ def _get_impl(cls):
@classmethod
def _get_by_object(cls, object):
# For User name is unique.
- name = getattr(object, 'name', '')
+ name = getattr(object, "name", "")
return cls.get_by_name(name)
@@ -64,13 +68,15 @@ def _get_impl(cls):
@classmethod
def add_or_update(cls, model_object, publish=True, validate=True):
- if not getattr(model_object, 'user', None):
- raise ValueError('User is not provided in the token.')
- if not getattr(model_object, 'token', None):
- raise ValueError('Token value is not set.')
- if not getattr(model_object, 'expiry', None):
- raise ValueError('Token expiry is not provided in the token.')
- return super(Token, cls).add_or_update(model_object, publish=publish, validate=validate)
+ if not getattr(model_object, "user", None):
+ raise ValueError("User is not provided in the token.")
+ if not getattr(model_object, "token", None):
+ raise ValueError("Token value is not set.")
+ if not getattr(model_object, "expiry", None):
+ raise ValueError("Token expiry is not provided in the token.")
+ return super(Token, cls).add_or_update(
+ model_object, publish=publish, validate=validate
+ )
@classmethod
def get(cls, value):
@@ -96,7 +102,7 @@ def get(cls, value):
result = cls.query(key_hash=value_hash).first()
if not result:
- raise ApiKeyNotFoundError('ApiKey with key_hash=%s not found.' % value_hash)
+ raise ApiKeyNotFoundError("ApiKey with key_hash=%s not found." % value_hash)
return result
@@ -109,4 +115,4 @@ def get_by_key_or_id(cls, value):
try:
return cls.get_by_id(value)
except:
- raise ApiKeyNotFoundError('ApiKey with key or id=%s not found.' % value)
+ raise ApiKeyNotFoundError("ApiKey with key or id=%s not found." % value)
diff --git a/st2common/st2common/persistence/base.py b/st2common/st2common/persistence/base.py
index ea1325762f..a477defe49 100644
--- a/st2common/st2common/persistence/base.py
+++ b/st2common/st2common/persistence/base.py
@@ -23,12 +23,7 @@
from st2common.models.system.common import ResourceReference
-__all__ = [
- 'Access',
-
- 'ContentPackResource',
- 'StatusBasedResource'
-]
+__all__ = ["Access", "ContentPackResource", "StatusBasedResource"]
LOG = logging.getLogger(__name__)
@@ -123,48 +118,60 @@ def aggregate(cls, *args, **kwargs):
return cls._get_impl().aggregate(*args, **kwargs)
@classmethod
- def insert(cls, model_object, publish=True, dispatch_trigger=True,
- log_not_unique_error_as_debug=False):
+ def insert(
+ cls,
+ model_object,
+ publish=True,
+ dispatch_trigger=True,
+ log_not_unique_error_as_debug=False,
+ ):
# Late import to avoid very expensive in-direct import (~1 second) when this function
# is not called / used
from mongoengine import NotUniqueError
if model_object.id:
- raise ValueError('id for object %s was unexpected.' % model_object)
+ raise ValueError("id for object %s was unexpected." % model_object)
try:
model_object = cls._get_impl().insert(model_object)
except NotUniqueError as e:
if log_not_unique_error_as_debug:
- LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e))
+ LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e))
else:
- LOG.exception('Conflict while trying to save in DB.')
+ LOG.exception("Conflict while trying to save in DB.")
# On a conflict determine the conflicting object and return its id in
# the raised exception.
conflict_object = cls._get_by_object(model_object)
conflict_id = str(conflict_object.id) if conflict_object else None
message = six.text_type(e)
- raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id,
- model_object=model_object)
+ raise StackStormDBObjectConflictError(
+ message=message, conflict_id=conflict_id, model_object=model_object
+ )
# Publish internal event on the message bus
if publish:
try:
cls.publish_create(model_object)
except:
- LOG.exception('Publish failed.')
+ LOG.exception("Publish failed.")
# Dispatch trigger
if dispatch_trigger:
try:
cls.dispatch_create_trigger(model_object)
except:
- LOG.exception('Trigger dispatch failed.')
+ LOG.exception("Trigger dispatch failed.")
return model_object
@classmethod
- def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True,
- log_not_unique_error_as_debug=False):
+ def add_or_update(
+ cls,
+ model_object,
+ publish=True,
+ dispatch_trigger=True,
+ validate=True,
+ log_not_unique_error_as_debug=False,
+ ):
# Late import to avoid very expensive in-direct import (~1 second) when this function
# is not called / used
from mongoengine import NotUniqueError
@@ -174,16 +181,17 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida
model_object = cls._get_impl().add_or_update(model_object, validate=True)
except NotUniqueError as e:
if log_not_unique_error_as_debug:
- LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e))
+ LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e))
else:
- LOG.exception('Conflict while trying to save in DB.')
+ LOG.exception("Conflict while trying to save in DB.")
# On a conflict determine the conflicting object and return its id in
# the raised exception.
conflict_object = cls._get_by_object(model_object)
conflict_id = str(conflict_object.id) if conflict_object else None
message = six.text_type(e)
- raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id,
- model_object=model_object)
+ raise StackStormDBObjectConflictError(
+ message=message, conflict_id=conflict_id, model_object=model_object
+ )
is_update = str(pre_persist_id) == str(model_object.id)
@@ -195,7 +203,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida
else:
cls.publish_create(model_object)
except:
- LOG.exception('Publish failed.')
+ LOG.exception("Publish failed.")
# Dispatch trigger
if dispatch_trigger:
@@ -205,7 +213,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida
else:
cls.dispatch_create_trigger(model_object)
except:
- LOG.exception('Trigger dispatch failed.')
+ LOG.exception("Trigger dispatch failed.")
return model_object
@@ -227,14 +235,14 @@ def update(cls, model_object, publish=True, dispatch_trigger=True, **kwargs):
try:
cls.publish_update(model_object)
except:
- LOG.exception('Publish failed.')
+ LOG.exception("Publish failed.")
# Dispatch trigger
if dispatch_trigger:
try:
cls.dispatch_update_trigger(model_object)
except:
- LOG.exception('Trigger dispatch failed.')
+ LOG.exception("Trigger dispatch failed.")
return model_object
@@ -247,14 +255,14 @@ def delete(cls, model_object, publish=True, dispatch_trigger=True):
try:
cls.publish_delete(model_object)
except Exception:
- LOG.exception('Publish failed.')
+ LOG.exception("Publish failed.")
# Dispatch trigger
if dispatch_trigger:
try:
cls.dispatch_delete_trigger(model_object)
except Exception:
- LOG.exception('Trigger dispatch failed.')
+ LOG.exception("Trigger dispatch failed.")
return persisted_object
@@ -289,14 +297,18 @@ def dispatch_create_trigger(cls, model_object):
"""
Dispatch a resource-specific trigger which indicates a new resource has been created.
"""
- return cls._dispatch_operation_trigger(operation='create', model_object=model_object)
+ return cls._dispatch_operation_trigger(
+ operation="create", model_object=model_object
+ )
@classmethod
def dispatch_update_trigger(cls, model_object):
"""
Dispatch a resource-specific trigger which indicates an existing resource has been updated.
"""
- return cls._dispatch_operation_trigger(operation='update', model_object=model_object)
+ return cls._dispatch_operation_trigger(
+ operation="update", model_object=model_object
+ )
@classmethod
def dispatch_delete_trigger(cls, model_object):
@@ -304,14 +316,18 @@ def dispatch_delete_trigger(cls, model_object):
Dispatch a resource-specific trigger which indicates an existing resource has been
deleted.
"""
- return cls._dispatch_operation_trigger(operation='delete', model_object=model_object)
+ return cls._dispatch_operation_trigger(
+ operation="delete", model_object=model_object
+ )
@classmethod
def _get_trigger_ref_for_operation(cls, operation):
trigger_ref = cls.operation_to_trigger_ref_map.get(operation, None)
if not trigger_ref:
- raise ValueError('Trigger ref not specified for operation: %s' % (operation))
+ raise ValueError(
+ "Trigger ref not specified for operation: %s" % (operation)
+ )
return trigger_ref
@@ -322,11 +338,13 @@ def _dispatch_operation_trigger(cls, operation, model_object):
trigger = cls._get_trigger_ref_for_operation(operation=operation)
- object_payload = cls.api_model_cls.from_model(model_object, mask_secrets=True).__json__()
- payload = {
- 'object': object_payload
- }
- return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload)
+ object_payload = cls.api_model_cls.from_model(
+ model_object, mask_secrets=True
+ ).__json__()
+ payload = {"object": object_payload}
+ return cls._dispatch_trigger(
+ operation=operation, trigger=trigger, payload=payload
+ )
@classmethod
def _dispatch_trigger(cls, operation, trigger, payload):
@@ -338,23 +356,23 @@ def _dispatch_trigger(cls, operation, trigger, payload):
class ContentPackResource(Access):
-
@classmethod
def get_by_ref(cls, ref):
if not ref:
return None
ref_obj = ResourceReference.from_string_reference(ref=ref)
- result = cls.query(name=ref_obj.name,
- pack=ref_obj.pack).first()
+ result = cls.query(name=ref_obj.name, pack=ref_obj.pack).first()
return result
@classmethod
def _get_by_object(cls, object):
# For an object with a resourcepack pack.name is unique.
- name = getattr(object, 'name', '')
- pack = getattr(object, 'pack', '')
- return cls.get_by_ref(ResourceReference.to_string_reference(pack=pack, name=name))
+ name = getattr(object, "name", "")
+ pack = getattr(object, "pack", "")
+ return cls.get_by_ref(
+ ResourceReference.to_string_reference(pack=pack, name=name)
+ )
class StatusBasedResource(Access):
@@ -372,4 +390,4 @@ def publish_status(cls, model_object):
"""
publisher = cls._get_publisher()
if publisher:
- publisher.publish_state(model_object, getattr(model_object, 'status', None))
+ publisher.publish_state(model_object, getattr(model_object, "status", None))
diff --git a/st2common/st2common/persistence/cleanup.py b/st2common/st2common/persistence/cleanup.py
index 5831a47cca..06c48dec86 100644
--- a/st2common/st2common/persistence/cleanup.py
+++ b/st2common/st2common/persistence/cleanup.py
@@ -24,11 +24,7 @@
from st2common.script_setup import setup as common_setup
from st2common.script_setup import teardown as common_teardown
-__all__ = [
- 'db_cleanup',
- 'db_cleanup_with_retry',
- 'main'
-]
+__all__ = ["db_cleanup", "db_cleanup_with_retry", "main"]
LOG = logging.getLogger(__name__)
@@ -42,26 +38,47 @@ def db_cleanup():
return connection
-def db_cleanup_with_retry(db_name, db_host, db_port, username=None, password=None,
- ssl=False, ssl_keyfile=None,
- ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None,
- authentication_mechanism=None, ssl_match_hostname=True):
+def db_cleanup_with_retry(
+ db_name,
+ db_host,
+ db_port,
+ username=None,
+ password=None,
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
"""
This method is a retry version of db_cleanup.
"""
- return db_func_with_retry(db_cleanup_func,
- db_name, db_host, db_port,
- username=username, password=password,
- ssl=ssl, ssl_keyfile=ssl_keyfile,
- ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs,
- ssl_ca_certs=ssl_ca_certs,
- authentication_mechanism=authentication_mechanism,
- ssl_match_hostname=ssl_match_hostname)
+ return db_func_with_retry(
+ db_cleanup_func,
+ db_name,
+ db_host,
+ db_port,
+ username=username,
+ password=password,
+ ssl=ssl,
+ ssl_keyfile=ssl_keyfile,
+ ssl_certfile=ssl_certfile,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ authentication_mechanism=authentication_mechanism,
+ ssl_match_hostname=ssl_match_hostname,
+ )
def setup(argv):
- common_setup(config=config, setup_db=False, register_mq_exchanges=False,
- register_internal_trigger_types=False)
+ common_setup(
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_internal_trigger_types=False,
+ )
def teardown():
@@ -75,5 +92,5 @@ def main(argv):
# This script registers actions and rules from content-packs.
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv[1:])
diff --git a/st2common/st2common/persistence/db_init.py b/st2common/st2common/persistence/db_init.py
index 04a2a3a753..678ca71ccd 100644
--- a/st2common/st2common/persistence/db_init.py
+++ b/st2common/st2common/persistence/db_init.py
@@ -22,9 +22,7 @@
from st2common import log as logging
from st2common.models.db import db_setup
-__all__ = [
- 'db_setup_with_retry'
-]
+__all__ = ["db_setup_with_retry"]
LOG = logging.getLogger(__name__)
@@ -36,9 +34,11 @@ def _retry_if_connection_error(error):
# Ideally, a special execption or atleast some exception code.
# If this does become an issue look for "Cannot connect to database" at the
# start of error msg.
- is_connection_error = isinstance(error, mongoengine.connection.MongoEngineConnectionError)
+ is_connection_error = isinstance(
+ error, mongoengine.connection.MongoEngineConnectionError
+ )
if is_connection_error:
- LOG.warn('Retry on ConnectionError - %s', error)
+ LOG.warn("Retry on ConnectionError - %s", error)
return is_connection_error
@@ -52,25 +52,45 @@ def db_func_with_retry(db_func, *args, **kwargs):
# reading of config values however this is lesser code.
retrying_obj = retrying.Retrying(
retry_on_exception=_retry_if_connection_error,
- wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul * 1000,
+ wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul
+ * 1000,
wait_exponential_max=cfg.CONF.database.connection_retry_backoff_max_s * 1000,
- stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000
+ stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000,
)
return retrying_obj.call(db_func, *args, **kwargs)
-def db_setup_with_retry(db_name, db_host, db_port, username=None, password=None,
- ensure_indexes=True, ssl=False, ssl_keyfile=None,
- ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None,
- authentication_mechanism=None, ssl_match_hostname=True):
+def db_setup_with_retry(
+ db_name,
+ db_host,
+ db_port,
+ username=None,
+ password=None,
+ ensure_indexes=True,
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ authentication_mechanism=None,
+ ssl_match_hostname=True,
+):
"""
This method is a retry version of db_setup.
"""
- return db_func_with_retry(db_setup, db_name, db_host, db_port,
- username=username, password=password,
- ensure_indexes=ensure_indexes,
- ssl=ssl, ssl_keyfile=ssl_keyfile,
- ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs,
- ssl_ca_certs=ssl_ca_certs,
- authentication_mechanism=authentication_mechanism,
- ssl_match_hostname=ssl_match_hostname)
+ return db_func_with_retry(
+ db_setup,
+ db_name,
+ db_host,
+ db_port,
+ username=username,
+ password=password,
+ ensure_indexes=ensure_indexes,
+ ssl=ssl,
+ ssl_keyfile=ssl_keyfile,
+ ssl_certfile=ssl_certfile,
+ ssl_cert_reqs=ssl_cert_reqs,
+ ssl_ca_certs=ssl_ca_certs,
+ authentication_mechanism=authentication_mechanism,
+ ssl_match_hostname=ssl_match_hostname,
+ )
diff --git a/st2common/st2common/persistence/execution.py b/st2common/st2common/persistence/execution.py
index 6af949786d..2073dda17b 100644
--- a/st2common/st2common/persistence/execution.py
+++ b/st2common/st2common/persistence/execution.py
@@ -21,8 +21,8 @@
from st2common.persistence.base import Access
__all__ = [
- 'ActionExecution',
- 'ActionExecutionOutput',
+ "ActionExecution",
+ "ActionExecutionOutput",
]
diff --git a/st2common/st2common/persistence/execution_queue.py b/st2common/st2common/persistence/execution_queue.py
index 2ec5f05924..eaedc22f4c 100644
--- a/st2common/st2common/persistence/execution_queue.py
+++ b/st2common/st2common/persistence/execution_queue.py
@@ -18,9 +18,7 @@
from st2common.models.db.execution_queue import EXECUTION_QUEUE_ACCESS
from st2common.persistence import base as persistence
-__all__ = [
- 'ActionExecutionSchedulingQueue'
-]
+__all__ = ["ActionExecutionSchedulingQueue"]
class ActionExecutionSchedulingQueue(persistence.Access):
diff --git a/st2common/st2common/persistence/executionstate.py b/st2common/st2common/persistence/executionstate.py
index 8e94a714aa..7e2debd138 100644
--- a/st2common/st2common/persistence/executionstate.py
+++ b/st2common/st2common/persistence/executionstate.py
@@ -19,9 +19,7 @@
from st2common.models.db.executionstate import actionexecstate_access
from st2common.persistence import base as persistence
-__all__ = [
- 'ActionExecutionState'
-]
+__all__ = ["ActionExecutionState"]
class ActionExecutionState(persistence.Access):
@@ -35,5 +33,7 @@ def _get_impl(cls):
@classmethod
def _get_publisher(cls):
if not cls.publisher:
- cls.publisher = transport.actionexecutionstate.ActionExecutionStatePublisher()
+ cls.publisher = (
+ transport.actionexecutionstate.ActionExecutionStatePublisher()
+ )
return cls.publisher
diff --git a/st2common/st2common/persistence/keyvalue.py b/st2common/st2common/persistence/keyvalue.py
index 634bd72302..10676998f5 100644
--- a/st2common/st2common/persistence/keyvalue.py
+++ b/st2common/st2common/persistence/keyvalue.py
@@ -34,24 +34,30 @@ class KeyValuePair(Access):
publisher = None
api_model_cls = KeyValuePairAPI
- dispatch_trigger_for_operations = ['create', 'update', 'value_change', 'delete']
+ dispatch_trigger_for_operations = ["create", "update", "value_change", "delete"]
operation_to_trigger_ref_map = {
- 'create': ResourceReference.to_string_reference(
- name=KEY_VALUE_PAIR_CREATE_TRIGGER['name'],
- pack=KEY_VALUE_PAIR_CREATE_TRIGGER['pack']),
- 'update': ResourceReference.to_string_reference(
- name=KEY_VALUE_PAIR_UPDATE_TRIGGER['name'],
- pack=KEY_VALUE_PAIR_UPDATE_TRIGGER['pack']),
- 'value_change': ResourceReference.to_string_reference(
- name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['name'],
- pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['pack']),
- 'delete': ResourceReference.to_string_reference(
- name=KEY_VALUE_PAIR_DELETE_TRIGGER['name'],
- pack=KEY_VALUE_PAIR_DELETE_TRIGGER['pack']),
+ "create": ResourceReference.to_string_reference(
+ name=KEY_VALUE_PAIR_CREATE_TRIGGER["name"],
+ pack=KEY_VALUE_PAIR_CREATE_TRIGGER["pack"],
+ ),
+ "update": ResourceReference.to_string_reference(
+ name=KEY_VALUE_PAIR_UPDATE_TRIGGER["name"],
+ pack=KEY_VALUE_PAIR_UPDATE_TRIGGER["pack"],
+ ),
+ "value_change": ResourceReference.to_string_reference(
+ name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["name"],
+ pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["pack"],
+ ),
+ "delete": ResourceReference.to_string_reference(
+ name=KEY_VALUE_PAIR_DELETE_TRIGGER["name"],
+ pack=KEY_VALUE_PAIR_DELETE_TRIGGER["pack"],
+ ),
}
@classmethod
- def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True):
+ def add_or_update(
+ cls, model_object, publish=True, dispatch_trigger=True, validate=True
+ ):
"""
Note: We override add_or_update because we also want to publish high level "value_change"
event for this resource.
@@ -62,32 +68,36 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida
# Not an update
existing_model_object = None
- model_object = super(KeyValuePair, cls).add_or_update(model_object=model_object,
- publish=publish,
- dispatch_trigger=dispatch_trigger)
+ model_object = super(KeyValuePair, cls).add_or_update(
+ model_object=model_object,
+ publish=publish,
+ dispatch_trigger=dispatch_trigger,
+ )
# Dispatch a value_change event which is specific to this resource
if existing_model_object and existing_model_object.value != model_object.value:
- cls.dispatch_value_change_trigger(old_model_object=existing_model_object,
- new_model_object=model_object)
+ cls.dispatch_value_change_trigger(
+ old_model_object=existing_model_object, new_model_object=model_object
+ )
return model_object
@classmethod
def dispatch_value_change_trigger(cls, old_model_object, new_model_object):
- operation = 'value_change'
+ operation = "value_change"
trigger = cls._get_trigger_ref_for_operation(operation=operation)
- old_object_payload = cls.api_model_cls.from_model(old_model_object,
- mask_secrets=True).__json__()
- new_object_payload = cls.api_model_cls.from_model(new_model_object,
- mask_secrets=True).__json__()
- payload = {
- 'old_object': old_object_payload,
- 'new_object': new_object_payload
- }
+ old_object_payload = cls.api_model_cls.from_model(
+ old_model_object, mask_secrets=True
+ ).__json__()
+ new_object_payload = cls.api_model_cls.from_model(
+ new_model_object, mask_secrets=True
+ ).__json__()
+ payload = {"old_object": old_object_payload, "new_object": new_object_payload}
- return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload)
+ return cls._dispatch_trigger(
+ operation=operation, trigger=trigger, payload=payload
+ )
@classmethod
def get_by_names(cls, names):
@@ -124,5 +134,5 @@ def _get_impl(cls):
@classmethod
def _get_by_object(cls, object):
# For KeyValuePair name is unique.
- name = getattr(object, 'name', '')
+ name = getattr(object, "name", "")
return cls.get_by_name(name)
diff --git a/st2common/st2common/persistence/liveaction.py b/st2common/st2common/persistence/liveaction.py
index 61b16b1878..aa7551592a 100644
--- a/st2common/st2common/persistence/liveaction.py
+++ b/st2common/st2common/persistence/liveaction.py
@@ -19,9 +19,7 @@
from st2common.models.db.liveaction import liveaction_access
from st2common.persistence import base as persistence
-__all__ = [
- 'LiveAction'
-]
+__all__ = ["LiveAction"]
class LiveAction(persistence.StatusBasedResource):
diff --git a/st2common/st2common/persistence/marker.py b/st2common/st2common/persistence/marker.py
index 1f35bbcdf2..6be08a25ec 100644
--- a/st2common/st2common/persistence/marker.py
+++ b/st2common/st2common/persistence/marker.py
@@ -19,9 +19,7 @@
from st2common.models.db.marker import DumperMarkerDB
from st2common.persistence.base import Access
-__all__ = [
- 'Marker'
-]
+__all__ = ["Marker"]
class Marker(Access):
diff --git a/st2common/st2common/persistence/pack.py b/st2common/st2common/persistence/pack.py
index 5b2ff39102..01ca6b20cb 100644
--- a/st2common/st2common/persistence/pack.py
+++ b/st2common/st2common/persistence/pack.py
@@ -19,11 +19,7 @@
from st2common.models.db.pack import config_schema_access
from st2common.models.db.pack import config_access
-__all__ = [
- 'Pack',
- 'ConfigSchema',
- 'Config'
-]
+__all__ = ["Pack", "ConfigSchema", "Config"]
class Pack(base.Access):
diff --git a/st2common/st2common/persistence/policy.py b/st2common/st2common/persistence/policy.py
index 468ce07f69..8b6700c194 100644
--- a/st2common/st2common/persistence/policy.py
+++ b/st2common/st2common/persistence/policy.py
@@ -30,16 +30,20 @@ def _get_impl(cls):
def get_by_ref(cls, ref):
if ref:
ref_obj = PolicyTypeReference.from_string_reference(ref=ref)
- result = cls.query(name=ref_obj.name, resource_type=ref_obj.resource_type).first()
+ result = cls.query(
+ name=ref_obj.name, resource_type=ref_obj.resource_type
+ ).first()
return result
else:
return None
@classmethod
def _get_by_object(cls, object):
- name = getattr(object, 'name', '')
- resource_type = getattr(object, 'resource_type', '')
- ref = PolicyTypeReference.to_string_reference(resource_type=resource_type, name=name)
+ name = getattr(object, "name", "")
+ resource_type = getattr(object, "resource_type", "")
+ ref = PolicyTypeReference.to_string_reference(
+ resource_type=resource_type, name=name
+ )
return cls.get_by_ref(ref)
diff --git a/st2common/st2common/persistence/rbac.py b/st2common/st2common/persistence/rbac.py
index bdac61d888..e14b973aeb 100644
--- a/st2common/st2common/persistence/rbac.py
+++ b/st2common/st2common/persistence/rbac.py
@@ -20,12 +20,7 @@
from st2common.models.db.rbac import permission_grant_access
from st2common.models.db.rbac import group_to_role_mapping_access
-__all__ = [
- 'Role',
- 'UserRoleAssignment',
- 'PermissionGrant',
- 'GroupToRoleMapping'
-]
+__all__ = ["Role", "UserRoleAssignment", "PermissionGrant", "GroupToRoleMapping"]
class Role(base.Access):
diff --git a/st2common/st2common/persistence/reactor.py b/st2common/st2common/persistence/reactor.py
index c060877513..0fa35c6bdf 100644
--- a/st2common/st2common/persistence/reactor.py
+++ b/st2common/st2common/persistence/reactor.py
@@ -16,12 +16,6 @@
from __future__ import absolute_import
from st2common.persistence.rule import Rule
from st2common.persistence.sensor import SensorType
-from st2common.persistence.trigger import (Trigger, TriggerInstance, TriggerType)
+from st2common.persistence.trigger import Trigger, TriggerInstance, TriggerType
-__all__ = [
- 'Rule',
- 'SensorType',
- 'Trigger',
- 'TriggerInstance',
- 'TriggerType'
-]
+__all__ = ["Rule", "SensorType", "Trigger", "TriggerInstance", "TriggerType"]
diff --git a/st2common/st2common/persistence/rule.py b/st2common/st2common/persistence/rule.py
index 741b9d4967..0a64e4bb1f 100644
--- a/st2common/st2common/persistence/rule.py
+++ b/st2common/st2common/persistence/rule.py
@@ -36,5 +36,5 @@ def _get_impl(cls):
@classmethod
def _get_by_object(cls, object):
# For RuleType name is unique.
- name = getattr(object, 'name', '')
+ name = getattr(object, "name", "")
return cls.get_by_name(name)
diff --git a/st2common/st2common/persistence/runner.py b/st2common/st2common/persistence/runner.py
index 77440707f2..63cfa36d9e 100644
--- a/st2common/st2common/persistence/runner.py
+++ b/st2common/st2common/persistence/runner.py
@@ -28,5 +28,5 @@ def _get_impl(cls):
@classmethod
def _get_by_object(cls, object):
# For RunnerType name is unique.
- name = getattr(object, 'name', '')
+ name = getattr(object, "name", "")
return cls.get_by_name(name)
diff --git a/st2common/st2common/persistence/sensor.py b/st2common/st2common/persistence/sensor.py
index 67367c7fc4..1a3a3679da 100644
--- a/st2common/st2common/persistence/sensor.py
+++ b/st2common/st2common/persistence/sensor.py
@@ -19,9 +19,7 @@
from st2common.models.db.sensor import sensor_type_access
from st2common.persistence.base import ContentPackResource
-__all__ = [
- 'SensorType'
-]
+__all__ = ["SensorType"]
class SensorType(ContentPackResource):
diff --git a/st2common/st2common/persistence/trace.py b/st2common/st2common/persistence/trace.py
index 5e7276a1f0..ce5472f2aa 100644
--- a/st2common/st2common/persistence/trace.py
+++ b/st2common/st2common/persistence/trace.py
@@ -26,14 +26,16 @@ def _get_impl(cls):
return cls.impl
@classmethod
- def push_components(cls, instance, action_executions=None, rules=None, trigger_instances=None):
+ def push_components(
+ cls, instance, action_executions=None, rules=None, trigger_instances=None
+ ):
update_kwargs = {}
if action_executions:
- update_kwargs['push_all__action_executions'] = action_executions
+ update_kwargs["push_all__action_executions"] = action_executions
if rules:
- update_kwargs['push_all__rules'] = rules
+ update_kwargs["push_all__rules"] = rules
if trigger_instances:
- update_kwargs['push_all__trigger_instances'] = trigger_instances
+ update_kwargs["push_all__trigger_instances"] = trigger_instances
if update_kwargs:
return cls.update(instance, **update_kwargs)
return instance
diff --git a/st2common/st2common/persistence/trigger.py b/st2common/st2common/persistence/trigger.py
index 1cdc4ef4ac..3567a15829 100644
--- a/st2common/st2common/persistence/trigger.py
+++ b/st2common/st2common/persistence/trigger.py
@@ -18,14 +18,14 @@
from st2common import log as logging
from st2common import transport
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-from st2common.models.db.trigger import triggertype_access, trigger_access, triggerinstance_access
-from st2common.persistence.base import (Access, ContentPackResource)
+from st2common.models.db.trigger import (
+ triggertype_access,
+ trigger_access,
+ triggerinstance_access,
+)
+from st2common.persistence.base import Access, ContentPackResource
-__all__ = [
- 'TriggerType',
- 'Trigger',
- 'TriggerInstance'
-]
+__all__ = ["TriggerType", "Trigger", "TriggerInstance"]
LOG = logging.getLogger(__name__)
@@ -57,7 +57,7 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru
# Found in the innards of mongoengine.
# e.g. {'pk': ObjectId('5609e91832ed356d04a93cc0')}
delete_query = model_object._object_key
- delete_query['ref_count__lte'] = 0
+ delete_query["ref_count__lte"] = 0
cls._get_impl().delete_by_query(**delete_query)
# Since delete_by_query cannot tell if teh delete actually happened check with a get call
@@ -73,14 +73,14 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru
try:
cls.publish_delete(model_object)
except Exception:
- LOG.exception('Publish failed.')
+ LOG.exception("Publish failed.")
# Dispatch trigger
if confirmed_delete and dispatch_trigger:
try:
cls.dispatch_delete_trigger(model_object)
except Exception:
- LOG.exception('Trigger dispatch failed.')
+ LOG.exception("Trigger dispatch failed.")
return model_object
diff --git a/st2common/st2common/persistence/workflow.py b/st2common/st2common/persistence/workflow.py
index aa02c320e1..8d993ef4fe 100644
--- a/st2common/st2common/persistence/workflow.py
+++ b/st2common/st2common/persistence/workflow.py
@@ -21,10 +21,7 @@
from st2common.persistence import base as persistence
-__all__ = [
- 'WorkflowExecution',
- 'TaskExecution'
-]
+__all__ = ["WorkflowExecution", "TaskExecution"]
class WorkflowExecution(persistence.StatusBasedResource):
diff --git a/st2common/st2common/policies/__init__.py b/st2common/st2common/policies/__init__.py
index df49fa1f14..ef39e129c9 100644
--- a/st2common/st2common/policies/__init__.py
+++ b/st2common/st2common/policies/__init__.py
@@ -18,7 +18,4 @@
from st2common.policies.base import ResourcePolicyApplicator
-__all__ = [
- 'get_driver',
- 'ResourcePolicyApplicator'
-]
+__all__ = ["get_driver", "ResourcePolicyApplicator"]
diff --git a/st2common/st2common/policies/base.py b/st2common/st2common/policies/base.py
index 5bfc3fa58e..a22fa2fb42 100644
--- a/st2common/st2common/policies/base.py
+++ b/st2common/st2common/policies/base.py
@@ -24,10 +24,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'ResourcePolicyApplicator',
- 'get_driver'
-]
+__all__ = ["ResourcePolicyApplicator", "get_driver"]
@six.add_metaclass(abc.ABCMeta)
@@ -72,9 +69,9 @@ def _get_lock_name(self, values):
lock_uid = []
for key, value in six.iteritems(values):
- lock_uid.append('%s=%s' % (key, value))
+ lock_uid.append("%s=%s" % (key, value))
- lock_uid = ','.join(lock_uid)
+ lock_uid = ",".join(lock_uid)
return lock_uid
@@ -88,5 +85,7 @@ def get_driver(policy_ref, policy_type, **parameters):
# interested in
continue
- if (issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith('Base')):
+ if issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith(
+ "Base"
+ ):
return obj(policy_ref, policy_type, **parameters)
diff --git a/st2common/st2common/policies/concurrency.py b/st2common/st2common/policies/concurrency.py
index a453214b72..fcf96467c3 100644
--- a/st2common/st2common/policies/concurrency.py
+++ b/st2common/st2common/policies/concurrency.py
@@ -18,24 +18,23 @@
from st2common.policies import base
from st2common.services import coordination
-__all__ = [
- 'BaseConcurrencyApplicator'
-]
+__all__ = ["BaseConcurrencyApplicator"]
class BaseConcurrencyApplicator(base.ResourcePolicyApplicator):
- def __init__(self, policy_ref, policy_type, threshold=0, action='delay'):
- super(BaseConcurrencyApplicator, self).__init__(policy_ref=policy_ref,
- policy_type=policy_type)
+ def __init__(self, policy_ref, policy_type, threshold=0, action="delay"):
+ super(BaseConcurrencyApplicator, self).__init__(
+ policy_ref=policy_ref, policy_type=policy_type
+ )
self.threshold = threshold
self.policy_action = action
self.coordinator = coordination.get_coordinator(start_heart=True)
def _get_status_for_policy_action(self, action):
- if action == 'delay':
+ if action == "delay":
status = action_constants.LIVEACTION_STATUS_DELAYED
- elif action == 'cancel':
+ elif action == "cancel":
status = action_constants.LIVEACTION_STATUS_CANCELING
return status
diff --git a/st2common/st2common/rbac/backends/__init__.py b/st2common/st2common/rbac/backends/__init__.py
index cf6429c124..bb7ad3d58f 100644
--- a/st2common/st2common/rbac/backends/__init__.py
+++ b/st2common/st2common/rbac/backends/__init__.py
@@ -22,15 +22,11 @@
from st2common.util import driver_loader
-__all__ = [
- 'get_available_backends',
- 'get_backend_instance',
- 'get_rbac_backend'
-]
+__all__ = ["get_available_backends", "get_backend_instance", "get_rbac_backend"]
LOG = logging.getLogger(__name__)
-BACKENDS_NAMESPACE = 'st2common.rbac.backend'
+BACKENDS_NAMESPACE = "st2common.rbac.backend"
# Cache which maps backed name -> backend class instance
# NOTE: We use cache to avoid slow stevedore dynamic filesystem instrospection on every
@@ -44,7 +40,9 @@ def get_available_backends():
def get_backend_instance(name, use_cache=True):
if name not in BACKENDS_CACHE or not use_cache:
- rbac_backend = driver_loader.get_backend_instance(namespace=BACKENDS_NAMESPACE, name=name)
+ rbac_backend = driver_loader.get_backend_instance(
+ namespace=BACKENDS_NAMESPACE, name=name
+ )
BACKENDS_CACHE[name] = rbac_backend
rbac_backend = BACKENDS_CACHE[name]
diff --git a/st2common/st2common/rbac/backends/base.py b/st2common/st2common/rbac/backends/base.py
index 8e2c54c4fd..f9661d0b4b 100644
--- a/st2common/st2common/rbac/backends/base.py
+++ b/st2common/st2common/rbac/backends/base.py
@@ -23,17 +23,16 @@
from st2common.exceptions.rbac import AccessDeniedError
__all__ = [
- 'BaseRBACBackend',
- 'BaseRBACPermissionResolver',
- 'BaseRBACService',
- 'BaseRBACUtils',
- 'BaseRBACRemoteGroupToRoleSyncer'
+ "BaseRBACBackend",
+ "BaseRBACPermissionResolver",
+ "BaseRBACService",
+ "BaseRBACUtils",
+ "BaseRBACRemoteGroupToRoleSyncer",
]
@six.add_metaclass(abc.ABCMeta)
class BaseRBACBackend(object):
-
def get_resolver_for_resource_type(self, resource_type):
"""
Method which returns PermissionResolver class for the provided resource type.
@@ -67,7 +66,6 @@ def get_utils_class(self):
@six.add_metaclass(abc.ABCMeta)
class BaseRBACPermissionResolver(object):
-
def user_has_permission(self, user_db, permission_type):
"""
Method for checking user permissions which are not tied to a particular resource.
@@ -177,7 +175,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api):
raise NotImplementedError()
@staticmethod
- def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False):
+ def assert_user_is_admin_if_user_query_param_is_provided(
+ user_db, user, require_rbac=False
+ ):
"""
Function which asserts that the request user is administator if "user" query parameter is
provided and doesn't match the current user.
@@ -273,12 +273,12 @@ def get_user_db_from_request(request):
"""
Retrieve UserDB object from the provided request.
"""
- auth_context = request.context.get('auth', {})
+ auth_context = request.context.get("auth", {})
if not auth_context:
return None
- user_db = auth_context.get('user', None)
+ user_db = auth_context.get("user", None)
return user_db
diff --git a/st2common/st2common/rbac/backends/noop.py b/st2common/st2common/rbac/backends/noop.py
index 15ca5a3a75..4d3b8fb127 100644
--- a/st2common/st2common/rbac/backends/noop.py
+++ b/st2common/st2common/rbac/backends/noop.py
@@ -25,11 +25,11 @@
from st2common.exceptions.rbac import AccessDeniedError
__all__ = [
- 'NoOpRBACBackend',
- 'NoOpRBACPermissionResolver',
- 'NoOpRBACService',
- 'NoOpRBACUtils',
- 'NoOpRBACRemoteGroupToRoleSyncer'
+ "NoOpRBACBackend",
+ "NoOpRBACPermissionResolver",
+ "NoOpRBACService",
+ "NoOpRBACUtils",
+ "NoOpRBACRemoteGroupToRoleSyncer",
]
@@ -37,6 +37,7 @@ class NoOpRBACBackend(BaseRBACBackend):
"""
NoOp RBAC backend.
"""
+
def get_resolver_for_resource_type(self, resource_type):
return NoOpRBACPermissionResolver()
@@ -79,7 +80,6 @@ def validate_roles_exists(role_names):
class NoOpRBACUtils(BaseRBACUtils):
-
@staticmethod
def assert_user_is_admin(user_db):
"""
@@ -141,7 +141,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api):
return True
@staticmethod
- def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False):
+ def assert_user_is_admin_if_user_query_param_is_provided(
+ user_db, user, require_rbac=False
+ ):
"""
Function which asserts that the request user is administator if "user" query parameter is
provided and doesn't match the current user.
diff --git a/st2common/st2common/rbac/migrations.py b/st2common/st2common/rbac/migrations.py
index 9e9fc9db18..951bbddf19 100644
--- a/st2common/st2common/rbac/migrations.py
+++ b/st2common/st2common/rbac/migrations.py
@@ -23,11 +23,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'run_all',
-
- 'insert_system_roles'
-]
+__all__ = ["run_all", "insert_system_roles"]
def run_all():
@@ -40,7 +36,7 @@ def insert_system_roles():
"""
system_roles = SystemRole.get_valid_values()
- LOG.debug('Inserting system roles (%s)' % (str(system_roles)))
+ LOG.debug("Inserting system roles (%s)" % (str(system_roles)))
for role_name in system_roles:
description = role_name
diff --git a/st2common/st2common/rbac/types.py b/st2common/st2common/rbac/types.py
index 1c6b0ea352..cceb819d7b 100644
--- a/st2common/st2common/rbac/types.py
+++ b/st2common/st2common/rbac/types.py
@@ -21,19 +21,16 @@
from st2common.constants.types import ResourceType as SystemResourceType
__all__ = [
- 'SystemRole',
- 'PermissionType',
- 'ResourceType',
-
- 'RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP',
- 'PERMISION_TYPE_TO_DESCRIPTION_MAP',
-
- 'ALL_PERMISSION_TYPES',
- 'GLOBAL_PERMISSION_TYPES',
- 'GLOBAL_PACK_PERMISSION_TYPES',
- 'LIST_PERMISSION_TYPES',
-
- 'get_resource_permission_types_with_descriptions'
+ "SystemRole",
+ "PermissionType",
+ "ResourceType",
+ "RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP",
+ "PERMISION_TYPE_TO_DESCRIPTION_MAP",
+ "ALL_PERMISSION_TYPES",
+ "GLOBAL_PERMISSION_TYPES",
+ "GLOBAL_PACK_PERMISSION_TYPES",
+ "LIST_PERMISSION_TYPES",
+ "get_resource_permission_types_with_descriptions",
]
@@ -43,120 +40,120 @@ class PermissionType(Enum):
"""
# Note: There is no create endpoint for runner types right now
- RUNNER_LIST = 'runner_type_list'
- RUNNER_VIEW = 'runner_type_view'
- RUNNER_MODIFY = 'runner_type_modify'
- RUNNER_ALL = 'runner_type_all'
+ RUNNER_LIST = "runner_type_list"
+ RUNNER_VIEW = "runner_type_view"
+ RUNNER_MODIFY = "runner_type_modify"
+ RUNNER_ALL = "runner_type_all"
- PACK_LIST = 'pack_list'
- PACK_VIEW = 'pack_view'
- PACK_CREATE = 'pack_create'
- PACK_MODIFY = 'pack_modify'
- PACK_DELETE = 'pack_delete'
+ PACK_LIST = "pack_list"
+ PACK_VIEW = "pack_view"
+ PACK_CREATE = "pack_create"
+ PACK_MODIFY = "pack_modify"
+ PACK_DELETE = "pack_delete"
# Pack-management specific permissions
# Note: Right now those permissions are global and apply to all the packs.
# In the future we plan to support globs.
- PACK_INSTALL = 'pack_install'
- PACK_UNINSTALL = 'pack_uninstall'
- PACK_REGISTER = 'pack_register'
- PACK_CONFIG = 'pack_config'
- PACK_SEARCH = 'pack_search'
- PACK_VIEWS_INDEX_HEALTH = 'pack_views_index_health'
+ PACK_INSTALL = "pack_install"
+ PACK_UNINSTALL = "pack_uninstall"
+ PACK_REGISTER = "pack_register"
+ PACK_CONFIG = "pack_config"
+ PACK_SEARCH = "pack_search"
+ PACK_VIEWS_INDEX_HEALTH = "pack_views_index_health"
- PACK_ALL = 'pack_all'
+ PACK_ALL = "pack_all"
# Note: Right now we only have read endpoints + update for sensors types
- SENSOR_LIST = 'sensor_type_list'
- SENSOR_VIEW = 'sensor_type_view'
- SENSOR_MODIFY = 'sensor_type_modify'
- SENSOR_ALL = 'sensor_type_all'
-
- ACTION_LIST = 'action_list'
- ACTION_VIEW = 'action_view'
- ACTION_CREATE = 'action_create'
- ACTION_MODIFY = 'action_modify'
- ACTION_DELETE = 'action_delete'
- ACTION_EXECUTE = 'action_execute'
- ACTION_ALL = 'action_all'
-
- ACTION_ALIAS_LIST = 'action_alias_list'
- ACTION_ALIAS_VIEW = 'action_alias_view'
- ACTION_ALIAS_CREATE = 'action_alias_create'
- ACTION_ALIAS_MODIFY = 'action_alias_modify'
- ACTION_ALIAS_MATCH = 'action_alias_match'
- ACTION_ALIAS_HELP = 'action_alias_help'
- ACTION_ALIAS_DELETE = 'action_alias_delete'
- ACTION_ALIAS_ALL = 'action_alias_all'
+ SENSOR_LIST = "sensor_type_list"
+ SENSOR_VIEW = "sensor_type_view"
+ SENSOR_MODIFY = "sensor_type_modify"
+ SENSOR_ALL = "sensor_type_all"
+
+ ACTION_LIST = "action_list"
+ ACTION_VIEW = "action_view"
+ ACTION_CREATE = "action_create"
+ ACTION_MODIFY = "action_modify"
+ ACTION_DELETE = "action_delete"
+ ACTION_EXECUTE = "action_execute"
+ ACTION_ALL = "action_all"
+
+ ACTION_ALIAS_LIST = "action_alias_list"
+ ACTION_ALIAS_VIEW = "action_alias_view"
+ ACTION_ALIAS_CREATE = "action_alias_create"
+ ACTION_ALIAS_MODIFY = "action_alias_modify"
+ ACTION_ALIAS_MATCH = "action_alias_match"
+ ACTION_ALIAS_HELP = "action_alias_help"
+ ACTION_ALIAS_DELETE = "action_alias_delete"
+ ACTION_ALIAS_ALL = "action_alias_all"
# Note: Execution create is granted with "action_execute"
- EXECUTION_LIST = 'execution_list'
- EXECUTION_VIEW = 'execution_view'
- EXECUTION_RE_RUN = 'execution_rerun'
- EXECUTION_STOP = 'execution_stop'
- EXECUTION_ALL = 'execution_all'
- EXECUTION_VIEWS_FILTERS_LIST = 'execution_views_filters_list'
-
- RULE_LIST = 'rule_list'
- RULE_VIEW = 'rule_view'
- RULE_CREATE = 'rule_create'
- RULE_MODIFY = 'rule_modify'
- RULE_DELETE = 'rule_delete'
- RULE_ALL = 'rule_all'
-
- RULE_ENFORCEMENT_LIST = 'rule_enforcement_list'
- RULE_ENFORCEMENT_VIEW = 'rule_enforcement_view'
+ EXECUTION_LIST = "execution_list"
+ EXECUTION_VIEW = "execution_view"
+ EXECUTION_RE_RUN = "execution_rerun"
+ EXECUTION_STOP = "execution_stop"
+ EXECUTION_ALL = "execution_all"
+ EXECUTION_VIEWS_FILTERS_LIST = "execution_views_filters_list"
+
+ RULE_LIST = "rule_list"
+ RULE_VIEW = "rule_view"
+ RULE_CREATE = "rule_create"
+ RULE_MODIFY = "rule_modify"
+ RULE_DELETE = "rule_delete"
+ RULE_ALL = "rule_all"
+
+ RULE_ENFORCEMENT_LIST = "rule_enforcement_list"
+ RULE_ENFORCEMENT_VIEW = "rule_enforcement_view"
# TODO - Maybe "datastore_item" / key_value_item ?
- KEY_VALUE_VIEW = 'key_value_pair_view'
- KEY_VALUE_SET = 'key_value_pair_set'
- KEY_VALUE_DELETE = 'key_value_pair_delete'
-
- WEBHOOK_LIST = 'webhook_list'
- WEBHOOK_VIEW = 'webhook_view'
- WEBHOOK_CREATE = 'webhook_create'
- WEBHOOK_SEND = 'webhook_send'
- WEBHOOK_DELETE = 'webhook_delete'
- WEBHOOK_ALL = 'webhook_all'
-
- TIMER_LIST = 'timer_list'
- TIMER_VIEW = 'timer_view'
- TIMER_ALL = 'timer_all'
-
- API_KEY_LIST = 'api_key_list'
- API_KEY_VIEW = 'api_key_view'
- API_KEY_CREATE = 'api_key_create'
- API_KEY_MODIFY = 'api_key_modify'
- API_KEY_DELETE = 'api_key_delete'
- API_KEY_ALL = 'api_key_all'
-
- TRACE_LIST = 'trace_list'
- TRACE_VIEW = 'trace_view'
- TRACE_ALL = 'trace_all'
+ KEY_VALUE_VIEW = "key_value_pair_view"
+ KEY_VALUE_SET = "key_value_pair_set"
+ KEY_VALUE_DELETE = "key_value_pair_delete"
+
+ WEBHOOK_LIST = "webhook_list"
+ WEBHOOK_VIEW = "webhook_view"
+ WEBHOOK_CREATE = "webhook_create"
+ WEBHOOK_SEND = "webhook_send"
+ WEBHOOK_DELETE = "webhook_delete"
+ WEBHOOK_ALL = "webhook_all"
+
+ TIMER_LIST = "timer_list"
+ TIMER_VIEW = "timer_view"
+ TIMER_ALL = "timer_all"
+
+ API_KEY_LIST = "api_key_list"
+ API_KEY_VIEW = "api_key_view"
+ API_KEY_CREATE = "api_key_create"
+ API_KEY_MODIFY = "api_key_modify"
+ API_KEY_DELETE = "api_key_delete"
+ API_KEY_ALL = "api_key_all"
+
+ TRACE_LIST = "trace_list"
+ TRACE_VIEW = "trace_view"
+ TRACE_ALL = "trace_all"
# Note: Trigger permissions types are also used for Timer API endpoint since timer is just
# a special type of a trigger
- TRIGGER_LIST = 'trigger_list'
- TRIGGER_VIEW = 'trigger_view'
- TRIGGER_ALL = 'trigger_all'
+ TRIGGER_LIST = "trigger_list"
+ TRIGGER_VIEW = "trigger_view"
+ TRIGGER_ALL = "trigger_all"
- POLICY_TYPE_LIST = 'policy_type_list'
- POLICY_TYPE_VIEW = 'policy_type_view'
- POLICY_TYPE_ALL = 'policy_type_all'
+ POLICY_TYPE_LIST = "policy_type_list"
+ POLICY_TYPE_VIEW = "policy_type_view"
+ POLICY_TYPE_ALL = "policy_type_all"
- POLICY_LIST = 'policy_list'
- POLICY_VIEW = 'policy_view'
- POLICY_CREATE = 'policy_create'
- POLICY_MODIFY = 'policy_modify'
- POLICY_DELETE = 'policy_delete'
- POLICY_ALL = 'policy_all'
+ POLICY_LIST = "policy_list"
+ POLICY_VIEW = "policy_view"
+ POLICY_CREATE = "policy_create"
+ POLICY_MODIFY = "policy_modify"
+ POLICY_DELETE = "policy_delete"
+ POLICY_ALL = "policy_all"
- STREAM_VIEW = 'stream_view'
+ STREAM_VIEW = "stream_view"
- INQUIRY_LIST = 'inquiry_list'
- INQUIRY_VIEW = 'inquiry_view'
- INQUIRY_RESPOND = 'inquiry_respond'
- INQUIRY_ALL = 'inquiry_all'
+ INQUIRY_LIST = "inquiry_list"
+ INQUIRY_VIEW = "inquiry_view"
+ INQUIRY_RESPOND = "inquiry_respond"
+ INQUIRY_ALL = "inquiry_all"
@classmethod
def get_valid_permissions_for_resource_type(cls, resource_type):
@@ -183,10 +180,10 @@ def get_resource_type(cls, permission_type):
elif permission_type == PermissionType.EXECUTION_VIEWS_FILTERS_LIST:
return ResourceType.EXECUTION
- split = permission_type.split('_')
+ split = permission_type.split("_")
assert len(split) >= 2
- return '_'.join(split[:-1])
+ return "_".join(split[:-1])
@classmethod
def get_permission_name(cls, permission_type):
@@ -195,12 +192,12 @@ def get_permission_name(cls, permission_type):
:rtype: ``str``
"""
- split = permission_type.split('_')
+ split = permission_type.split("_")
assert len(split) >= 2
# Special case for PACK_VIEWS_INDEX_HEALTH
if permission_type == PermissionType.PACK_VIEWS_INDEX_HEALTH:
- split = permission_type.split('_', 1)
+ split = permission_type.split("_", 1)
return split[1]
return split[-1]
@@ -224,14 +221,16 @@ def get_permission_type(cls, resource_type, permission_name):
"""
# Special case for sensor type (sensor_type -> sensor)
if resource_type == ResourceType.SENSOR:
- resource_type = 'sensor'
+ resource_type = "sensor"
- permission_enum = '%s_%s' % (resource_type.upper(), permission_name.upper())
+ permission_enum = "%s_%s" % (resource_type.upper(), permission_name.upper())
result = getattr(cls, permission_enum, None)
if not result:
- raise ValueError('Unsupported permission type for type "%s" and name "%s"' %
- (resource_type, permission_name))
+ raise ValueError(
+ 'Unsupported permission type for type "%s" and name "%s"'
+ % (resource_type, permission_name)
+ )
return result
@@ -240,6 +239,7 @@ class ResourceType(Enum):
"""
Resource types on which permissions can be granted.
"""
+
RUNNER = SystemResourceType.RUNNER_TYPE
PACK = SystemResourceType.PACK
@@ -266,9 +266,10 @@ class SystemRole(Enum):
"""
Default system roles which can't be manipulated (modified or removed).
"""
- SYSTEM_ADMIN = 'system_admin' # Special role which can't be revoked.
- ADMIN = 'admin'
- OBSERVER = 'observer'
+
+ SYSTEM_ADMIN = "system_admin" # Special role which can't be revoked.
+ ADMIN = "admin"
+ OBSERVER = "observer"
# Maps a list of available permission types for each resource
@@ -292,35 +293,31 @@ class SystemRole(Enum):
PermissionType.PACK_SEARCH,
PermissionType.PACK_VIEWS_INDEX_HEALTH,
PermissionType.PACK_ALL,
-
PermissionType.SENSOR_VIEW,
PermissionType.SENSOR_MODIFY,
PermissionType.SENSOR_ALL,
-
PermissionType.ACTION_VIEW,
PermissionType.ACTION_CREATE,
PermissionType.ACTION_MODIFY,
PermissionType.ACTION_DELETE,
PermissionType.ACTION_EXECUTE,
PermissionType.ACTION_ALL,
-
PermissionType.ACTION_ALIAS_VIEW,
PermissionType.ACTION_ALIAS_CREATE,
PermissionType.ACTION_ALIAS_MODIFY,
PermissionType.ACTION_ALIAS_DELETE,
PermissionType.ACTION_ALIAS_ALL,
-
PermissionType.RULE_VIEW,
PermissionType.RULE_CREATE,
PermissionType.RULE_MODIFY,
PermissionType.RULE_DELETE,
- PermissionType.RULE_ALL
+ PermissionType.RULE_ALL,
],
ResourceType.SENSOR: [
PermissionType.SENSOR_LIST,
PermissionType.SENSOR_VIEW,
PermissionType.SENSOR_MODIFY,
- PermissionType.SENSOR_ALL
+ PermissionType.SENSOR_ALL,
],
ResourceType.ACTION: [
PermissionType.ACTION_LIST,
@@ -329,7 +326,7 @@ class SystemRole(Enum):
PermissionType.ACTION_MODIFY,
PermissionType.ACTION_DELETE,
PermissionType.ACTION_EXECUTE,
- PermissionType.ACTION_ALL
+ PermissionType.ACTION_ALL,
],
ResourceType.ACTION_ALIAS: [
PermissionType.ACTION_ALIAS_LIST,
@@ -339,7 +336,7 @@ class SystemRole(Enum):
PermissionType.ACTION_ALIAS_MATCH,
PermissionType.ACTION_ALIAS_HELP,
PermissionType.ACTION_ALIAS_DELETE,
- PermissionType.ACTION_ALIAS_ALL
+ PermissionType.ACTION_ALIAS_ALL,
],
ResourceType.RULE: [
PermissionType.RULE_LIST,
@@ -347,7 +344,7 @@ class SystemRole(Enum):
PermissionType.RULE_CREATE,
PermissionType.RULE_MODIFY,
PermissionType.RULE_DELETE,
- PermissionType.RULE_ALL
+ PermissionType.RULE_ALL,
],
ResourceType.RULE_ENFORCEMENT: [
PermissionType.RULE_ENFORCEMENT_LIST,
@@ -364,7 +361,7 @@ class SystemRole(Enum):
ResourceType.KEY_VALUE_PAIR: [
PermissionType.KEY_VALUE_VIEW,
PermissionType.KEY_VALUE_SET,
- PermissionType.KEY_VALUE_DELETE
+ PermissionType.KEY_VALUE_DELETE,
],
ResourceType.WEBHOOK: [
PermissionType.WEBHOOK_LIST,
@@ -372,12 +369,12 @@ class SystemRole(Enum):
PermissionType.WEBHOOK_CREATE,
PermissionType.WEBHOOK_SEND,
PermissionType.WEBHOOK_DELETE,
- PermissionType.WEBHOOK_ALL
+ PermissionType.WEBHOOK_ALL,
],
ResourceType.TIMER: [
PermissionType.TIMER_LIST,
PermissionType.TIMER_VIEW,
- PermissionType.TIMER_ALL
+ PermissionType.TIMER_ALL,
],
ResourceType.API_KEY: [
PermissionType.API_KEY_LIST,
@@ -385,17 +382,17 @@ class SystemRole(Enum):
PermissionType.API_KEY_CREATE,
PermissionType.API_KEY_MODIFY,
PermissionType.API_KEY_DELETE,
- PermissionType.API_KEY_ALL
+ PermissionType.API_KEY_ALL,
],
ResourceType.TRACE: [
PermissionType.TRACE_LIST,
PermissionType.TRACE_VIEW,
- PermissionType.TRACE_ALL
+ PermissionType.TRACE_ALL,
],
ResourceType.TRIGGER: [
PermissionType.TRIGGER_LIST,
PermissionType.TRIGGER_VIEW,
- PermissionType.TRIGGER_ALL
+ PermissionType.TRIGGER_ALL,
],
ResourceType.POLICY_TYPE: [
PermissionType.POLICY_TYPE_LIST,
@@ -415,13 +412,16 @@ class SystemRole(Enum):
PermissionType.INQUIRY_VIEW,
PermissionType.INQUIRY_RESPOND,
PermissionType.INQUIRY_ALL,
- ]
+ ],
}
ALL_PERMISSION_TYPES = list(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP.values())
ALL_PERMISSION_TYPES = list(itertools.chain(*ALL_PERMISSION_TYPES))
-LIST_PERMISSION_TYPES = [permission_type for permission_type in ALL_PERMISSION_TYPES if
- permission_type.endswith('_list')]
+LIST_PERMISSION_TYPES = [
+ permission_type
+ for permission_type in ALL_PERMISSION_TYPES
+ if permission_type.endswith("_list")
+]
# List of global permissions (ones which don't apply to a specific resource)
GLOBAL_PERMISSION_TYPES = [
@@ -433,169 +433,198 @@ class SystemRole(Enum):
PermissionType.PACK_CONFIG,
PermissionType.PACK_SEARCH,
PermissionType.PACK_VIEWS_INDEX_HEALTH,
-
# Action alias global permission types
PermissionType.ACTION_ALIAS_MATCH,
PermissionType.ACTION_ALIAS_HELP,
-
# API key global permission types
PermissionType.API_KEY_CREATE,
-
# Policy global permission types
PermissionType.POLICY_CREATE,
-
# Execution
PermissionType.EXECUTION_VIEWS_FILTERS_LIST,
-
# Stream
PermissionType.STREAM_VIEW,
-
# Inquiry
PermissionType.INQUIRY_LIST,
PermissionType.INQUIRY_RESPOND,
- PermissionType.INQUIRY_VIEW
-
+ PermissionType.INQUIRY_VIEW,
] + LIST_PERMISSION_TYPES
-GLOBAL_PACK_PERMISSION_TYPES = [permission_type for permission_type in GLOBAL_PERMISSION_TYPES if
- permission_type.startswith('pack_')]
+GLOBAL_PACK_PERMISSION_TYPES = [
+ permission_type
+ for permission_type in GLOBAL_PERMISSION_TYPES
+ if permission_type.startswith("pack_")
+]
# Maps a permission type to the corresponding description
PERMISION_TYPE_TO_DESCRIPTION_MAP = {
- PermissionType.PACK_LIST: 'Ability to list (view all) packs.',
- PermissionType.PACK_VIEW: 'Ability to view a pack.',
- PermissionType.PACK_CREATE: 'Ability to create a new pack.',
- PermissionType.PACK_MODIFY: 'Ability to modify (update) an existing pack.',
- PermissionType.PACK_DELETE: 'Ability to delete an existing pack.',
- PermissionType.PACK_INSTALL: 'Ability to install packs.',
- PermissionType.PACK_UNINSTALL: 'Ability to uninstall packs.',
- PermissionType.PACK_REGISTER: 'Ability to register packs and corresponding resources.',
- PermissionType.PACK_CONFIG: 'Ability to configure a pack.',
- PermissionType.PACK_SEARCH: 'Ability to query registry and search packs.',
- PermissionType.PACK_VIEWS_INDEX_HEALTH: 'Ability to query health of pack registries.',
- PermissionType.PACK_ALL: ('Ability to perform all the supported operations on a particular '
- 'pack.'),
-
- PermissionType.SENSOR_LIST: 'Ability to list (view all) sensors.',
- PermissionType.SENSOR_VIEW: 'Ability to view a sensor',
- PermissionType.SENSOR_MODIFY: ('Ability to modify (update) an existing sensor. Also implies '
- '"sensor_type_view" permission.'),
- PermissionType.SENSOR_ALL: ('Ability to perform all the supported operations on a particular '
- 'sensor.'),
-
- PermissionType.ACTION_LIST: 'Ability to list (view all) actions.',
- PermissionType.ACTION_VIEW: 'Ability to view an action.',
- PermissionType.ACTION_CREATE: ('Ability to create a new action. Also implies "action_view" '
- 'permission.'),
- PermissionType.ACTION_MODIFY: ('Ability to modify (update) an existing action. Also implies '
- '"action_view" permission.'),
- PermissionType.ACTION_DELETE: ('Ability to delete an existing action. Also implies '
- '"action_view" permission.'),
- PermissionType.ACTION_EXECUTE: ('Ability to execute (run) an action. Also implies '
- '"action_view" permission.'),
- PermissionType.ACTION_ALL: ('Ability to perform all the supported operations on a particular '
- 'action.'),
-
- PermissionType.ACTION_ALIAS_LIST: 'Ability to list (view all) action aliases.',
- PermissionType.ACTION_ALIAS_VIEW: 'Ability to view an action alias.',
- PermissionType.ACTION_ALIAS_CREATE: ('Ability to create a new action alias. Also implies'
- ' "action_alias_view" permission.'),
- PermissionType.ACTION_ALIAS_MODIFY: ('Ability to modify (update) an existing action alias. '
- 'Also implies "action_alias_view" permission.'),
- PermissionType.ACTION_ALIAS_MATCH: ('Ability to use action alias match API endpoint.'),
- PermissionType.ACTION_ALIAS_HELP: ('Ability to use action alias help API endpoint.'),
- PermissionType.ACTION_ALIAS_DELETE: ('Ability to delete an existing action alias. Also '
- 'implies "action_alias_view" permission.'),
- PermissionType.ACTION_ALIAS_ALL: ('Ability to perform all the supported operations on a '
- 'particular action alias.'),
-
- PermissionType.EXECUTION_LIST: 'Ability to list (view all) executions.',
- PermissionType.EXECUTION_VIEW: 'Ability to view an execution.',
- PermissionType.EXECUTION_RE_RUN: 'Ability to create a new action.',
- PermissionType.EXECUTION_STOP: 'Ability to stop (cancel) a running execution.',
- PermissionType.EXECUTION_ALL: ('Ability to perform all the supported operations on a '
- 'particular execution.'),
- PermissionType.EXECUTION_VIEWS_FILTERS_LIST: ('Ability view all the distinct execution '
- 'filters.'),
-
- PermissionType.RULE_LIST: 'Ability to list (view all) rules.',
- PermissionType.RULE_VIEW: 'Ability to view a rule.',
- PermissionType.RULE_CREATE: ('Ability to create a new rule. Also implies "rule_view" '
- 'permission'),
- PermissionType.RULE_MODIFY: ('Ability to modify (update) an existing rule. Also implies '
- '"rule_view" permission.'),
- PermissionType.RULE_DELETE: ('Ability to delete an existing rule. Also implies "rule_view" '
- 'permission.'),
- PermissionType.RULE_ALL: ('Ability to perform all the supported operations on a particular '
- 'rule.'),
-
- PermissionType.RULE_ENFORCEMENT_LIST: 'Ability to list (view all) rule enforcements.',
- PermissionType.RULE_ENFORCEMENT_VIEW: 'Ability to view a rule enforcement.',
-
- PermissionType.RUNNER_LIST: 'Ability to list (view all) runners.',
- PermissionType.RUNNER_VIEW: 'Ability to view a runner.',
- PermissionType.RUNNER_MODIFY: ('Ability to modify (update) an existing runner. Also implies '
- '"runner_type_view" permission.'),
- PermissionType.RUNNER_ALL: ('Ability to perform all the supported operations on a particular '
- 'runner.'),
-
- PermissionType.WEBHOOK_LIST: 'Ability to list (view all) webhooks.',
- PermissionType.WEBHOOK_VIEW: ('Ability to view a webhook.'),
- PermissionType.WEBHOOK_CREATE: ('Ability to create a new webhook.'),
- PermissionType.WEBHOOK_SEND: ('Ability to send / POST data to an existing webhook.'),
- PermissionType.WEBHOOK_DELETE: ('Ability to delete an existing webhook.'),
- PermissionType.WEBHOOK_ALL: ('Ability to perform all the supported operations on a particular '
- 'webhook.'),
-
- PermissionType.TIMER_LIST: 'Ability to list (view all) timers.',
- PermissionType.TIMER_VIEW: ('Ability to view a timer.'),
- PermissionType.TIMER_ALL: ('Ability to perform all the supported operations on timers'),
-
- PermissionType.API_KEY_LIST: 'Ability to list (view all) API keys.',
- PermissionType.API_KEY_VIEW: ('Ability to view an API Key.'),
- PermissionType.API_KEY_CREATE: ('Ability to create a new API Key.'),
- PermissionType.API_KEY_MODIFY: ('Ability to modify (update) an existing API key. Also implies '
- '"api_key_view" permission.'),
- PermissionType.API_KEY_DELETE: ('Ability to delete an existing API Keys.'),
- PermissionType.API_KEY_ALL: ('Ability to perform all the supported operations on an API Key.'),
-
- PermissionType.KEY_VALUE_VIEW: ('Ability to view Key-Value Pairs.'),
- PermissionType.KEY_VALUE_SET: ('Ability to set a Key-Value Pair.'),
- PermissionType.KEY_VALUE_DELETE: ('Ability to delete an existing Key-Value Pair.'),
-
- PermissionType.TRACE_LIST: ('Ability to list (view all) traces.'),
- PermissionType.TRACE_VIEW: ('Ability to view a trace.'),
- PermissionType.TRACE_ALL: ('Ability to perform all the supported operations on traces.'),
-
- PermissionType.TRIGGER_LIST: ('Ability to list (view all) triggers.'),
- PermissionType.TRIGGER_VIEW: ('Ability to view a trigger.'),
- PermissionType.TRIGGER_ALL: ('Ability to perform all the supported operations on triggers.'),
-
- PermissionType.POLICY_TYPE_LIST: ('Ability to list (view all) policy types.'),
- PermissionType.POLICY_TYPE_VIEW: ('Ability to view a policy types.'),
- PermissionType.POLICY_TYPE_ALL: ('Ability to perform all the supported operations on policy'
- ' types.'),
-
- PermissionType.POLICY_LIST: 'Ability to list (view all) policies.',
- PermissionType.POLICY_VIEW: ('Ability to view a policy.'),
- PermissionType.POLICY_CREATE: ('Ability to create a new policy.'),
- PermissionType.POLICY_MODIFY: ('Ability to modify an existing policy.'),
- PermissionType.POLICY_DELETE: ('Ability to delete an existing policy.'),
- PermissionType.POLICY_ALL: ('Ability to perform all the supported operations on a particular '
- 'policy.'),
-
- PermissionType.STREAM_VIEW: ('Ability to view / listen to the events on the stream API '
- 'endpoint.'),
-
- PermissionType.INQUIRY_LIST: 'Ability to list existing Inquiries',
- PermissionType.INQUIRY_VIEW: 'Ability to view an existing Inquiry. Also implies '
- '"inquiry_respond" permission.',
- PermissionType.INQUIRY_RESPOND: 'Ability to respond to an existing Inquiry (in general - user '
- 'still needs access per specific inquiry parameters). Also '
- 'implies "inquiry_view" permission.',
- PermissionType.INQUIRY_ALL: ('Ability to perform all supported operations on a particular '
- 'Inquiry.')
+ PermissionType.PACK_LIST: "Ability to list (view all) packs.",
+ PermissionType.PACK_VIEW: "Ability to view a pack.",
+ PermissionType.PACK_CREATE: "Ability to create a new pack.",
+ PermissionType.PACK_MODIFY: "Ability to modify (update) an existing pack.",
+ PermissionType.PACK_DELETE: "Ability to delete an existing pack.",
+ PermissionType.PACK_INSTALL: "Ability to install packs.",
+ PermissionType.PACK_UNINSTALL: "Ability to uninstall packs.",
+ PermissionType.PACK_REGISTER: "Ability to register packs and corresponding resources.",
+ PermissionType.PACK_CONFIG: "Ability to configure a pack.",
+ PermissionType.PACK_SEARCH: "Ability to query registry and search packs.",
+ PermissionType.PACK_VIEWS_INDEX_HEALTH: "Ability to query health of pack registries.",
+ PermissionType.PACK_ALL: (
+ "Ability to perform all the supported operations on a particular " "pack."
+ ),
+ PermissionType.SENSOR_LIST: "Ability to list (view all) sensors.",
+ PermissionType.SENSOR_VIEW: "Ability to view a sensor",
+ PermissionType.SENSOR_MODIFY: (
+ "Ability to modify (update) an existing sensor. Also implies "
+ '"sensor_type_view" permission.'
+ ),
+ PermissionType.SENSOR_ALL: (
+ "Ability to perform all the supported operations on a particular " "sensor."
+ ),
+ PermissionType.ACTION_LIST: "Ability to list (view all) actions.",
+ PermissionType.ACTION_VIEW: "Ability to view an action.",
+ PermissionType.ACTION_CREATE: (
+ 'Ability to create a new action. Also implies "action_view" ' "permission."
+ ),
+ PermissionType.ACTION_MODIFY: (
+ "Ability to modify (update) an existing action. Also implies "
+ '"action_view" permission.'
+ ),
+ PermissionType.ACTION_DELETE: (
+ "Ability to delete an existing action. Also implies "
+ '"action_view" permission.'
+ ),
+ PermissionType.ACTION_EXECUTE: (
+ "Ability to execute (run) an action. Also implies " '"action_view" permission.'
+ ),
+ PermissionType.ACTION_ALL: (
+ "Ability to perform all the supported operations on a particular " "action."
+ ),
+ PermissionType.ACTION_ALIAS_LIST: "Ability to list (view all) action aliases.",
+ PermissionType.ACTION_ALIAS_VIEW: "Ability to view an action alias.",
+ PermissionType.ACTION_ALIAS_CREATE: (
+ "Ability to create a new action alias. Also implies"
+ ' "action_alias_view" permission.'
+ ),
+ PermissionType.ACTION_ALIAS_MODIFY: (
+ "Ability to modify (update) an existing action alias. "
+ 'Also implies "action_alias_view" permission.'
+ ),
+ PermissionType.ACTION_ALIAS_MATCH: (
+ "Ability to use action alias match API endpoint."
+ ),
+ PermissionType.ACTION_ALIAS_HELP: (
+ "Ability to use action alias help API endpoint."
+ ),
+ PermissionType.ACTION_ALIAS_DELETE: (
+ "Ability to delete an existing action alias. Also "
+ 'implies "action_alias_view" permission.'
+ ),
+ PermissionType.ACTION_ALIAS_ALL: (
+ "Ability to perform all the supported operations on a "
+ "particular action alias."
+ ),
+ PermissionType.EXECUTION_LIST: "Ability to list (view all) executions.",
+ PermissionType.EXECUTION_VIEW: "Ability to view an execution.",
+ PermissionType.EXECUTION_RE_RUN: "Ability to create a new action.",
+ PermissionType.EXECUTION_STOP: "Ability to stop (cancel) a running execution.",
+ PermissionType.EXECUTION_ALL: (
+ "Ability to perform all the supported operations on a " "particular execution."
+ ),
+ PermissionType.EXECUTION_VIEWS_FILTERS_LIST: (
+ "Ability view all the distinct execution " "filters."
+ ),
+ PermissionType.RULE_LIST: "Ability to list (view all) rules.",
+ PermissionType.RULE_VIEW: "Ability to view a rule.",
+ PermissionType.RULE_CREATE: (
+ 'Ability to create a new rule. Also implies "rule_view" ' "permission"
+ ),
+ PermissionType.RULE_MODIFY: (
+ "Ability to modify (update) an existing rule. Also implies "
+ '"rule_view" permission.'
+ ),
+ PermissionType.RULE_DELETE: (
+ 'Ability to delete an existing rule. Also implies "rule_view" ' "permission."
+ ),
+ PermissionType.RULE_ALL: (
+ "Ability to perform all the supported operations on a particular " "rule."
+ ),
+ PermissionType.RULE_ENFORCEMENT_LIST: "Ability to list (view all) rule enforcements.",
+ PermissionType.RULE_ENFORCEMENT_VIEW: "Ability to view a rule enforcement.",
+ PermissionType.RUNNER_LIST: "Ability to list (view all) runners.",
+ PermissionType.RUNNER_VIEW: "Ability to view a runner.",
+ PermissionType.RUNNER_MODIFY: (
+ "Ability to modify (update) an existing runner. Also implies "
+ '"runner_type_view" permission.'
+ ),
+ PermissionType.RUNNER_ALL: (
+ "Ability to perform all the supported operations on a particular " "runner."
+ ),
+ PermissionType.WEBHOOK_LIST: "Ability to list (view all) webhooks.",
+ PermissionType.WEBHOOK_VIEW: ("Ability to view a webhook."),
+ PermissionType.WEBHOOK_CREATE: ("Ability to create a new webhook."),
+ PermissionType.WEBHOOK_SEND: (
+ "Ability to send / POST data to an existing webhook."
+ ),
+ PermissionType.WEBHOOK_DELETE: ("Ability to delete an existing webhook."),
+ PermissionType.WEBHOOK_ALL: (
+ "Ability to perform all the supported operations on a particular " "webhook."
+ ),
+ PermissionType.TIMER_LIST: "Ability to list (view all) timers.",
+ PermissionType.TIMER_VIEW: ("Ability to view a timer."),
+ PermissionType.TIMER_ALL: (
+ "Ability to perform all the supported operations on timers"
+ ),
+ PermissionType.API_KEY_LIST: "Ability to list (view all) API keys.",
+ PermissionType.API_KEY_VIEW: ("Ability to view an API Key."),
+ PermissionType.API_KEY_CREATE: ("Ability to create a new API Key."),
+ PermissionType.API_KEY_MODIFY: (
+ "Ability to modify (update) an existing API key. Also implies "
+ '"api_key_view" permission.'
+ ),
+ PermissionType.API_KEY_DELETE: ("Ability to delete an existing API Keys."),
+ PermissionType.API_KEY_ALL: (
+ "Ability to perform all the supported operations on an API Key."
+ ),
+ PermissionType.KEY_VALUE_VIEW: ("Ability to view Key-Value Pairs."),
+ PermissionType.KEY_VALUE_SET: ("Ability to set a Key-Value Pair."),
+ PermissionType.KEY_VALUE_DELETE: ("Ability to delete an existing Key-Value Pair."),
+ PermissionType.TRACE_LIST: ("Ability to list (view all) traces."),
+ PermissionType.TRACE_VIEW: ("Ability to view a trace."),
+ PermissionType.TRACE_ALL: (
+ "Ability to perform all the supported operations on traces."
+ ),
+ PermissionType.TRIGGER_LIST: ("Ability to list (view all) triggers."),
+ PermissionType.TRIGGER_VIEW: ("Ability to view a trigger."),
+ PermissionType.TRIGGER_ALL: (
+ "Ability to perform all the supported operations on triggers."
+ ),
+ PermissionType.POLICY_TYPE_LIST: ("Ability to list (view all) policy types."),
+ PermissionType.POLICY_TYPE_VIEW: ("Ability to view a policy types."),
+ PermissionType.POLICY_TYPE_ALL: (
+ "Ability to perform all the supported operations on policy" " types."
+ ),
+ PermissionType.POLICY_LIST: "Ability to list (view all) policies.",
+ PermissionType.POLICY_VIEW: ("Ability to view a policy."),
+ PermissionType.POLICY_CREATE: ("Ability to create a new policy."),
+ PermissionType.POLICY_MODIFY: ("Ability to modify an existing policy."),
+ PermissionType.POLICY_DELETE: ("Ability to delete an existing policy."),
+ PermissionType.POLICY_ALL: (
+ "Ability to perform all the supported operations on a particular " "policy."
+ ),
+ PermissionType.STREAM_VIEW: (
+ "Ability to view / listen to the events on the stream API " "endpoint."
+ ),
+ PermissionType.INQUIRY_LIST: "Ability to list existing Inquiries",
+ PermissionType.INQUIRY_VIEW: "Ability to view an existing Inquiry. Also implies "
+ '"inquiry_respond" permission.',
+ PermissionType.INQUIRY_RESPOND: "Ability to respond to an existing Inquiry (in general - user "
+ "still needs access per specific inquiry parameters). Also "
+ 'implies "inquiry_view" permission.',
+ PermissionType.INQUIRY_ALL: (
+ "Ability to perform all supported operations on a particular " "Inquiry."
+ ),
}
@@ -607,10 +636,13 @@ def get_resource_permission_types_with_descriptions():
"""
result = {}
- for resource_type, permission_types in six.iteritems(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP):
+ for resource_type, permission_types in six.iteritems(
+ RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP
+ ):
result[resource_type] = {}
for permission_type in permission_types:
- result[resource_type][permission_type] = \
- PERMISION_TYPE_TO_DESCRIPTION_MAP[permission_type]
+ result[resource_type][permission_type] = PERMISION_TYPE_TO_DESCRIPTION_MAP[
+ permission_type
+ ]
return result
diff --git a/st2common/st2common/router.py b/st2common/st2common/router.py
index 29b34031b4..47ef009b98 100644
--- a/st2common/st2common/router.py
+++ b/st2common/st2common/router.py
@@ -43,15 +43,12 @@
from st2common.util.http import parse_content_type_header
__all__ = [
- 'Router',
-
- 'Response',
-
- 'NotFoundException',
-
- 'abort',
- 'abort_unauthorized',
- 'exc'
+ "Router",
+ "Response",
+ "NotFoundException",
+ "abort",
+ "abort_unauthorized",
+ "exc",
]
LOG = logging.getLogger(__name__)
@@ -63,24 +60,24 @@ def op_resolver(op_id):
:rtype: ``tuple``
"""
- module_name, func_name = op_id.split(':', 1)
- controller_name = func_name.split('.')[0]
+ module_name, func_name = op_id.split(":", 1)
+ controller_name = func_name.split(".")[0]
__import__(module_name)
module = sys.modules[module_name]
controller_instance = getattr(module, controller_name)
- method_callable = functools.reduce(getattr, func_name.split('.'), module)
+ method_callable = functools.reduce(getattr, func_name.split("."), module)
return controller_instance, method_callable
-def abort(status_code=exc.HTTPInternalServerError.code, message='Unhandled exception'):
+def abort(status_code=exc.HTTPInternalServerError.code, message="Unhandled exception"):
raise exc.status_map[status_code](message)
def abort_unauthorized(msg=None):
- raise exc.HTTPUnauthorized('Unauthorized - %s' % msg if msg else 'Unauthorized')
+ raise exc.HTTPUnauthorized("Unauthorized - %s" % msg if msg else "Unauthorized")
def extend_with_default(validator_class):
@@ -92,12 +89,16 @@ def set_defaults(validator, properties, instance, schema):
instance.setdefault(property, subschema["default"])
for error in validate_properties(
- validator, properties, instance, schema,
+ validator,
+ properties,
+ instance,
+ schema,
):
yield error
return jsonschema.validators.extend(
- validator_class, {"properties": set_defaults},
+ validator_class,
+ {"properties": set_defaults},
)
@@ -109,7 +110,8 @@ def set_additional_check(validator, properties, instance, schema):
yield error
return jsonschema.validators.extend(
- validator_class, {"x-additional-check": set_additional_check},
+ validator_class,
+ {"x-additional-check": set_additional_check},
)
@@ -126,7 +128,8 @@ def set_type_draft4(validator, types, instance, schema):
yield error
return jsonschema.validators.extend(
- validator_class, {"type": set_type_draft4},
+ validator_class,
+ {"type": set_type_draft4},
)
@@ -141,27 +144,40 @@ class NotFoundException(Exception):
class Response(webob.Response):
- def __init__(self, body=None, status=None, headerlist=None, app_iter=None, content_type=None,
- *args, **kwargs):
+ def __init__(
+ self,
+ body=None,
+ status=None,
+ headerlist=None,
+ app_iter=None,
+ content_type=None,
+ *args,
+ **kwargs,
+ ):
# Do some sanity checking, and turn json_body into an actual body
- if app_iter is None and body is None and ('json_body' in kwargs or 'json' in kwargs):
- if 'json_body' in kwargs:
- json_body = kwargs.pop('json_body')
+ if (
+ app_iter is None
+ and body is None
+ and ("json_body" in kwargs or "json" in kwargs)
+ ):
+ if "json_body" in kwargs:
+ json_body = kwargs.pop("json_body")
else:
- json_body = kwargs.pop('json')
- body = json_encode(json_body).encode('UTF-8')
+ json_body = kwargs.pop("json")
+ body = json_encode(json_body).encode("UTF-8")
if content_type is None:
- content_type = 'application/json'
+ content_type = "application/json"
- super(Response, self).__init__(body, status, headerlist, app_iter, content_type,
- *args, **kwargs)
+ super(Response, self).__init__(
+ body, status, headerlist, app_iter, content_type, *args, **kwargs
+ )
def _json_body__get(self):
return super(Response, self)._json_body__get()
def _json_body__set(self, value):
- self.body = json_encode(value).encode('UTF-8')
+ self.body = json_encode(value).encode("UTF-8")
def _json_body__del(self):
return super(Response, self)._json_body__del()
@@ -182,44 +198,51 @@ def __init__(self, arguments=None, debug=False, auth=True, is_gunicorn=True):
self.routes = routes.Mapper()
def add_spec(self, spec, transforms):
- info = spec.get('info', {})
- LOG.debug('Adding API: %s %s', info.get('title', 'untitled'), info.get('version', '0.0.0'))
+ info = spec.get("info", {})
+ LOG.debug(
+ "Adding API: %s %s",
+ info.get("title", "untitled"),
+ info.get("version", "0.0.0"),
+ )
self.spec = spec
- self.spec_resolver = jsonschema.RefResolver('', self.spec)
+ self.spec_resolver = jsonschema.RefResolver("", self.spec)
validate(copy.deepcopy(self.spec))
for filter in transforms:
- for (path, methods) in six.iteritems(spec['paths']):
+ for (path, methods) in six.iteritems(spec["paths"]):
if not re.search(filter, path):
continue
for (method, endpoint) in six.iteritems(methods):
- conditions = {
- 'method': [method.upper()]
- }
+ conditions = {"method": [method.upper()]}
connect_kw = {}
- if 'x-requirements' in endpoint:
- connect_kw['requirements'] = endpoint['x-requirements']
+ if "x-requirements" in endpoint:
+ connect_kw["requirements"] = endpoint["x-requirements"]
- m = self.routes.submapper(_api_path=path, _api_method=method,
- conditions=conditions)
+ m = self.routes.submapper(
+ _api_path=path, _api_method=method, conditions=conditions
+ )
for transform in transforms[filter]:
m.connect(None, re.sub(filter, transform, path), **connect_kw)
- module_name = endpoint['operationId'].split(':', 1)[0]
+ module_name = endpoint["operationId"].split(":", 1)[0]
__import__(module_name)
for route in sorted(self.routes.matchlist, key=lambda r: r.routepath):
- LOG.debug('Route registered: %+6s %s', route.conditions['method'][0], route.routepath)
+ LOG.debug(
+ "Route registered: %+6s %s",
+ route.conditions["method"][0],
+ route.routepath,
+ )
def match(self, req):
path = url_unquote(req.path)
LOG.debug("Match path: %s", path)
- if len(path) > 1 and path.endswith('/'):
+ if len(path) > 1 and path.endswith("/"):
path = path[:-1]
match = self.routes.match(path, req.environ)
@@ -235,9 +258,9 @@ def match(self, req):
path_vars = dict(path_vars)
- path = path_vars.pop('_api_path')
- method = path_vars.pop('_api_method')
- endpoint = self.spec['paths'][path][method]
+ path = path_vars.pop("_api_path")
+ method = path_vars.pop("_api_method")
+ endpoint = self.spec["paths"][path][method]
return endpoint, path_vars
@@ -256,127 +279,140 @@ def __call__(self, req):
LOG.debug("Parsed endpoint: %s", endpoint)
LOG.debug("Parsed path_vars: %s", path_vars)
- context = copy.copy(getattr(self, 'mock_context', {}))
+ context = copy.copy(getattr(self, "mock_context", {}))
cookie_token = None
# Handle security
- if 'security' in endpoint:
- security = endpoint.get('security')
+ if "security" in endpoint:
+ security = endpoint.get("security")
else:
- security = self.spec.get('security', [])
+ security = self.spec.get("security", [])
if self.auth and security:
try:
- security_definitions = self.spec.get('securityDefinitions', {})
+ security_definitions = self.spec.get("securityDefinitions", {})
for statement in security:
declaration, options = statement.copy().popitem()
definition = security_definitions[declaration]
- if definition['type'] == 'apiKey':
- if definition['in'] == 'header':
- token = req.headers.get(definition['name'])
- elif definition['in'] == 'query':
- token = req.GET.get(definition['name'])
- elif definition['in'] == 'cookie':
- token = req.cookies.get(definition['name'])
+ if definition["type"] == "apiKey":
+ if definition["in"] == "header":
+ token = req.headers.get(definition["name"])
+ elif definition["in"] == "query":
+ token = req.GET.get(definition["name"])
+ elif definition["in"] == "cookie":
+ token = req.cookies.get(definition["name"])
else:
token = None
if token:
- _, auth_func = op_resolver(definition['x-operationId'])
+ _, auth_func = op_resolver(definition["x-operationId"])
auth_resp = auth_func(token)
# Include information on how user authenticated inside the context
- if 'auth-token' in definition['name'].lower():
- auth_method = 'authentication token'
- elif 'api-key' in definition['name'].lower():
- auth_method = 'API key'
-
- context['user'] = User.get_by_name(auth_resp.user)
- context['auth_info'] = {
- 'method': auth_method,
- 'location': definition['in']
+ if "auth-token" in definition["name"].lower():
+ auth_method = "authentication token"
+ elif "api-key" in definition["name"].lower():
+ auth_method = "API key"
+
+ context["user"] = User.get_by_name(auth_resp.user)
+ context["auth_info"] = {
+ "method": auth_method,
+ "location": definition["in"],
}
# Also include token expiration time when authenticated via auth token
- if 'auth-token' in definition['name'].lower():
- context['auth_info']['token_expire'] = auth_resp.expiry
-
- if 'x-set-cookie' in definition:
- max_age = auth_resp.expiry - date_utils.get_datetime_utc_now()
- cookie_token = cookies.make_cookie(definition['x-set-cookie'],
- token,
- max_age=max_age,
- httponly=True)
+ if "auth-token" in definition["name"].lower():
+ context["auth_info"]["token_expire"] = auth_resp.expiry
+
+ if "x-set-cookie" in definition:
+ max_age = (
+ auth_resp.expiry - date_utils.get_datetime_utc_now()
+ )
+ cookie_token = cookies.make_cookie(
+ definition["x-set-cookie"],
+ token,
+ max_age=max_age,
+ httponly=True,
+ )
break
- if 'user' not in context:
- raise auth_exc.NoAuthSourceProvidedError('One of Token or API key required.')
- except (auth_exc.NoAuthSourceProvidedError,
- auth_exc.MultipleAuthSourcesError) as e:
+ if "user" not in context:
+ raise auth_exc.NoAuthSourceProvidedError(
+ "One of Token or API key required."
+ )
+ except (
+ auth_exc.NoAuthSourceProvidedError,
+ auth_exc.MultipleAuthSourcesError,
+ ) as e:
LOG.error(six.text_type(e))
return abort_unauthorized(six.text_type(e))
except auth_exc.TokenNotProvidedError as e:
- LOG.exception('Token is not provided.')
+ LOG.exception("Token is not provided.")
return abort_unauthorized(six.text_type(e))
except auth_exc.TokenNotFoundError as e:
- LOG.exception('Token is not found.')
+ LOG.exception("Token is not found.")
return abort_unauthorized(six.text_type(e))
except auth_exc.TokenExpiredError as e:
- LOG.exception('Token has expired.')
+ LOG.exception("Token has expired.")
return abort_unauthorized(six.text_type(e))
except auth_exc.ApiKeyNotProvidedError as e:
- LOG.exception('API key is not provided.')
+ LOG.exception("API key is not provided.")
return abort_unauthorized(six.text_type(e))
except auth_exc.ApiKeyNotFoundError as e:
- LOG.exception('API key is not found.')
+ LOG.exception("API key is not found.")
return abort_unauthorized(six.text_type(e))
except auth_exc.ApiKeyDisabledError as e:
- LOG.exception('API key is disabled.')
+ LOG.exception("API key is disabled.")
return abort_unauthorized(six.text_type(e))
if cfg.CONF.rbac.enable:
- user_db = context['user']
+ user_db = context["user"]
- permission_type = endpoint.get('x-permissions', None)
+ permission_type = endpoint.get("x-permissions", None)
if permission_type:
rbac_backend = get_rbac_backend()
- resolver = rbac_backend.get_resolver_for_permission_type(permission_type)
- has_permission = resolver.user_has_permission(user_db, permission_type)
+ resolver = rbac_backend.get_resolver_for_permission_type(
+ permission_type
+ )
+ has_permission = resolver.user_has_permission(
+ user_db, permission_type
+ )
if not has_permission:
- raise rbac_exc.ResourceTypeAccessDeniedError(user_db,
- permission_type)
+ raise rbac_exc.ResourceTypeAccessDeniedError(
+ user_db, permission_type
+ )
# Collect parameters
kw = {}
- for param in endpoint.get('parameters', []) + endpoint.get('x-parameters', []):
- name = param['name']
- argument_name = param.get('x-as', None) or name
- source = param['in']
- default = param.get('default', None)
+ for param in endpoint.get("parameters", []) + endpoint.get("x-parameters", []):
+ name = param["name"]
+ argument_name = param.get("x-as", None) or name
+ source = param["in"]
+ default = param.get("default", None)
# Collecting params from different sources
- if source == 'query':
+ if source == "query":
kw[argument_name] = req.GET.get(name, default)
- elif source == 'path':
+ elif source == "path":
kw[argument_name] = path_vars[name]
- elif source == 'header':
+ elif source == "header":
kw[argument_name] = req.headers.get(name, default)
- elif source == 'formData':
+ elif source == "formData":
kw[argument_name] = req.POST.get(name, default)
- elif source == 'environ':
+ elif source == "environ":
kw[argument_name] = req.environ.get(name.upper(), default)
- elif source == 'context':
+ elif source == "context":
kw[argument_name] = context.get(name, default)
- elif source == 'request':
+ elif source == "request":
kw[argument_name] = getattr(req, name)
- elif source == 'body':
- content_type = req.headers.get('Content-Type', 'application/json')
+ elif source == "body":
+ content_type = req.headers.get("Content-Type", "application/json")
content_type = parse_content_type_header(content_type=content_type)[0]
- schema = param['schema']
+ schema = param["schema"]
# NOTE: HACK: Workaround for eventlet wsgi server which sets Content-Type to
# text/plain if Content-Type is not provided in the request.
@@ -384,65 +420,76 @@ def __call__(self, req):
# expect application/json so we explicitly set it to that
# if not provided (set to text/plain by the base http server) and if it's not
# /v1/workflows/inspection API endpoints.
- if not self.is_gunicorn and content_type == 'text/plain':
- operation_id = endpoint['operationId']
+ if not self.is_gunicorn and content_type == "text/plain":
+ operation_id = endpoint["operationId"]
- if ('workflow_inspection_controller' not in operation_id):
- content_type = 'application/json'
+ if "workflow_inspection_controller" not in operation_id:
+ content_type = "application/json"
# Note: We also want to perform validation if no body is explicitly provided - in a
# lot of POST, PUT scenarios, body is mandatory
- if not req.body and content_type == 'application/json':
- req.body = b'{}'
+ if not req.body and content_type == "application/json":
+ req.body = b"{}"
try:
- if content_type == 'application/json':
+ if content_type == "application/json":
data = req.json
- elif content_type == 'text/plain':
+ elif content_type == "text/plain":
data = req.body
- elif content_type in ['application/x-www-form-urlencoded',
- 'multipart/form-data']:
+ elif content_type in [
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
+ ]:
data = urlparse.parse_qs(req.body)
else:
- raise ValueError('Unsupported Content-Type: "%s"' % (content_type))
+ raise ValueError(
+ 'Unsupported Content-Type: "%s"' % (content_type)
+ )
except Exception as e:
- detail = 'Failed to parse request body: %s' % six.text_type(e)
+ detail = "Failed to parse request body: %s" % six.text_type(e)
raise exc.HTTPBadRequest(detail=detail)
# Special case for Python 3
- if six.PY3 and content_type == 'text/plain' and isinstance(data, six.binary_type):
+ if (
+ six.PY3
+ and content_type == "text/plain"
+ and isinstance(data, six.binary_type)
+ ):
# Convert bytes to text type (string / unicode)
- data = data.decode('utf-8')
+ data = data.decode("utf-8")
try:
CustomValidator(schema, resolver=self.spec_resolver).validate(data)
except (jsonschema.ValidationError, ValueError) as e:
- raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)),
- comment=traceback.format_exc())
+ raise exc.HTTPBadRequest(
+ detail=getattr(e, "message", six.text_type(e)),
+ comment=traceback.format_exc(),
+ )
- if content_type == 'text/plain':
+ if content_type == "text/plain":
kw[argument_name] = data
else:
+
class Body(object):
def __init__(self, **entries):
self.__dict__.update(entries)
- ref = schema.get('$ref', None)
+ ref = schema.get("$ref", None)
if ref:
with self.spec_resolver.resolving(ref) as resolved:
schema = resolved
- if 'x-api-model' in schema:
- input_type = schema.get('type', [])
- _, Model = op_resolver(schema['x-api-model'])
+ if "x-api-model" in schema:
+ input_type = schema.get("type", [])
+ _, Model = op_resolver(schema["x-api-model"])
if input_type and not isinstance(input_type, (list, tuple)):
input_type = [input_type]
# root attribute is not an object, we need to use wrapper attribute to
# make it work with **kwarg expansion
- if input_type and 'array' in input_type:
- data = {'data': data}
+ if input_type and "array" in input_type:
+ data = {"data": data}
instance = self._get_model_instance(model_cls=Model, data=data)
@@ -451,143 +498,178 @@ def __init__(self, **entries):
try:
instance = instance.validate()
except (jsonschema.ValidationError, ValueError) as e:
- raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)),
- comment=traceback.format_exc())
+ raise exc.HTTPBadRequest(
+ detail=getattr(e, "message", six.text_type(e)),
+ comment=traceback.format_exc(),
+ )
else:
- LOG.debug('Missing x-api-model definition for %s, using generic Body '
- 'model.' % (endpoint['operationId']))
+ LOG.debug(
+ "Missing x-api-model definition for %s, using generic Body "
+ "model." % (endpoint["operationId"])
+ )
model = Body
instance = self._get_model_instance(model_cls=model, data=data)
kw[argument_name] = instance
# Making sure all required params are present
- required = param.get('required', False)
+ required = param.get("required", False)
if required and kw[argument_name] is None:
detail = 'Required parameter "%s" is missing' % name
raise exc.HTTPBadRequest(detail=detail)
# Validating and casting param types
- param_type = param.get('type', None)
+ param_type = param.get("type", None)
if kw[argument_name] is not None:
- if param_type == 'boolean':
- positive = ('true', '1', 'yes', 'y')
- negative = ('false', '0', 'no', 'n')
+ if param_type == "boolean":
+ positive = ("true", "1", "yes", "y")
+ negative = ("false", "0", "no", "n")
if str(kw[argument_name]).lower() not in positive + negative:
detail = 'Parameter "%s" is not of type boolean' % argument_name
raise exc.HTTPBadRequest(detail=detail)
kw[argument_name] = str(kw[argument_name]).lower() in positive
- elif param_type == 'integer':
- regex = r'^-?[0-9]+$'
+ elif param_type == "integer":
+ regex = r"^-?[0-9]+$"
if not re.search(regex, str(kw[argument_name])):
detail = 'Parameter "%s" is not of type integer' % argument_name
raise exc.HTTPBadRequest(detail=detail)
kw[argument_name] = int(kw[argument_name])
- elif param_type == 'number':
- regex = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$'
+ elif param_type == "number":
+ regex = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$"
if not re.search(regex, str(kw[argument_name])):
detail = 'Parameter "%s" is not of type float' % argument_name
raise exc.HTTPBadRequest(detail=detail)
kw[argument_name] = float(kw[argument_name])
- elif param_type == 'array' and param.get('items', {}).get('type', None) == 'string':
+ elif (
+ param_type == "array"
+ and param.get("items", {}).get("type", None) == "string"
+ ):
if kw[argument_name] is None:
kw[argument_name] = []
elif isinstance(kw[argument_name], (list, tuple)):
# argument is already an array
pass
else:
- kw[argument_name] = kw[argument_name].split(',')
+ kw[argument_name] = kw[argument_name].split(",")
# Call the controller
try:
- controller_instance, func = op_resolver(endpoint['operationId'])
+ controller_instance, func = op_resolver(endpoint["operationId"])
except Exception as e:
- LOG.exception('Failed to load controller for operation "%s": %s' %
- (endpoint['operationId'], six.text_type(e)))
+ LOG.exception(
+ 'Failed to load controller for operation "%s": %s'
+ % (endpoint["operationId"], six.text_type(e))
+ )
raise e
try:
resp = func(**kw)
except DataStoreKeyNotFoundError as e:
- LOG.warning('Failed to call controller function "%s" for operation "%s": %s' %
- (func.__name__, endpoint['operationId'], six.text_type(e)))
+ LOG.warning(
+ 'Failed to call controller function "%s" for operation "%s": %s'
+ % (func.__name__, endpoint["operationId"], six.text_type(e))
+ )
raise e
except Exception as e:
- LOG.exception('Failed to call controller function "%s" for operation "%s": %s' %
- (func.__name__, endpoint['operationId'], six.text_type(e)))
+ LOG.exception(
+ 'Failed to call controller function "%s" for operation "%s": %s'
+ % (func.__name__, endpoint["operationId"], six.text_type(e))
+ )
raise e
# Handle response
if resp is None:
resp = Response()
- if not hasattr(resp, '__call__'):
+ if not hasattr(resp, "__call__"):
resp = Response(json=resp)
- operation_id = endpoint['operationId']
+ operation_id = endpoint["operationId"]
# Process the response removing attributes based on the exclude_attribute and
# include_attributes query param filter values (if specified)
- include_attributes = kw.get('include_attributes', None)
- exclude_attributes = kw.get('exclude_attributes', None)
- has_include_or_exclude_attributes = bool(include_attributes) or bool(exclude_attributes)
+ include_attributes = kw.get("include_attributes", None)
+ exclude_attributes = kw.get("exclude_attributes", None)
+ has_include_or_exclude_attributes = bool(include_attributes) or bool(
+ exclude_attributes
+ )
# NOTE: We do NOT want to process stream controller response
- is_streamming_controller = endpoint.get('x-is-streaming-endpoint',
- bool('st2stream' in operation_id))
-
- if not is_streamming_controller and resp.body and has_include_or_exclude_attributes:
+ is_streamming_controller = endpoint.get(
+ "x-is-streaming-endpoint", bool("st2stream" in operation_id)
+ )
+
+ if (
+ not is_streamming_controller
+ and resp.body
+ and has_include_or_exclude_attributes
+ ):
# NOTE: We need to check for response.body attribute since resp.json throws if JSON
# response is not available
- mandatory_include_fields = getattr(controller_instance,
- 'mandatory_include_fields_response', [])
- data = self._process_response(data=resp.json,
- mandatory_include_fields=mandatory_include_fields,
- include_attributes=include_attributes,
- exclude_attributes=exclude_attributes)
+ mandatory_include_fields = getattr(
+ controller_instance, "mandatory_include_fields_response", []
+ )
+ data = self._process_response(
+ data=resp.json,
+ mandatory_include_fields=mandatory_include_fields,
+ include_attributes=include_attributes,
+ exclude_attributes=exclude_attributes,
+ )
resp.json = data
- responses = endpoint.get('responses', {})
+ responses = endpoint.get("responses", {})
response_spec = responses.get(str(resp.status_code), None)
- default_response_spec = responses.get('default', None)
+ default_response_spec = responses.get("default", None)
if not response_spec and default_response_spec:
- LOG.debug('No custom response spec found for endpoint "%s", using a default one' %
- (endpoint['operationId']))
- response_spec_name = 'default'
+ LOG.debug(
+ 'No custom response spec found for endpoint "%s", using a default one'
+ % (endpoint["operationId"])
+ )
+ response_spec_name = "default"
else:
response_spec_name = str(resp.status_code)
response_spec = response_spec or default_response_spec
- if response_spec and 'schema' in response_spec and not has_include_or_exclude_attributes:
+ if (
+ response_spec
+ and "schema" in response_spec
+ and not has_include_or_exclude_attributes
+ ):
# NOTE: We don't perform response validation when include or exclude attributes are
# provided because this means partial response which likely won't pass the validation
- LOG.debug('Using response spec "%s" for endpoint %s and status code %s' %
- (response_spec_name, endpoint['operationId'], resp.status_code))
+ LOG.debug(
+ 'Using response spec "%s" for endpoint %s and status code %s'
+ % (response_spec_name, endpoint["operationId"], resp.status_code)
+ )
try:
- validator = CustomValidator(response_spec['schema'], resolver=self.spec_resolver)
+ validator = CustomValidator(
+ response_spec["schema"], resolver=self.spec_resolver
+ )
- response_type = response_spec['schema'].get('type', 'json')
- if response_type == 'string':
+ response_type = response_spec["schema"].get("type", "json")
+ if response_type == "string":
validator.validate(resp.text)
else:
validator.validate(resp.json)
except (jsonschema.ValidationError, ValueError):
- LOG.exception('Response validation failed.')
- resp.headers.add('Warning', '199 OpenAPI "Response validation failed"')
+ LOG.exception("Response validation failed.")
+ resp.headers.add("Warning", '199 OpenAPI "Response validation failed"')
else:
- LOG.debug('No response spec found for endpoint "%s"' % (endpoint['operationId']))
+ LOG.debug(
+ 'No response spec found for endpoint "%s"' % (endpoint["operationId"])
+ )
if cookie_token:
- resp.headerlist.append(('Set-Cookie', cookie_token))
+ resp.headerlist.append(("Set-Cookie", cookie_token))
return resp
@@ -604,17 +686,24 @@ def _get_model_instance(self, model_cls, data):
instance = model_cls(**data)
except TypeError as e:
# Throw a more user-friendly exception when input data is not an object
- if 'type object argument after ** must be a mapping, not' in six.text_type(e):
+ if "type object argument after ** must be a mapping, not" in six.text_type(
+ e
+ ):
type_string = get_json_type_for_python_value(data)
- msg = ('Input body needs to be an object, got: %s' % (type_string))
+ msg = "Input body needs to be an object, got: %s" % (type_string)
raise ValueError(msg)
raise e
return instance
- def _process_response(self, data, mandatory_include_fields=None, include_attributes=None,
- exclude_attributes=None):
+ def _process_response(
+ self,
+ data,
+ mandatory_include_fields=None,
+ include_attributes=None,
+ exclude_attributes=None,
+ ):
"""
Process controller response data such as removing attributes based on the values of
exclude_attributes and include_attributes query param filters and similar.
@@ -628,8 +717,10 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu
# NOTE: include_attributes and exclude_attributes are mutually exclusive
if include_attributes and exclude_attributes:
- msg = ('exclude_attributes and include_attributes arguments are mutually exclusive. '
- 'You need to provide either one or another, but not both.')
+ msg = (
+ "exclude_attributes and include_attributes arguments are mutually exclusive. "
+ "You need to provide either one or another, but not both."
+ )
raise ValueError(msg)
# Common case - filters are not provided
@@ -637,16 +728,20 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu
return data
# Skip processing of error responses
- if isinstance(data, dict) and data.get('faultstring', None):
+ if isinstance(data, dict) and data.get("faultstring", None):
return data
# We only care about the first part of the field name since deep filtering happens inside
# MongoDB. Deep filtering here would also be quite expensive and waste of CPU cycles.
- cleaned_include_attributes = [attribute.split('.')[0] for attribute in include_attributes]
+ cleaned_include_attributes = [
+ attribute.split(".")[0] for attribute in include_attributes
+ ]
# Add in mandatory fields which always need to be present in the response (primary keys)
cleaned_include_attributes += mandatory_include_fields
- cleaned_exclude_attributes = [attribute.split('.')[0] for attribute in exclude_attributes]
+ cleaned_exclude_attributes = [
+ attribute.split(".")[0] for attribute in exclude_attributes
+ ]
# NOTE: Since those parameters are mutually exclusive we could perform more efficient
# filtering when just exclude_attributes is provided. Instead of creating a new dict, we
@@ -675,6 +770,6 @@ def process_item(item):
# get_one response
result = process_item(data)
else:
- raise ValueError('Unsupported type: %s' % (type(data)))
+ raise ValueError("Unsupported type: %s" % (type(data)))
return result
diff --git a/st2common/st2common/runners/__init__.py b/st2common/st2common/runners/__init__.py
index bcccaaf48d..d6468f78e2 100644
--- a/st2common/st2common/runners/__init__.py
+++ b/st2common/st2common/runners/__init__.py
@@ -19,14 +19,9 @@
from st2common.util import driver_loader
-__all__ = [
- 'BACKENDS_NAMESPACE',
+__all__ = ["BACKENDS_NAMESPACE", "get_available_backends", "get_backend_driver"]
- 'get_available_backends',
- 'get_backend_driver'
-]
-
-BACKENDS_NAMESPACE = 'st2common.runners.runner'
+BACKENDS_NAMESPACE = "st2common.runners.runner"
def get_available_backends():
diff --git a/st2common/st2common/runners/base.py b/st2common/st2common/runners/base.py
index 6a9656b9a1..a5692b66e6 100644
--- a/st2common/st2common/runners/base.py
+++ b/st2common/st2common/runners/base.py
@@ -42,45 +42,43 @@
subprocess = concurrency.get_subprocess_module()
__all__ = [
- 'ActionRunner',
- 'AsyncActionRunner',
- 'PollingAsyncActionRunner',
- 'GitWorktreeActionRunner',
- 'PollingAsyncActionRunner',
- 'ShellRunnerMixin',
-
- 'get_runner_module',
-
- 'get_runner',
- 'get_metadata',
+ "ActionRunner",
+ "AsyncActionRunner",
+ "PollingAsyncActionRunner",
+ "GitWorktreeActionRunner",
+ "PollingAsyncActionRunner",
+ "ShellRunnerMixin",
+ "get_runner_module",
+ "get_runner",
+ "get_metadata",
]
LOG = logging.getLogger(__name__)
# constants to lookup in runner_parameters
-RUNNER_COMMAND = 'cmd'
-RUNNER_CONTENT_VERSION = 'content_version'
-RUNNER_DEBUG = 'debug'
+RUNNER_COMMAND = "cmd"
+RUNNER_CONTENT_VERSION = "content_version"
+RUNNER_DEBUG = "debug"
def get_runner(name, config=None):
"""
Load the module and return an instance of the runner.
"""
- LOG.debug('Runner loading Python module: %s', name)
+ LOG.debug("Runner loading Python module: %s", name)
module = get_runner_module(name=name)
- LOG.debug('Instance of runner module: %s', module)
+ LOG.debug("Instance of runner module: %s", module)
if config:
- runner_kwargs = {'config': config}
+ runner_kwargs = {"config": config}
else:
runner_kwargs = {}
runner = module.get_runner(**runner_kwargs)
- LOG.debug('Instance of runner: %s', runner)
+ LOG.debug("Instance of runner: %s", runner)
return runner
@@ -95,19 +93,21 @@ def get_runner_module(name):
try:
module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False)
except NoMatches:
- name = name.replace('_', '-')
+ name = name.replace("_", "-")
try:
module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False)
except Exception as e:
available_runners = get_available_plugins(namespace=RUNNERS_NAMESPACE)
- available_runners = ', '.join(available_runners)
- msg = ('Failed to find runner %s. Make sure that the runner is available and installed '
- 'in StackStorm virtual environment. Available runners are: %s' %
- (name, available_runners))
+ available_runners = ", ".join(available_runners)
+ msg = (
+ "Failed to find runner %s. Make sure that the runner is available and installed "
+ "in StackStorm virtual environment. Available runners are: %s"
+ % (name, available_runners)
+ )
LOG.exception(msg)
- raise exc.ActionRunnerCreateError('%s\n\n%s' % (msg, six.text_type(e)))
+ raise exc.ActionRunnerCreateError("%s\n\n%s" % (msg, six.text_type(e)))
return module
@@ -120,9 +120,9 @@ def get_metadata(package_name):
"""
import pkg_resources
- file_path = pkg_resources.resource_filename(package_name, 'runner.yaml')
+ file_path = pkg_resources.resource_filename(package_name, "runner.yaml")
- with open(file_path, 'r') as fp:
+ with open(file_path, "r") as fp:
content = fp.read()
metadata = yaml.safe_load(content)
@@ -158,14 +158,14 @@ def __init__(self, runner_id):
def pre_run(self):
# Handle runner "enabled" attribute
- runner_enabled = getattr(self.runner_type, 'enabled', True)
- runner_name = getattr(self.runner_type, 'name', 'unknown')
+ runner_enabled = getattr(self.runner_type, "enabled", True)
+ runner_name = getattr(self.runner_type, "name", "unknown")
if not runner_enabled:
msg = 'Runner "%s" has been disabled by the administrator.' % runner_name
raise ValueError(msg)
- runner_parameters = getattr(self, 'runner_parameters', {}) or {}
+ runner_parameters = getattr(self, "runner_parameters", {}) or {}
self._debug = runner_parameters.get(RUNNER_DEBUG, False)
# Run will need to take an action argument
@@ -175,18 +175,20 @@ def run(self, action_parameters):
raise NotImplementedError()
def pause(self):
- runner_name = getattr(self.runner_type, 'name', 'unknown')
- raise NotImplementedError('Pause is not supported for runner %s.' % runner_name)
+ runner_name = getattr(self.runner_type, "name", "unknown")
+ raise NotImplementedError("Pause is not supported for runner %s." % runner_name)
def resume(self):
- runner_name = getattr(self.runner_type, 'name', 'unknown')
- raise NotImplementedError('Resume is not supported for runner %s.' % runner_name)
+ runner_name = getattr(self.runner_type, "name", "unknown")
+ raise NotImplementedError(
+ "Resume is not supported for runner %s." % runner_name
+ )
def cancel(self):
return (
action_constants.LIVEACTION_STATUS_CANCELED,
self.liveaction.result,
- self.liveaction.context
+ self.liveaction.context,
)
def post_run(self, status, result):
@@ -213,8 +215,8 @@ def get_user(self):
:rtype: ``str``
"""
- context = getattr(self, 'context', {}) or {}
- user = context.get('user', cfg.CONF.system_user.user)
+ context = getattr(self, "context", {}) or {}
+ user = context.get("user", cfg.CONF.system_user.user)
return user
@@ -228,18 +230,18 @@ def _get_common_action_env_variables(self):
:rtype: ``dict``
"""
result = {}
- result['ST2_ACTION_PACK_NAME'] = self.get_pack_ref()
- result['ST2_ACTION_EXECUTION_ID'] = str(self.execution_id)
- result['ST2_ACTION_API_URL'] = get_full_public_api_url()
+ result["ST2_ACTION_PACK_NAME"] = self.get_pack_ref()
+ result["ST2_ACTION_EXECUTION_ID"] = str(self.execution_id)
+ result["ST2_ACTION_API_URL"] = get_full_public_api_url()
if self.auth_token:
- result['ST2_ACTION_AUTH_TOKEN'] = self.auth_token.token
+ result["ST2_ACTION_AUTH_TOKEN"] = self.auth_token.token
return result
def __str__(self):
- attrs = ', '.join(['%s=%s' % (k, v) for k, v in six.iteritems(self.__dict__)])
- return '%s@%s(%s)' % (self.__class__.__name__, str(id(self)), attrs)
+ attrs = ", ".join(["%s=%s" % (k, v) for k, v in six.iteritems(self.__dict__)])
+ return "%s@%s(%s)" % (self.__class__.__name__, str(id(self)), attrs)
@six.add_metaclass(abc.ABCMeta)
@@ -248,7 +250,6 @@ class AsyncActionRunner(ActionRunner):
class PollingAsyncActionRunner(AsyncActionRunner):
-
@classmethod
def is_polling_enabled(cls):
return True
@@ -264,7 +265,7 @@ class GitWorktreeActionRunner(ActionRunner):
This revision is specified using "content_version" runner parameter.
"""
- WORKTREE_DIRECTORY_PREFIX = 'st2-git-worktree-'
+ WORKTREE_DIRECTORY_PREFIX = "st2-git-worktree-"
def __init__(self, runner_id):
super(GitWorktreeActionRunner, self).__init__(runner_id=runner_id)
@@ -284,11 +285,13 @@ def pre_run(self):
# Override entry_point so it points to git worktree directory
pack_name = self.get_pack_name()
- entry_point = self._get_entry_point_for_worktree_path(pack_name=pack_name,
- entry_point=self.entry_point,
- worktree_path=self.git_worktree_path)
+ entry_point = self._get_entry_point_for_worktree_path(
+ pack_name=pack_name,
+ entry_point=self.entry_point,
+ worktree_path=self.git_worktree_path,
+ )
- assert(entry_point.startswith(self.git_worktree_path))
+ assert entry_point.startswith(self.git_worktree_path)
self.entry_point = entry_point
@@ -298,9 +301,11 @@ def post_run(self, status, result):
# Remove git worktree directories (if used and available)
if self.git_worktree_path and self.git_worktree_revision:
pack_name = self.get_pack_name()
- self.cleanup_git_worktree(worktree_path=self.git_worktree_path,
- content_version=self.git_worktree_revision,
- pack_name=pack_name)
+ self.cleanup_git_worktree(
+ worktree_path=self.git_worktree_path,
+ content_version=self.git_worktree_revision,
+ pack_name=pack_name,
+ )
def create_git_worktree(self, content_version):
"""
@@ -318,51 +323,59 @@ def create_git_worktree(self, content_version):
self.git_worktree_path = worktree_path
extra = {
- 'pack_name': pack_name,
- 'pack_directory': pack_directory,
- 'content_version': content_version,
- 'worktree_path': worktree_path
+ "pack_name": pack_name,
+ "pack_directory": pack_directory,
+ "content_version": content_version,
+ "worktree_path": worktree_path,
}
if not os.path.isdir(pack_directory):
- msg = ('Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t '
- 'exist.' % (pack_name, pack_directory))
+ msg = (
+ 'Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t '
+ "exist." % (pack_name, pack_directory)
+ )
raise ValueError(msg)
args = [
- 'git',
- '-C',
+ "git",
+ "-C",
pack_directory,
- 'worktree',
- 'add',
+ "worktree",
+ "add",
worktree_path,
- content_version
+ content_version,
]
cmd = list2cmdline(args)
- LOG.debug('Creating git worktree for pack "%s", content version "%s" and execution '
- 'id "%s" in "%s"' % (pack_name, content_version, self.execution_id,
- worktree_path), extra=extra)
- LOG.debug('Command: %s' % (cmd))
- exit_code, stdout, stderr, timed_out = run_command(cmd=cmd,
- cwd=pack_directory,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- shell=True)
+ LOG.debug(
+ 'Creating git worktree for pack "%s", content version "%s" and execution '
+ 'id "%s" in "%s"'
+ % (pack_name, content_version, self.execution_id, worktree_path),
+ extra=extra,
+ )
+ LOG.debug("Command: %s" % (cmd))
+ exit_code, stdout, stderr, timed_out = run_command(
+ cmd=cmd,
+ cwd=pack_directory,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True,
+ )
if exit_code != 0:
- self._handle_git_worktree_error(pack_name=pack_name, pack_directory=pack_directory,
- content_version=content_version,
- exit_code=exit_code, stdout=stdout, stderr=stderr)
+ self._handle_git_worktree_error(
+ pack_name=pack_name,
+ pack_directory=pack_directory,
+ content_version=content_version,
+ exit_code=exit_code,
+ stdout=stdout,
+ stderr=stderr,
+ )
else:
LOG.debug('Git worktree created in "%s"' % (worktree_path), extra=extra)
# Make sure system / action runner user can access that directory
- args = [
- 'chmod',
- '777',
- worktree_path
- ]
+ args = ["chmod", "777", worktree_path]
cmd = list2cmdline(args)
run_command(cmd=cmd, shell=True)
@@ -375,15 +388,19 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version):
:rtype: ``bool``
"""
# Safety check to make sure we don't remove something outside /tmp
- assert(worktree_path.startswith('/tmp'))
- assert(worktree_path.startswith('/tmp/%s' % (self.WORKTREE_DIRECTORY_PREFIX)))
+ assert worktree_path.startswith("/tmp")
+ assert worktree_path.startswith("/tmp/%s" % (self.WORKTREE_DIRECTORY_PREFIX))
if self._debug:
- LOG.debug('Not removing git worktree "%s" because debug mode is enabled' %
- (worktree_path))
+ LOG.debug(
+ 'Not removing git worktree "%s" because debug mode is enabled'
+ % (worktree_path)
+ )
else:
- LOG.debug('Removing git worktree "%s" for pack "%s" and content version "%s"' %
- (worktree_path, pack_name, content_version))
+ LOG.debug(
+ 'Removing git worktree "%s" for pack "%s" and content version "%s"'
+ % (worktree_path, pack_name, content_version)
+ )
try:
shutil.rmtree(worktree_path, ignore_errors=True)
@@ -392,36 +409,43 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version):
return True
- def _handle_git_worktree_error(self, pack_name, pack_directory, content_version, exit_code,
- stdout, stderr):
+ def _handle_git_worktree_error(
+ self, pack_name, pack_directory, content_version, exit_code, stdout, stderr
+ ):
"""
Handle "git worktree" related errors and throw a more user-friendly exception.
"""
error_prefix = 'Failed to create git worktree for pack "%s": ' % (pack_name)
if isinstance(stdout, six.binary_type):
- stdout = stdout.decode('utf-8')
+ stdout = stdout.decode("utf-8")
if isinstance(stderr, six.binary_type):
- stderr = stderr.decode('utf-8')
+ stderr = stderr.decode("utf-8")
# 1. Installed version of git which doesn't support worktree command
if "git: 'worktree' is not a git command." in stderr:
- msg = ('Installed git version doesn\'t support git worktree command. '
- 'To be able to utilize this functionality you need to use git '
- '>= 2.5.0.')
+ msg = (
+ "Installed git version doesn't support git worktree command. "
+ "To be able to utilize this functionality you need to use git "
+ ">= 2.5.0."
+ )
raise ValueError(error_prefix + msg)
# 2. Provided pack directory is not a git repository
if "Not a git repository" in stderr:
- msg = ('Pack directory "%s" is not a git repository. To utilize this functionality, '
- 'pack directory needs to be a git repository.' % (pack_directory))
+ msg = (
+ 'Pack directory "%s" is not a git repository. To utilize this functionality, '
+ "pack directory needs to be a git repository." % (pack_directory)
+ )
raise ValueError(error_prefix + msg)
# 3. Invalid revision provided
if "invalid reference" in stderr:
- msg = ('Invalid content_version "%s" provided. Make sure that git repository is up '
- 'to date and contains that revision.' % (content_version))
+ msg = (
+ 'Invalid content_version "%s" provided. Make sure that git repository is up '
+ "to date and contains that revision." % (content_version)
+ )
raise ValueError(error_prefix + msg)
def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_path):
@@ -433,10 +457,10 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa
"""
pack_base_path = get_pack_base_path(pack_name=pack_name)
- new_entry_point = entry_point.replace(pack_base_path, '')
+ new_entry_point = entry_point.replace(pack_base_path, "")
# Remove leading slash (if any)
- if new_entry_point.startswith('/'):
+ if new_entry_point.startswith("/"):
new_entry_point = new_entry_point[1:]
new_entry_point = os.path.join(worktree_path, new_entry_point)
@@ -444,7 +468,7 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa
# Check to prevent directory traversal
common_prefix = os.path.commonprefix([worktree_path, new_entry_point])
if common_prefix != worktree_path:
- raise ValueError('entry_point is not located inside the pack directory')
+ raise ValueError("entry_point is not located inside the pack directory")
return new_entry_point
@@ -483,11 +507,11 @@ def _get_script_args(self, action_parameters):
is_script_run_as_cmd = self.runner_parameters.get(RUNNER_COMMAND, None)
- pos_args = ''
+ pos_args = ""
named_args = {}
if is_script_run_as_cmd:
- pos_args = self.runner_parameters.get(RUNNER_COMMAND, '')
+ pos_args = self.runner_parameters.get(RUNNER_COMMAND, "")
named_args = action_parameters
else:
pos_args, named_args = action_utils.get_args(action_parameters, self.action)
diff --git a/st2common/st2common/runners/base_action.py b/st2common/st2common/runners/base_action.py
index bc915d2b4f..244a4235c9 100644
--- a/st2common/st2common/runners/base_action.py
+++ b/st2common/st2common/runners/base_action.py
@@ -21,9 +21,7 @@
from st2common.runners.utils import get_logger_for_python_runner_action
from st2common.runners.utils import PackConfigDict
-__all__ = [
- 'Action'
-]
+__all__ = ["Action"]
@six.add_metaclass(abc.ABCMeta)
@@ -45,16 +43,17 @@ def __init__(self, config=None, action_service=None):
self.config = config or {}
self.action_service = action_service
- if action_service and getattr(action_service, '_action_wrapper', None):
- log_level = getattr(action_service._action_wrapper, '_log_level', 'debug')
- pack_name = getattr(action_service._action_wrapper, '_pack', 'unknown')
+ if action_service and getattr(action_service, "_action_wrapper", None):
+ log_level = getattr(action_service._action_wrapper, "_log_level", "debug")
+ pack_name = getattr(action_service._action_wrapper, "_pack", "unknown")
else:
- log_level = 'debug'
- pack_name = 'unknown'
+ log_level = "debug"
+ pack_name = "unknown"
self.config = PackConfigDict(pack_name, self.config)
- self.logger = get_logger_for_python_runner_action(action_name=self.__class__.__name__,
- log_level=log_level)
+ self.logger = get_logger_for_python_runner_action(
+ action_name=self.__class__.__name__, log_level=log_level
+ )
@abc.abstractmethod
def run(self, **kwargs):
diff --git a/st2common/st2common/runners/parallel_ssh.py b/st2common/st2common/runners/parallel_ssh.py
index 28f8756415..c41175c02c 100644
--- a/st2common/st2common/runners/parallel_ssh.py
+++ b/st2common/st2common/runners/parallel_ssh.py
@@ -35,13 +35,26 @@
class ParallelSSHClient(object):
- KEYS_TO_TRANSFORM = ['stdout', 'stderr']
- CONNECT_ERROR = 'Cannot connect to host.'
-
- def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_material=None, port=22,
- bastion_host=None, concurrency=10, raise_on_any_error=False, connect=True,
- passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None,
- sudo_password=False):
+ KEYS_TO_TRANSFORM = ["stdout", "stderr"]
+ CONNECT_ERROR = "Cannot connect to host."
+
+ def __init__(
+ self,
+ hosts,
+ user=None,
+ password=None,
+ pkey_file=None,
+ pkey_material=None,
+ port=22,
+ bastion_host=None,
+ concurrency=10,
+ raise_on_any_error=False,
+ connect=True,
+ passphrase=None,
+ handle_stdout_line_func=None,
+ handle_stderr_line_func=None,
+ sudo_password=False,
+ ):
"""
:param handle_stdout_line_func: Callback function which is called dynamically each time a
new stdout line is received.
@@ -65,7 +78,7 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia
self._sudo_password = sudo_password
if not hosts:
- raise Exception('Need an non-empty list of hosts to talk to.')
+ raise Exception("Need an non-empty list of hosts to talk to.")
self._pool = concurrency_lib.get_green_pool_class()(concurrency)
self._hosts_client = {}
@@ -74,8 +87,8 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia
if connect:
connect_results = self.connect(raise_on_any_error=raise_on_any_error)
- extra = {'_connect_results': connect_results}
- LOG.debug('Connect to hosts complete.', extra=extra)
+ extra = {"_connect_results": connect_results}
+ LOG.debug("Connect to hosts complete.", extra=extra)
def connect(self, raise_on_any_error=False):
"""
@@ -92,17 +105,28 @@ def connect(self, raise_on_any_error=False):
for host in self._hosts:
while not concurrency_lib.is_green_pool_free(self._pool):
concurrency_lib.sleep(self._scan_interval)
- self._pool.spawn(self._connect, host=host, results=results,
- raise_on_any_error=raise_on_any_error)
+ self._pool.spawn(
+ self._connect,
+ host=host,
+ results=results,
+ raise_on_any_error=raise_on_any_error,
+ )
concurrency_lib.green_pool_wait_all(self._pool)
if self._successful_connects < 1:
# We definitely have to raise an exception in this case.
- LOG.error('Unable to connect to any of the hosts.',
- extra={'connect_results': results})
- msg = ('Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s' %
- (self._hosts, json.dumps(results, indent=2)))
+ LOG.error(
+ "Unable to connect to any of the hosts.",
+ extra={"connect_results": results},
+ )
+ msg = (
+ "Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s"
+ % (
+ self._hosts,
+ json.dumps(results, indent=2),
+ )
+ )
raise NoHostsConnectedToException(msg)
return results
@@ -124,10 +148,7 @@ def run(self, cmd, timeout=None):
:rtype: ``dict`` of ``str`` to ``dict``
"""
- options = {
- 'cmd': cmd,
- 'timeout': timeout
- }
+ options = {"cmd": cmd, "timeout": timeout}
results = self._execute_in_pool(self._run_command, **options)
return results
@@ -152,13 +173,13 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
"""
if not os.path.exists(local_path):
- raise Exception('Local path %s does not exist.' % local_path)
+ raise Exception("Local path %s does not exist." % local_path)
options = {
- 'local_path': local_path,
- 'remote_path': remote_path,
- 'mode': mode,
- 'mirror_local_mode': mirror_local_mode
+ "local_path": local_path,
+ "remote_path": remote_path,
+ "mode": mode,
+ "mirror_local_mode": mirror_local_mode,
}
return self._execute_in_pool(self._put_files, **options)
@@ -173,9 +194,7 @@ def mkdir(self, path):
:rtype path: ``dict`` of ``str`` to ``dict``
"""
- options = {
- 'path': path
- }
+ options = {"path": path}
return self._execute_in_pool(self._mkdir, **options)
def delete_file(self, path):
@@ -188,9 +207,7 @@ def delete_file(self, path):
:rtype path: ``dict`` of ``str`` to ``dict``
"""
- options = {
- 'path': path
- }
+ options = {"path": path}
return self._execute_in_pool(self._delete_file, **options)
def delete_dir(self, path, force=False, timeout=None):
@@ -203,10 +220,7 @@ def delete_dir(self, path, force=False, timeout=None):
:rtype path: ``dict`` of ``str`` to ``dict``
"""
- options = {
- 'path': path,
- 'force': force
- }
+ options = {"path": path, "force": force}
return self._execute_in_pool(self._delete_dir, **options)
def close(self):
@@ -218,7 +232,7 @@ def close(self):
try:
self._hosts_client[host].close()
except:
- LOG.exception('Failed shutting down SSH connection to host: %s', host)
+ LOG.exception("Failed shutting down SSH connection to host: %s", host)
def _execute_in_pool(self, execute_method, **kwargs):
results = {}
@@ -237,36 +251,41 @@ def _execute_in_pool(self, execute_method, **kwargs):
def _connect(self, host, results, raise_on_any_error=False):
(hostname, port) = self._get_host_port_info(host)
- extra = {'host': host, 'port': port, 'user': self._ssh_user}
+ extra = {"host": host, "port": port, "user": self._ssh_user}
if self._ssh_password:
- extra['password'] = ''
+ extra["password"] = ""
elif self._ssh_key_file:
- extra['key_file_path'] = self._ssh_key_file
+ extra["key_file_path"] = self._ssh_key_file
else:
- extra['private_key'] = ''
-
- LOG.debug('Connecting to host.', extra=extra)
-
- client = ParamikoSSHClient(hostname=hostname, port=port,
- username=self._ssh_user,
- password=self._ssh_password,
- bastion_host=self._bastion_host,
- key_files=self._ssh_key_file,
- key_material=self._ssh_key_material,
- passphrase=self._passphrase,
- handle_stdout_line_func=self._handle_stdout_line_func,
- handle_stderr_line_func=self._handle_stderr_line_func)
+ extra["private_key"] = ""
+
+ LOG.debug("Connecting to host.", extra=extra)
+
+ client = ParamikoSSHClient(
+ hostname=hostname,
+ port=port,
+ username=self._ssh_user,
+ password=self._ssh_password,
+ bastion_host=self._bastion_host,
+ key_files=self._ssh_key_file,
+ key_material=self._ssh_key_material,
+ passphrase=self._passphrase,
+ handle_stdout_line_func=self._handle_stdout_line_func,
+ handle_stderr_line_func=self._handle_stderr_line_func,
+ )
try:
client.connect()
except SSHException as ex:
LOG.exception(ex)
if raise_on_any_error:
raise
- error_dict = self._generate_error_result(exc=ex, message='Connection error.')
+ error_dict = self._generate_error_result(
+ exc=ex, message="Connection error."
+ )
self._bad_hosts[hostname] = error_dict
results[hostname] = error_dict
except Exception as ex:
- error = 'Failed connecting to host %s.' % hostname
+ error = "Failed connecting to host %s." % hostname
LOG.exception(error)
if raise_on_any_error:
raise
@@ -276,16 +295,19 @@ def _connect(self, host, results, raise_on_any_error=False):
else:
self._successful_connects += 1
self._hosts_client[hostname] = client
- results[hostname] = {'message': 'Connected to host.'}
+ results[hostname] = {"message": "Connected to host."}
def _run_command(self, host, cmd, results, timeout=None):
try:
- LOG.debug('Running command: %s on host: %s.', cmd, host)
+ LOG.debug("Running command: %s on host: %s.", cmd, host)
client = self._hosts_client[host]
- (stdout, stderr, exit_code) = client.run(cmd, timeout=timeout,
- call_line_handler_func=True)
+ (stdout, stderr, exit_code) = client.run(
+ cmd, timeout=timeout, call_line_handler_func=True
+ )
- result = self._handle_command_result(stdout=stdout, stderr=stderr, exit_code=exit_code)
+ result = self._handle_command_result(
+ stdout=stdout, stderr=stderr, exit_code=exit_code
+ )
results[host] = result
except Exception as ex:
cmd = self._sanitize_command_string(cmd=cmd)
@@ -293,20 +315,24 @@ def _run_command(self, host, cmd, results, timeout=None):
LOG.exception(error)
results[host] = self._generate_error_result(exc=ex, message=error)
- def _put_files(self, local_path, remote_path, host, results, mode=None,
- mirror_local_mode=False):
+ def _put_files(
+ self, local_path, remote_path, host, results, mode=None, mirror_local_mode=False
+ ):
try:
- LOG.debug('Copying file to host: %s' % host)
+ LOG.debug("Copying file to host: %s" % host)
if os.path.isdir(local_path):
result = self._hosts_client[host].put_dir(local_path, remote_path)
else:
- result = self._hosts_client[host].put(local_path, remote_path,
- mirror_local_mode=mirror_local_mode,
- mode=mode)
- LOG.debug('Result of copy: %s' % result)
+ result = self._hosts_client[host].put(
+ local_path,
+ remote_path,
+ mirror_local_mode=mirror_local_mode,
+ mode=mode,
+ )
+ LOG.debug("Result of copy: %s" % result)
results[host] = result
except Exception as ex:
- error = 'Failed sending file(s) in path %s to host %s' % (local_path, host)
+ error = "Failed sending file(s) in path %s to host %s" % (local_path, host)
LOG.exception(error)
results[host] = self._generate_error_result(exc=ex, message=error)
@@ -324,16 +350,18 @@ def _delete_file(self, host, path, results):
result = self._hosts_client[host].delete_file(path)
results[host] = result
except Exception as ex:
- error = 'Failed deleting file %s on host %s.' % (path, host)
+ error = "Failed deleting file %s on host %s." % (path, host)
LOG.exception(error)
results[host] = self._generate_error_result(exc=ex, message=error)
def _delete_dir(self, host, path, results, force=False, timeout=None):
try:
- result = self._hosts_client[host].delete_dir(path, force=force, timeout=timeout)
+ result = self._hosts_client[host].delete_dir(
+ path, force=force, timeout=timeout
+ )
results[host] = result
except Exception as ex:
- error = 'Failed deleting dir %s on host %s.' % (path, host)
+ error = "Failed deleting dir %s on host %s." % (path, host)
LOG.exception(error)
results[host] = self._generate_error_result(exc=ex, message=error)
@@ -347,20 +375,27 @@ def _get_host_port_info(self, host_str):
def _handle_command_result(self, stdout, stderr, exit_code):
# Detect if user provided an invalid sudo password or sudo is not configured for that user
if self._sudo_password:
- if re.search(r'sudo: \d+ incorrect password attempts', stderr):
- match = re.search(r'\[sudo\] password for (.+?)\:', stderr)
+ if re.search(r"sudo: \d+ incorrect password attempts", stderr):
+ match = re.search(r"\[sudo\] password for (.+?)\:", stderr)
if match:
username = match.groups()[0]
else:
- username = 'unknown'
+ username = "unknown"
- error = ('Invalid sudo password provided or sudo is not configured for this user '
- '(%s)' % (username))
+ error = (
+ "Invalid sudo password provided or sudo is not configured for this user "
+ "(%s)" % (username)
+ )
raise ValueError(error)
- is_succeeded = (exit_code == 0)
- result_dict = {'stdout': stdout, 'stderr': stderr, 'return_code': exit_code,
- 'succeeded': is_succeeded, 'failed': not is_succeeded}
+ is_succeeded = exit_code == 0
+ result_dict = {
+ "stdout": stdout,
+ "stderr": stderr,
+ "return_code": exit_code,
+ "succeeded": is_succeeded,
+ "failed": not is_succeeded,
+ }
result = jsonify.json_loads(result_dict, ParallelSSHClient.KEYS_TO_TRANSFORM)
return result
@@ -375,8 +410,11 @@ def _sanitize_command_string(cmd):
if not cmd:
return cmd
- result = re.sub(r'ST2_ACTION_AUTH_TOKEN=(.+?)\s+?', 'ST2_ACTION_AUTH_TOKEN=%s ' %
- (MASKED_ATTRIBUTE_VALUE), cmd)
+ result = re.sub(
+ r"ST2_ACTION_AUTH_TOKEN=(.+?)\s+?",
+ "ST2_ACTION_AUTH_TOKEN=%s " % (MASKED_ATTRIBUTE_VALUE),
+ cmd,
+ )
return result
@staticmethod
@@ -388,8 +426,8 @@ def _generate_error_result(exc, message):
:param message: Error message which will be prefixed to the exception exception message.
:type message: ``str``
"""
- exc_message = getattr(exc, 'message', str(exc))
- error_message = '%s %s' % (message, exc_message)
+ exc_message = getattr(exc, "message", str(exc))
+ error_message = "%s %s" % (message, exc_message)
traceback_message = traceback.format_exc()
if isinstance(exc, SSHCommandTimeoutError):
@@ -399,21 +437,24 @@ def _generate_error_result(exc, message):
timeout = False
return_code = 255
- stdout = getattr(exc, 'stdout', None) or ''
- stderr = getattr(exc, 'stderr', None) or ''
+ stdout = getattr(exc, "stdout", None) or ""
+ stderr = getattr(exc, "stderr", None) or ""
error_dict = {
- 'failed': True,
- 'succeeded': False,
- 'timeout': timeout,
- 'return_code': return_code,
- 'stdout': stdout,
- 'stderr': stderr,
- 'error': error_message,
- 'traceback': traceback_message,
+ "failed": True,
+ "succeeded": False,
+ "timeout": timeout,
+ "return_code": return_code,
+ "stdout": stdout,
+ "stderr": stderr,
+ "error": error_message,
+ "traceback": traceback_message,
}
return error_dict
def __repr__(self):
- return ('' %
- (repr(self._hosts), self._ssh_user, id(self)))
+ return "" % (
+ repr(self._hosts),
+ self._ssh_user,
+ id(self),
+ )
diff --git a/st2common/st2common/runners/paramiko_ssh.py b/st2common/st2common/runners/paramiko_ssh.py
index c42c4eb89f..7530a532d9 100644
--- a/st2common/st2common/runners/paramiko_ssh.py
+++ b/st2common/st2common/runners/paramiko_ssh.py
@@ -35,14 +35,13 @@
from st2common.util.misc import strip_shell_chars
from st2common.util.misc import sanitize_output
from st2common.util.shell import quote_unix
-from st2common.constants.runners import DEFAULT_SSH_PORT, REMOTE_RUNNER_PRIVATE_KEY_HEADER
+from st2common.constants.runners import (
+ DEFAULT_SSH_PORT,
+ REMOTE_RUNNER_PRIVATE_KEY_HEADER,
+)
from st2common.util import concurrency
-__all__ = [
- 'ParamikoSSHClient',
-
- 'SSHCommandTimeoutError'
-]
+__all__ = ["ParamikoSSHClient", "SSHCommandTimeoutError"]
class SSHCommandTimeoutError(Exception):
@@ -63,13 +62,21 @@ def __init__(self, cmd, timeout, ssh_connect_timeout, stdout=None, stderr=None):
self.ssh_connect_timeout = ssh_connect_timeout
self.stdout = stdout
self.stderr = stderr
- self.message = ('Command didn\'t finish in %s seconds or the SSH connection '
- 'did not succeed in %s seconds' % (timeout, ssh_connect_timeout))
+ self.message = (
+ "Command didn't finish in %s seconds or the SSH connection "
+ "did not succeed in %s seconds" % (timeout, ssh_connect_timeout)
+ )
super(SSHCommandTimeoutError, self).__init__(self.message)
def __repr__(self):
- return ('' %
- (self.cmd, self.timeout, self.ssh_connect_timeout))
+ return (
+ ''
+ % (
+ self.cmd,
+ self.timeout,
+ self.ssh_connect_timeout,
+ )
+ )
def __str__(self):
return self.message
@@ -86,9 +93,20 @@ class ParamikoSSHClient(object):
# How long to sleep while waiting for command to finish to prevent busy waiting
SLEEP_DELAY = 0.2
- def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None,
- bastion_host=None, key_files=None, key_material=None, timeout=None,
- passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None):
+ def __init__(
+ self,
+ hostname,
+ port=DEFAULT_SSH_PORT,
+ username=None,
+ password=None,
+ bastion_host=None,
+ key_files=None,
+ key_material=None,
+ timeout=None,
+ passphrase=None,
+ handle_stdout_line_func=None,
+ handle_stderr_line_func=None,
+ ):
"""
Authentication is always attempted in the following order:
@@ -114,8 +132,7 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None
self._handle_stderr_line_func = handle_stderr_line_func
self.ssh_config_file = os.path.expanduser(
- cfg.CONF.ssh_runner.ssh_config_file_path or
- '~/.ssh/config'
+ cfg.CONF.ssh_runner.ssh_config_file_path or "~/.ssh/config"
)
if self.timeout and int(self.ssh_connect_timeout) > int(self.timeout) - 2:
@@ -140,14 +157,16 @@ def connect(self):
:rtype: ``bool``
"""
if self.bastion_host:
- self.logger.debug('Bastion host specified, connecting')
+ self.logger.debug("Bastion host specified, connecting")
self.bastion_client = self._connect(host=self.bastion_host)
transport = self.bastion_client.get_transport()
real_addr = (self.hostname, self.port)
# fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour
# see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad
- local_addr = ('', 0)
- self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr)
+ local_addr = ("", 0)
+ self.bastion_socket = transport.open_channel(
+ "direct-tcpip", real_addr, local_addr
+ )
self.client = self._connect(host=self.hostname, socket=self.bastion_socket)
return True
@@ -173,17 +192,24 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
"""
if not local_path or not remote_path:
- raise Exception('Need both local_path and remote_path. local: %s, remote: %s' %
- local_path, remote_path)
+ raise Exception(
+ "Need both local_path and remote_path. local: %s, remote: %s"
+ % local_path,
+ remote_path,
+ )
local_path = quote_unix(local_path)
remote_path = quote_unix(remote_path)
- extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
- '_mirror_local_mode': mirror_local_mode}
- self.logger.debug('Uploading file', extra=extra)
+ extra = {
+ "_local_path": local_path,
+ "_remote_path": remote_path,
+ "_mode": mode,
+ "_mirror_local_mode": mirror_local_mode,
+ }
+ self.logger.debug("Uploading file", extra=extra)
if not os.path.exists(local_path):
- raise Exception('Path %s does not exist locally.' % local_path)
+ raise Exception("Path %s does not exist locally." % local_path)
rattrs = self.sftp.put(local_path, remote_path)
@@ -199,7 +225,7 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False):
remote_mode = rattrs.st_mode
# Only bitshift if we actually got an remote_mode
if remote_mode is not None:
- remote_mode = (remote_mode & 0o7777)
+ remote_mode = remote_mode & 0o7777
if local_mode != remote_mode:
self.sftp.chmod(remote_path, local_mode)
@@ -225,9 +251,13 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
:rtype: ``list`` of ``str``
"""
- extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode,
- '_mirror_local_mode': mirror_local_mode}
- self.logger.debug('Uploading dir', extra=extra)
+ extra = {
+ "_local_path": local_path,
+ "_remote_path": remote_path,
+ "_mode": mode,
+ "_mirror_local_mode": mirror_local_mode,
+ }
+ self.logger.debug("Uploading dir", extra=extra)
if os.path.basename(local_path):
strip = os.path.dirname(local_path)
@@ -237,10 +267,10 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
remote_paths = []
for context, dirs, files in os.walk(local_path):
- rcontext = context.replace(strip, '', 1)
+ rcontext = context.replace(strip, "", 1)
# normalize pathname separators with POSIX separator
- rcontext = rcontext.replace(os.sep, '/')
- rcontext = rcontext.lstrip('/')
+ rcontext = rcontext.replace(os.sep, "/")
+ rcontext = rcontext.lstrip("/")
rcontext = posixpath.join(remote_path, rcontext)
if not self.exists(rcontext):
@@ -255,8 +285,12 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False):
local_path = os.path.join(context, f)
n = posixpath.join(rcontext, f)
# Note that quote_unix is done by put anyways.
- p = self.put(local_path=local_path, remote_path=n,
- mirror_local_mode=mirror_local_mode, mode=mode)
+ p = self.put(
+ local_path=local_path,
+ remote_path=n,
+ mirror_local_mode=mirror_local_mode,
+ mode=mode,
+ )
remote_paths.append(p)
return remote_paths
@@ -290,8 +324,8 @@ def mkdir(self, dir_path):
"""
dir_path = quote_unix(dir_path)
- extra = {'_dir_path': dir_path}
- self.logger.debug('mkdir', extra=extra)
+ extra = {"_dir_path": dir_path}
+ self.logger.debug("mkdir", extra=extra)
return self.sftp.mkdir(dir_path)
def delete_file(self, path):
@@ -307,8 +341,8 @@ def delete_file(self, path):
"""
path = quote_unix(path)
- extra = {'_path': path}
- self.logger.debug('Deleting file', extra=extra)
+ extra = {"_path": path}
+ self.logger.debug("Deleting file", extra=extra)
self.sftp.unlink(path)
return True
@@ -331,15 +365,15 @@ def delete_dir(self, path, force=False, timeout=None):
"""
path = quote_unix(path)
- extra = {'_path': path}
+ extra = {"_path": path}
if force:
- command = 'rm -rf %s' % path
- extra['_command'] = command
- extra['_force'] = force
- self.logger.debug('Deleting dir', extra=extra)
+ command = "rm -rf %s" % path
+ extra["_command"] = command
+ extra["_force"] = force
+ self.logger.debug("Deleting dir", extra=extra)
return self.run(command, timeout=timeout)
- self.logger.debug('Deleting dir', extra=extra)
+ self.logger.debug("Deleting dir", extra=extra)
return self.sftp.rmdir(path)
def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
@@ -359,8 +393,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
if quote:
cmd = quote_unix(cmd)
- extra = {'_cmd': cmd}
- self.logger.info('Executing command', extra=extra)
+ extra = {"_cmd": cmd}
+ self.logger.info("Executing command", extra=extra)
# Use the system default buffer size
bufsize = -1
@@ -369,7 +403,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
chan = transport.open_session()
start_time = time.time()
- if cmd.startswith('sudo'):
+ if cmd.startswith("sudo"):
# Note that fabric does this as well. If you set pty, stdout and stderr
# streams will be combined into one.
# NOTE: If pty is used, every new line character \n will be converted to \r\n which
@@ -386,7 +420,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
# Create a stdin file and immediately close it to prevent any
# interactive script from hanging the process.
- stdin = chan.makefile('wb', bufsize)
+ stdin = chan.makefile("wb", bufsize)
stdin.close()
# Receive all the output
@@ -400,12 +434,14 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
exit_status_ready = chan.exit_status_ready()
if exit_status_ready:
- stdout_data = self._consume_stdout(chan=chan,
- call_line_handler_func=call_line_handler_func)
+ stdout_data = self._consume_stdout(
+ chan=chan, call_line_handler_func=call_line_handler_func
+ )
stdout_data = stdout_data.getvalue()
- stderr_data = self._consume_stderr(chan=chan,
- call_line_handler_func=call_line_handler_func)
+ stderr_data = self._consume_stderr(
+ chan=chan, call_line_handler_func=call_line_handler_func
+ )
stderr_data = stderr_data.getvalue()
stdout.write(stdout_data)
@@ -413,7 +449,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
while not exit_status_ready:
current_time = time.time()
- elapsed_time = (current_time - start_time)
+ elapsed_time = current_time - start_time
if timeout and (elapsed_time > timeout):
# TODO: Is this the right way to clean up?
@@ -421,16 +457,22 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty)
stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty)
- raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout,
- ssh_connect_timeout=self.ssh_connect_timeout,
- stdout=stdout, stderr=stderr)
-
- stdout_data = self._consume_stdout(chan=chan,
- call_line_handler_func=call_line_handler_func)
+ raise SSHCommandTimeoutError(
+ cmd=cmd,
+ timeout=timeout,
+ ssh_connect_timeout=self.ssh_connect_timeout,
+ stdout=stdout,
+ stderr=stderr,
+ )
+
+ stdout_data = self._consume_stdout(
+ chan=chan, call_line_handler_func=call_line_handler_func
+ )
stdout_data = stdout_data.getvalue()
- stderr_data = self._consume_stderr(chan=chan,
- call_line_handler_func=call_line_handler_func)
+ stderr_data = self._consume_stderr(
+ chan=chan, call_line_handler_func=call_line_handler_func
+ )
stderr_data = stderr_data.getvalue()
stdout.write(stdout_data)
@@ -453,8 +495,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False):
stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty)
stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty)
- extra = {'_status': status, '_stdout': stdout, '_stderr': stderr}
- self.logger.debug('Command finished', extra=extra)
+ extra = {"_status": status, "_stdout": stdout, "_stderr": stderr}
+ self.logger.debug("Command finished", extra=extra)
return [stdout, stderr, status]
@@ -499,7 +541,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False):
data = chan.recv(self.CHUNK_SIZE)
if six.PY3 and isinstance(data, six.text_type):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
out += data
@@ -512,7 +554,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False):
data = chan.recv(self.CHUNK_SIZE)
if six.PY3 and isinstance(data, six.text_type):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
out += data
@@ -520,14 +562,14 @@ def _consume_stdout(self, chan, call_line_handler_func=False):
if self._handle_stdout_line_func and call_line_handler_func:
data = strip_shell_chars(stdout.getvalue())
- lines = data.split('\n')
+ lines = data.split("\n")
lines = [line for line in lines if line]
for line in lines:
# Note: If this function performs network operating no sleep is
# needed, otherwise if a long blocking operating is performed,
# sleep is recommended to yield and prevent from busy looping
- self._handle_stdout_line_func(line=line + '\n')
+ self._handle_stdout_line_func(line=line + "\n")
stdout.seek(0)
@@ -545,7 +587,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False):
data = chan.recv_stderr(self.CHUNK_SIZE)
if six.PY3 and isinstance(data, six.text_type):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
out += data
@@ -558,7 +600,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False):
data = chan.recv_stderr(self.CHUNK_SIZE)
if six.PY3 and isinstance(data, six.text_type):
- data = data.encode('utf-8')
+ data = data.encode("utf-8")
out += data
@@ -566,14 +608,14 @@ def _consume_stderr(self, chan, call_line_handler_func=False):
if self._handle_stderr_line_func and call_line_handler_func:
data = strip_shell_chars(stderr.getvalue())
- lines = data.split('\n')
+ lines = data.split("\n")
lines = [line for line in lines if line]
for line in lines:
# Note: If this function performs network operating no sleep is
# needed, otherwise if a long blocking operating is performed,
# sleep is recommended to yield and prevent from busy looping
- self._handle_stderr_line_func(line=line + '\n')
+ self._handle_stderr_line_func(line=line + "\n")
stderr.seek(0)
@@ -581,9 +623,9 @@ def _consume_stderr(self, chan, call_line_handler_func=False):
def _get_decoded_data(self, data):
try:
- return data.decode('utf-8')
+ return data.decode("utf-8")
except:
- self.logger.exception('Non UTF-8 character found in data: %s', data)
+ self.logger.exception("Non UTF-8 character found in data: %s", data)
raise
def _get_pkey_object(self, key_material, passphrase):
@@ -604,13 +646,17 @@ def _get_pkey_object(self, key_material, passphrase):
# exception letting the user know we expect the contents a not a path.
# Note: We do it here and not up the stack to avoid false positives.
contains_header = REMOTE_RUNNER_PRIVATE_KEY_HEADER in key_material.lower()
- if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1):
- msg = ('"private_key" parameter needs to contain private key data / content and not '
- 'a path')
+ if not contains_header and (
+ key_material.count("/") >= 1 or key_material.count("\\") >= 1
+ ):
+ msg = (
+ '"private_key" parameter needs to contain private key data / content and not '
+ "a path"
+ )
elif passphrase:
- msg = 'Invalid passphrase or invalid/unsupported key type'
+ msg = "Invalid passphrase or invalid/unsupported key type"
else:
- msg = 'Invalid or unsupported key type'
+ msg = "Invalid or unsupported key type"
raise paramiko.ssh_exception.SSHException(msg)
@@ -636,19 +682,23 @@ def _connect(self, host, socket=None):
:rtype: :class:`paramiko.SSHClient`
"""
- conninfo = {'hostname': host,
- 'allow_agent': False,
- 'look_for_keys': False,
- 'timeout': self.ssh_connect_timeout}
+ conninfo = {
+ "hostname": host,
+ "allow_agent": False,
+ "look_for_keys": False,
+ "timeout": self.ssh_connect_timeout,
+ }
ssh_config_file_info = {}
if cfg.CONF.ssh_runner.use_ssh_config:
ssh_config_file_info = self._get_ssh_config_for_host(host)
- ssh_config_username = ssh_config_file_info.get('user', None)
- ssh_config_port = ssh_config_file_info.get('port', None)
+ ssh_config_username = ssh_config_file_info.get("user", None)
+ ssh_config_port = ssh_config_file_info.get("port", None)
- self.username = (self.username or ssh_config_username or cfg.CONF.system_user.user)
+ self.username = (
+ self.username or ssh_config_username or cfg.CONF.system_user.user
+ )
# If a custom non-default port is provided in the SSH config file we use that over the
# default port value provided via runner parameter
@@ -660,78 +710,92 @@ def _connect(self, host, socket=None):
# If both key file and key material are provided as action parameters,
# throw an error informing user only one is required.
if self.key_files and self.key_material:
- msg = ('key_files and key_material arguments are mutually exclusive. Supply only one.')
+ msg = "key_files and key_material arguments are mutually exclusive. Supply only one."
raise ValueError(msg)
# If neither key material nor password is provided, only then we look at key file and decide
# if we want to use the user supplied one or the one in SSH config.
if not self.key_material and not self.password:
- self.key_files = (self.key_files or ssh_config_file_info.get('identityfile', None) or
- cfg.CONF.system_user.ssh_key_file)
+ self.key_files = (
+ self.key_files
+ or ssh_config_file_info.get("identityfile", None)
+ or cfg.CONF.system_user.ssh_key_file
+ )
if self.passphrase and not (self.key_files or self.key_material):
- raise ValueError('passphrase should accompany private key material')
+ raise ValueError("passphrase should accompany private key material")
credentials_provided = self.password or self.key_files or self.key_material
if not credentials_provided:
- msg = ('Either password or key file location or key material should be supplied ' +
- 'for action. You can also add an entry for host %s in SSH config file %s.' %
- (host, self.ssh_config_file))
+ msg = (
+ "Either password or key file location or key material should be supplied "
+ + "for action. You can also add an entry for host %s in SSH config file %s."
+ % (host, self.ssh_config_file)
+ )
raise ValueError(msg)
- conninfo['username'] = self.username
- conninfo['port'] = self.port
+ conninfo["username"] = self.username
+ conninfo["port"] = self.port
if self.password:
- conninfo['password'] = self.password
+ conninfo["password"] = self.password
if self.key_files:
- conninfo['key_filename'] = self.key_files
+ conninfo["key_filename"] = self.key_files
passphrase_reqd = self._is_key_file_needs_passphrase(self.key_files)
if passphrase_reqd and not self.passphrase:
- msg = ('Private key file %s is passphrase protected. Supply a passphrase.' %
- self.key_files)
+ msg = (
+ "Private key file %s is passphrase protected. Supply a passphrase."
+ % self.key_files
+ )
raise paramiko.ssh_exception.PasswordRequiredException(msg)
if self.passphrase:
# Optional passphrase for unlocking the private key
- conninfo['password'] = self.passphrase
+ conninfo["password"] = self.passphrase
if self.key_material:
- conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material,
- passphrase=self.passphrase)
+ conninfo["pkey"] = self._get_pkey_object(
+ key_material=self.key_material, passphrase=self.passphrase
+ )
if not self.password and not (self.key_files or self.key_material):
- conninfo['allow_agent'] = True
- conninfo['look_for_keys'] = True
-
- extra = {'_hostname': host, '_port': self.port,
- '_username': self.username, '_timeout': self.ssh_connect_timeout}
- self.logger.debug('Connecting to server', extra=extra)
-
- self.socket = socket or ssh_config_file_info.get('sock', None)
+ conninfo["allow_agent"] = True
+ conninfo["look_for_keys"] = True
+
+ extra = {
+ "_hostname": host,
+ "_port": self.port,
+ "_username": self.username,
+ "_timeout": self.ssh_connect_timeout,
+ }
+ self.logger.debug("Connecting to server", extra=extra)
+
+ self.socket = socket or ssh_config_file_info.get("sock", None)
if self.socket:
- conninfo['sock'] = socket
+ conninfo["sock"] = socket
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
- extra = {'_conninfo': conninfo}
- self.logger.debug('Connection info', extra=extra)
+ extra = {"_conninfo": conninfo}
+ self.logger.debug("Connection info", extra=extra)
try:
client.connect(**conninfo)
except SSHException as e:
paramiko_msg = six.text_type(e)
- if conninfo.get('password', None):
- conninfo['password'] = ''
+ if conninfo.get("password", None):
+ conninfo["password"] = ""
- msg = ('Error connecting to host %s ' % host +
- 'with connection parameters %s.' % conninfo +
- 'Paramiko error: %s.' % paramiko_msg)
+ msg = (
+ "Error connecting to host %s " % host
+ + "with connection parameters %s." % conninfo
+ + "Paramiko error: %s." % paramiko_msg
+ )
raise SSHException(msg)
return client
@@ -744,25 +808,29 @@ def _get_ssh_config_for_host(self, host):
with open(self.ssh_config_file) as f:
ssh_config_parser.parse(f)
except IOError as e:
- raise Exception('Error accessing ssh config file %s. Code: %s Reason %s' %
- (self.ssh_config_file, e.errno, e.strerror))
+ raise Exception(
+ "Error accessing ssh config file %s. Code: %s Reason %s"
+ % (self.ssh_config_file, e.errno, e.strerror)
+ )
ssh_config = ssh_config_parser.lookup(host)
- self.logger.info('Parsed SSH config file contents: %s', ssh_config)
+ self.logger.info("Parsed SSH config file contents: %s", ssh_config)
if ssh_config:
- for k in ('hostname', 'user', 'port'):
+ for k in ("hostname", "user", "port"):
if k in ssh_config:
ssh_config_info[k] = ssh_config[k]
- if 'identityfile' in ssh_config:
- key_file = ssh_config['identityfile']
+ if "identityfile" in ssh_config:
+ key_file = ssh_config["identityfile"]
if type(key_file) is list:
key_file = key_file[0]
- ssh_config_info['identityfile'] = key_file
+ ssh_config_info["identityfile"] = key_file
- if 'proxycommand' in ssh_config:
- ssh_config_info['sock'] = paramiko.ProxyCommand(ssh_config['proxycommand'])
+ if "proxycommand" in ssh_config:
+ ssh_config_info["sock"] = paramiko.ProxyCommand(
+ ssh_config["proxycommand"]
+ )
return ssh_config_info
@@ -779,5 +847,9 @@ def _is_key_file_needs_passphrase(file):
return False
def __repr__(self):
- return ('' %
- (self.hostname, self.port, self.username, id(self)))
+ return "" % (
+ self.hostname,
+ self.port,
+ self.username,
+ id(self),
+ )
diff --git a/st2common/st2common/runners/paramiko_ssh_runner.py b/st2common/st2common/runners/paramiko_ssh_runner.py
index 6c1ab053e9..f41882935f 100644
--- a/st2common/st2common/runners/paramiko_ssh_runner.py
+++ b/st2common/st2common/runners/paramiko_ssh_runner.py
@@ -29,34 +29,31 @@
from st2common.exceptions.actionrunner import ActionRunnerPreRunError
from st2common.services.action import store_execution_output_data
-__all__ = [
- 'BaseParallelSSHRunner'
-]
+__all__ = ["BaseParallelSSHRunner"]
LOG = logging.getLogger(__name__)
# constants to lookup in runner_parameters.
-RUNNER_HOSTS = 'hosts'
-RUNNER_USERNAME = 'username'
-RUNNER_PASSWORD = 'password'
-RUNNER_PRIVATE_KEY = 'private_key'
-RUNNER_PARALLEL = 'parallel'
-RUNNER_SUDO = 'sudo'
-RUNNER_SUDO_PASSWORD = 'sudo_password'
-RUNNER_ON_BEHALF_USER = 'user'
-RUNNER_REMOTE_DIR = 'dir'
-RUNNER_COMMAND = 'cmd'
-RUNNER_CWD = 'cwd'
-RUNNER_ENV = 'env'
-RUNNER_KWARG_OP = 'kwarg_op'
-RUNNER_TIMEOUT = 'timeout'
-RUNNER_SSH_PORT = 'port'
-RUNNER_BASTION_HOST = 'bastion_host'
-RUNNER_PASSPHRASE = 'passphrase'
+RUNNER_HOSTS = "hosts"
+RUNNER_USERNAME = "username"
+RUNNER_PASSWORD = "password"
+RUNNER_PRIVATE_KEY = "private_key"
+RUNNER_PARALLEL = "parallel"
+RUNNER_SUDO = "sudo"
+RUNNER_SUDO_PASSWORD = "sudo_password"
+RUNNER_ON_BEHALF_USER = "user"
+RUNNER_REMOTE_DIR = "dir"
+RUNNER_COMMAND = "cmd"
+RUNNER_CWD = "cwd"
+RUNNER_ENV = "env"
+RUNNER_KWARG_OP = "kwarg_op"
+RUNNER_TIMEOUT = "timeout"
+RUNNER_SSH_PORT = "port"
+RUNNER_BASTION_HOST = "bastion_host"
+RUNNER_PASSPHRASE = "passphrase"
class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin):
-
def __init__(self, runner_id):
super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id)
self._hosts = None
@@ -68,7 +65,7 @@ def __init__(self, runner_id):
self._password = None
self._private_key = None
self._passphrase = None
- self._kwarg_op = '--'
+ self._kwarg_op = "--"
self._cwd = None
self._env = None
self._ssh_port = None
@@ -83,13 +80,16 @@ def __init__(self, runner_id):
def pre_run(self):
super(BaseParallelSSHRunner, self).pre_run()
- LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
- self.liveaction_id)
- hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',')
+ LOG.debug(
+ 'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"',
+ self.liveaction_id,
+ )
+ hosts = self.runner_parameters.get(RUNNER_HOSTS, "").split(",")
self._hosts = [h.strip() for h in hosts if len(h) > 0]
if len(self._hosts) < 1:
- raise ActionRunnerPreRunError('No hosts specified to run action for action %s.'
- % self.liveaction_id)
+ raise ActionRunnerPreRunError(
+ "No hosts specified to run action for action %s." % self.liveaction_id
+ )
self._username = self.runner_parameters.get(RUNNER_USERNAME, None)
self._password = self.runner_parameters.get(RUNNER_PASSWORD, None)
self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None)
@@ -103,85 +103,105 @@ def pre_run(self):
self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None)
if self.context:
- self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user)
+ self._on_behalf_user = self.context.get(
+ RUNNER_ON_BEHALF_USER, self._on_behalf_user
+ )
self._cwd = self.runner_parameters.get(RUNNER_CWD, None)
self._env = self.runner_parameters.get(RUNNER_ENV, {})
- self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--')
- self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT,
- REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT)
+ self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, "--")
+ self._timeout = self.runner_parameters.get(
+ RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT
+ )
self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None)
- LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
- self.runner_id, self.liveaction_id)
+ LOG.info(
+ '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.',
+ self.runner_id,
+ self.liveaction_id,
+ )
concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1
if concurrency > self._max_concurrency:
- LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency)
+ LOG.debug("Limiting parallel SSH concurrency to %d.", concurrency)
concurrency = self._max_concurrency
client_kwargs = {
- 'hosts': self._hosts,
- 'user': self._username,
- 'port': self._ssh_port,
- 'concurrency': concurrency,
- 'bastion_host': self._bastion_host,
- 'raise_on_any_error': False,
- 'connect': True
+ "hosts": self._hosts,
+ "user": self._username,
+ "port": self._ssh_port,
+ "concurrency": concurrency,
+ "bastion_host": self._bastion_host,
+ "raise_on_any_error": False,
+ "connect": True,
}
def make_store_stdout_line_func(execution_db, action_db):
def store_stdout_line(line):
if cfg.CONF.actionrunner.stream_output:
- store_execution_output_data(execution_db=execution_db, action_db=action_db,
- data=line, output_type='stdout')
+ store_execution_output_data(
+ execution_db=execution_db,
+ action_db=action_db,
+ data=line,
+ output_type="stdout",
+ )
return store_stdout_line
def make_store_stderr_line_func(execution_db, action_db):
def store_stderr_line(line):
if cfg.CONF.actionrunner.stream_output:
- store_execution_output_data(execution_db=execution_db, action_db=action_db,
- data=line, output_type='stderr')
+ store_execution_output_data(
+ execution_db=execution_db,
+ action_db=action_db,
+ data=line,
+ output_type="stderr",
+ )
return store_stderr_line
- handle_stdout_line_func = make_store_stdout_line_func(execution_db=self.execution,
- action_db=self.action)
- handle_stderr_line_func = make_store_stderr_line_func(execution_db=self.execution,
- action_db=self.action)
+ handle_stdout_line_func = make_store_stdout_line_func(
+ execution_db=self.execution, action_db=self.action
+ )
+ handle_stderr_line_func = make_store_stderr_line_func(
+ execution_db=self.execution, action_db=self.action
+ )
if len(self._hosts) == 1:
# We only support streaming output when running action on one host. That is because
# the action output is tied to a particulat execution. User can still achieve output
# streaming for multiple hosts by running one execution per host.
- client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func
- client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func
+ client_kwargs["handle_stdout_line_func"] = handle_stdout_line_func
+ client_kwargs["handle_stderr_line_func"] = handle_stderr_line_func
else:
- LOG.debug('Real-time action output streaming is disabled, because action is running '
- 'on more than one host')
+ LOG.debug(
+ "Real-time action output streaming is disabled, because action is running "
+ "on more than one host"
+ )
if self._password:
- client_kwargs['password'] = self._password
+ client_kwargs["password"] = self._password
elif self._private_key:
# Determine if the private_key is a path to the key file or the raw key material
- is_key_material = self._is_private_key_material(private_key=self._private_key)
+ is_key_material = self._is_private_key_material(
+ private_key=self._private_key
+ )
if is_key_material:
# Raw key material
- client_kwargs['pkey_material'] = self._private_key
+ client_kwargs["pkey_material"] = self._private_key
else:
# Assume it's a path to the key file, verify the file exists
- client_kwargs['pkey_file'] = self._private_key
+ client_kwargs["pkey_file"] = self._private_key
if self._passphrase:
- client_kwargs['passphrase'] = self._passphrase
+ client_kwargs["passphrase"] = self._passphrase
else:
# Default to stanley key file specified in the config
- client_kwargs['pkey_file'] = self._ssh_key_file
+ client_kwargs["pkey_file"] = self._ssh_key_file
if self._sudo_password:
- client_kwargs['sudo_password'] = True
+ client_kwargs["sudo_password"] = True
self._parallel_ssh_client = ParallelSSHClient(**client_kwargs)
@@ -213,21 +233,22 @@ def _get_env_vars(self):
@staticmethod
def _get_result_status(result, allow_partial_failure):
- if 'error' in result and 'traceback' in result:
+ if "error" in result and "traceback" in result:
# Assume this is a global failure where the result dictionary doesn't contain entry
# per host
timeout = False
- success = result.get('succeeded', False)
- status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
- timeout=timeout)
+ success = result.get("succeeded", False)
+ status = BaseParallelSSHRunner._get_status_for_success_and_timeout(
+ success=success, timeout=timeout
+ )
return status
success = not allow_partial_failure
timeout = True
for r in six.itervalues(result):
- r_succeess = r.get('succeeded', False) if r else False
- r_timeout = r.get('timeout', False) if r else False
+ r_succeess = r.get("succeeded", False) if r else False
+ r_timeout = r.get("timeout", False) if r else False
timeout &= r_timeout
@@ -240,8 +261,9 @@ def _get_result_status(result, allow_partial_failure):
if not success:
break
- status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success,
- timeout=timeout)
+ status = BaseParallelSSHRunner._get_status_for_success_and_timeout(
+ success=success, timeout=timeout
+ )
return status
diff --git a/st2common/st2common/runners/utils.py b/st2common/st2common/runners/utils.py
index 82f1a3477c..70f7139f3f 100644
--- a/st2common/st2common/runners/utils.py
+++ b/st2common/st2common/runners/utils.py
@@ -27,14 +27,11 @@
__all__ = [
- 'PackConfigDict',
-
- 'get_logger_for_python_runner_action',
- 'get_action_class_instance',
-
- 'make_read_and_store_stream_func',
-
- 'invoke_post_run',
+ "PackConfigDict",
+ "get_logger_for_python_runner_action",
+ "get_action_class_instance",
+ "make_read_and_store_stream_func",
+ "invoke_post_run",
]
LOG = logging.getLogger(__name__)
@@ -61,6 +58,7 @@ class PackConfigDict(dict):
This class throws a user-friendly exception in case user tries to access config item which
doesn't exist in the dict.
"""
+
def __init__(self, pack_name, *args):
super(PackConfigDict, self).__init__(*args)
self._pack_name = pack_name
@@ -72,8 +70,8 @@ def __getitem__(self, key):
# Note: We use late import to avoid performance overhead
from oslo_config import cfg
- configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/')
- config_path = os.path.join(configs_path, self._pack_name + '.yaml')
+ configs_path = os.path.join(cfg.CONF.system.base_path, "configs/")
+ config_path = os.path.join(configs_path, self._pack_name + ".yaml")
msg = CONFIG_MISSING_ITEM_ERROR % (self._pack_name, key, config_path)
raise ValueError(msg)
@@ -83,11 +81,11 @@ def __setitem__(self, key, value):
super(PackConfigDict, self).__setitem__(key, value)
-def get_logger_for_python_runner_action(action_name, log_level='debug'):
+def get_logger_for_python_runner_action(action_name, log_level="debug"):
"""
Set up a logger which logs all the messages with level DEBUG and above to stderr.
"""
- logger_name = 'actions.python.%s' % (action_name)
+ logger_name = "actions.python.%s" % (action_name)
if logger_name not in LOGGERS:
level_name = log_level.upper()
@@ -97,7 +95,7 @@ def get_logger_for_python_runner_action(action_name, log_level='debug'):
console = stdlib_logging.StreamHandler()
console.setLevel(log_level_constant)
- formatter = stdlib_logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
+ formatter = stdlib_logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s")
console.setFormatter(formatter)
logger.addHandler(console)
logger.setLevel(log_level_constant)
@@ -123,8 +121,8 @@ def get_action_class_instance(action_cls, config=None, action_service=None):
:type action_service: :class:`ActionService`
"""
kwargs = {}
- kwargs['config'] = config
- kwargs['action_service'] = action_service
+ kwargs["config"] = config
+ kwargs["action_service"] = action_service
# Note: This is done for backward compatibility reasons. We first try to pass
# "action_service" argument to the action class constructor, but if that doesn't work (e.g. old
@@ -133,13 +131,15 @@ def get_action_class_instance(action_cls, config=None, action_service=None):
try:
action_instance = action_cls(**kwargs)
except TypeError as e:
- if 'unexpected keyword argument \'action_service\'' not in six.text_type(e):
+ if "unexpected keyword argument 'action_service'" not in six.text_type(e):
raise e
- LOG.debug('Action class (%s) constructor doesn\'t take "action_service" argument, '
- 'falling back to late assignment...' % (action_cls.__class__.__name__))
+ LOG.debug(
+ 'Action class (%s) constructor doesn\'t take "action_service" argument, '
+ "falling back to late assignment..." % (action_cls.__class__.__name__)
+ )
- action_service = kwargs.pop('action_service', None)
+ action_service = kwargs.pop("action_service", None)
action_instance = action_cls(**kwargs)
action_instance.action_service = action_service
@@ -166,7 +166,7 @@ def read_and_store_stream(stream, buff):
break
if isinstance(line, six.binary_type):
- line = line.decode('utf-8')
+ line = line.decode("utf-8")
buff.write(line)
@@ -175,7 +175,9 @@ def read_and_store_stream(stream, buff):
continue
if cfg.CONF.actionrunner.stream_output:
- store_data_func(execution_db=execution_db, action_db=action_db, data=line)
+ store_data_func(
+ execution_db=execution_db, action_db=action_db, data=line
+ )
except RuntimeError:
# process was terminated abruptly
pass
@@ -193,31 +195,40 @@ def invoke_post_run(liveaction_db, action_db=None):
from st2common.util import action_db as action_db_utils
from st2common.content import utils as content_utils
- LOG.info('Invoking post run for action execution %s.', liveaction_db.id)
+ LOG.info("Invoking post run for action execution %s.", liveaction_db.id)
# Identify action and runner.
if not action_db:
action_db = action_db_utils.get_action_by_ref(liveaction_db.action)
if not action_db:
- LOG.error('Unable to invoke post run. Action %s no longer exists.', liveaction_db.action)
+ LOG.error(
+ "Unable to invoke post run. Action %s no longer exists.",
+ liveaction_db.action,
+ )
return
- LOG.info('Action execution %s runs %s of runner type %s.',
- liveaction_db.id, action_db.name, action_db.runner_type['name'])
+ LOG.info(
+ "Action execution %s runs %s of runner type %s.",
+ liveaction_db.id,
+ action_db.name,
+ action_db.runner_type["name"],
+ )
# Get instance of the action runner and related configuration.
- runner_type_db = action_db_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runner_type_db = action_db_utils.get_runnertype_by_name(
+ action_db.runner_type["name"]
+ )
runner = runners.get_runner(name=runner_type_db.name)
entry_point = content_utils.get_entry_point_abs_path(
- pack=action_db.pack,
- entry_point=action_db.entry_point)
+ pack=action_db.pack, entry_point=action_db.entry_point
+ )
libs_dir_path = content_utils.get_action_libs_abs_path(
- pack=action_db.pack,
- entry_point=action_db.entry_point)
+ pack=action_db.pack, entry_point=action_db.entry_point
+ )
# Configure the action runner.
runner.runner_type_db = runner_type_db
@@ -226,8 +237,8 @@ def invoke_post_run(liveaction_db, action_db=None):
runner.liveaction = liveaction_db
runner.liveaction_id = str(liveaction_db.id)
runner.entry_point = entry_point
- runner.context = getattr(liveaction_db, 'context', dict())
- runner.callback = getattr(liveaction_db, 'callback', dict())
+ runner.context = getattr(liveaction_db, "context", dict())
+ runner.callback = getattr(liveaction_db, "callback", dict())
runner.libs_dir_path = libs_dir_path
# Invoke the post_run method.
diff --git a/st2common/st2common/script_setup.py b/st2common/st2common/script_setup.py
index 03be4b4427..0abb7e8269 100644
--- a/st2common/st2common/script_setup.py
+++ b/st2common/st2common/script_setup.py
@@ -32,13 +32,7 @@
from st2common.logging.filters import LogLevelFilter
from st2common.transport.bootstrap_utils import register_exchanges_with_retry
-__all__ = [
- 'setup',
- 'teardown',
-
- 'db_setup',
- 'db_teardown'
-]
+__all__ = ["setup", "teardown", "db_setup", "db_teardown"]
LOG = logging.getLogger(__name__)
@@ -47,11 +41,15 @@ def register_common_cli_options():
"""
Register common CLI options.
"""
- cfg.CONF.register_cli_opt(cfg.BoolOpt('verbose', short='v', default=False))
+ cfg.CONF.register_cli_opt(cfg.BoolOpt("verbose", short="v", default=False))
-def setup(config, setup_db=True, register_mq_exchanges=True,
- register_internal_trigger_types=False):
+def setup(
+ config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_internal_trigger_types=False,
+):
"""
Common setup function.
@@ -76,7 +74,9 @@ def setup(config, setup_db=True, register_mq_exchanges=True,
# Set up logging
log_level = stdlib_logging.DEBUG
- stdlib_logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s', level=log_level)
+ stdlib_logging.basicConfig(
+ format="%(asctime)s %(levelname)s [-] %(message)s", level=log_level
+ )
if not cfg.CONF.verbose:
# Note: We still want to print things at the following log levels: INFO, ERROR, CRITICAL
diff --git a/st2common/st2common/service_setup.py b/st2common/st2common/service_setup.py
index 14cd708cca..bd01d205e7 100644
--- a/st2common/st2common/service_setup.py
+++ b/st2common/st2common/service_setup.py
@@ -53,22 +53,29 @@
__all__ = [
- 'setup',
- 'teardown',
-
- 'db_setup',
- 'db_teardown',
-
- 'register_service_in_service_registry'
+ "setup",
+ "teardown",
+ "db_setup",
+ "db_teardown",
+ "register_service_in_service_registry",
]
LOG = logging.getLogger(__name__)
-def setup(service, config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, register_internal_trigger_types=False,
- run_migrations=True, register_runners=True, service_registry=False,
- capabilities=None, config_args=None):
+def setup(
+ service,
+ config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=False,
+ run_migrations=True,
+ register_runners=True,
+ service_registry=False,
+ capabilities=None,
+ config_args=None,
+):
"""
Common setup function.
@@ -99,29 +106,38 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True,
else:
config.parse_args()
- version = '%s.%s.%s' % (sys.version_info[0], sys.version_info[1], sys.version_info[2])
- LOG.debug('Using Python: %s (%s)' % (version, sys.executable))
+ version = "%s.%s.%s" % (
+ sys.version_info[0],
+ sys.version_info[1],
+ sys.version_info[2],
+ )
+ LOG.debug("Using Python: %s (%s)" % (version, sys.executable))
config_file_paths = cfg.CONF.config_file
config_file_paths = [os.path.abspath(path) for path in config_file_paths]
- LOG.debug('Using config files: %s', ','.join(config_file_paths))
+ LOG.debug("Using config files: %s", ",".join(config_file_paths))
# Setup logging.
logging_config_path = config.get_logging_config_path()
logging_config_path = os.path.abspath(logging_config_path)
- LOG.debug('Using logging config: %s', logging_config_path)
+ LOG.debug("Using logging config: %s", logging_config_path)
- is_debug_enabled = (cfg.CONF.debug or cfg.CONF.system.debug)
+ is_debug_enabled = cfg.CONF.debug or cfg.CONF.system.debug
try:
- logging.setup(logging_config_path, redirect_stderr=cfg.CONF.log.redirect_stderr,
- excludes=cfg.CONF.log.excludes)
+ logging.setup(
+ logging_config_path,
+ redirect_stderr=cfg.CONF.log.redirect_stderr,
+ excludes=cfg.CONF.log.excludes,
+ )
except KeyError as e:
tb_msg = traceback.format_exc()
- if 'log.setLevel' in tb_msg:
- msg = 'Invalid log level selected. Log level names need to be all uppercase.'
- msg += '\n\n' + getattr(e, 'message', six.text_type(e))
+ if "log.setLevel" in tb_msg:
+ msg = (
+ "Invalid log level selected. Log level names need to be all uppercase."
+ )
+ msg += "\n\n" + getattr(e, "message", six.text_type(e))
raise KeyError(msg)
else:
raise e
@@ -134,10 +150,14 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True,
# duplicate "AUDIT" messages in production deployments where default service log level is
# set to "INFO" and we already log messages with level AUDIT to a special dedicated log
# file.
- ignore_audit_log_messages = (handler.level >= stdlib_logging.INFO and
- handler.level < stdlib_logging.AUDIT)
+ ignore_audit_log_messages = (
+ handler.level >= stdlib_logging.INFO
+ and handler.level < stdlib_logging.AUDIT
+ )
if not is_debug_enabled and ignore_audit_log_messages:
- LOG.debug('Excluding log messages with level "AUDIT" for handler "%s"' % (handler))
+ LOG.debug(
+ 'Excluding log messages with level "AUDIT" for handler "%s"' % (handler)
+ )
handler.addFilter(LogLevelFilter(log_levels=exclude_log_levels))
if not is_debug_enabled:
@@ -184,8 +204,9 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True,
# Register service in the service registry
if cfg.CONF.coordination.service_registry and service_registry:
# NOTE: It's important that we pass start_heart=True to start the hearbeat process
- register_service_in_service_registry(service=service, capabilities=capabilities,
- start_heart=True)
+ register_service_in_service_registry(
+ service=service, capabilities=capabilities, start_heart=True
+ )
if sys.version_info[0] == 2:
LOG.warning(PYTHON2_DEPRECATION)
@@ -220,7 +241,7 @@ def register_service_in_service_registry(service, capabilities=None, start_heart
# 1. Create a group with the name of the service
if not isinstance(service, six.binary_type):
- group_id = service.encode('utf-8')
+ group_id = service.encode("utf-8")
else:
group_id = service
@@ -231,10 +252,12 @@ def register_service_in_service_registry(service, capabilities=None, start_heart
# Include common capabilities such as hostname and process ID
proc_info = system_info.get_process_info()
- capabilities['hostname'] = proc_info['hostname']
- capabilities['pid'] = proc_info['pid']
+ capabilities["hostname"] = proc_info["hostname"]
+ capabilities["pid"] = proc_info["pid"]
# 1. Join the group as a member
- LOG.debug('Joining service registry group "%s" as member_id "%s" with capabilities "%s"' %
- (group_id, member_id, capabilities))
+ LOG.debug(
+ 'Joining service registry group "%s" as member_id "%s" with capabilities "%s"'
+ % (group_id, member_id, capabilities)
+ )
return coordinator.join_group(group_id, capabilities=capabilities).get()
diff --git a/st2common/st2common/services/access.py b/st2common/st2common/services/access.py
index 72f7f192bb..9d88c39c42 100644
--- a/st2common/st2common/services/access.py
+++ b/st2common/st2common/services/access.py
@@ -27,15 +27,14 @@
from st2common.persistence.auth import Token, User
from st2common import log as logging
-__all__ = [
- 'create_token',
- 'delete_token'
-]
+__all__ = ["create_token", "delete_token"]
LOG = logging.getLogger(__name__)
-def create_token(username, ttl=None, metadata=None, add_missing_user=True, service=False):
+def create_token(
+ username, ttl=None, metadata=None, add_missing_user=True, service=False
+):
"""
:param username: Username of the user to create the token for. If the account for this user
doesn't exist yet it will be created.
@@ -57,8 +56,10 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi
if ttl:
# Note: We allow arbitrary large TTLs for service tokens.
if not service and ttl > cfg.CONF.auth.token_ttl:
- msg = ('TTL specified %s is greater than max allowed %s.' % (ttl,
- cfg.CONF.auth.token_ttl))
+ msg = "TTL specified %s is greater than max allowed %s." % (
+ ttl,
+ cfg.CONF.auth.token_ttl,
+ )
raise TTLTooLargeException(msg)
else:
ttl = cfg.CONF.auth.token_ttl
@@ -71,22 +72,27 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi
user_db = UserDB(name=username)
User.add_or_update(user_db)
- extra = {'username': username, 'user': user_db}
+ extra = {"username": username, "user": user_db}
LOG.audit('Registered new user "%s".' % (username), extra=extra)
else:
raise UserNotFoundError()
token = uuid.uuid4().hex
expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl)
- token = TokenDB(user=username, token=token, expiry=expiry, metadata=metadata, service=service)
+ token = TokenDB(
+ user=username, token=token, expiry=expiry, metadata=metadata, service=service
+ )
Token.add_or_update(token)
- username_string = username if username else 'an anonymous user'
+ username_string = username if username else "an anonymous user"
token_expire_string = isotime.format(expiry, offset=False)
- extra = {'username': username, 'token_expiration': token_expire_string}
+ extra = {"username": username, "token_expiration": token_expire_string}
- LOG.audit('Access granted to "%s" with the token set to expire at "%s".' %
- (username_string, token_expire_string), extra=extra)
+ LOG.audit(
+ 'Access granted to "%s" with the token set to expire at "%s".'
+ % (username_string, token_expire_string),
+ extra=extra,
+ )
return token
diff --git a/st2common/st2common/services/action.py b/st2common/st2common/services/action.py
index c7e7495d69..46e44800cc 100644
--- a/st2common/st2common/services/action.py
+++ b/st2common/st2common/services/action.py
@@ -34,15 +34,13 @@
__all__ = [
- 'request',
- 'create_request',
- 'publish_request',
- 'is_action_canceled_or_canceling',
-
- 'request_pause',
- 'request_resume',
-
- 'store_execution_output_data',
+ "request",
+ "create_request",
+ "publish_request",
+ "is_action_canceled_or_canceling",
+ "request_pause",
+ "request_resume",
+ "store_execution_output_data",
]
LOG = logging.getLogger(__name__)
@@ -51,7 +49,7 @@
def _get_immutable_params(parameters):
if not parameters:
return []
- return [k for k, v in six.iteritems(parameters) if v.get('immutable', False)]
+ return [k for k, v in six.iteritems(parameters) if v.get("immutable", False)]
def create_request(liveaction, action_db=None, runnertype_db=None):
@@ -77,10 +75,10 @@ def create_request(liveaction, action_db=None, runnertype_db=None):
# action can be invoked by a system user and so we want to use the user context
# from the original workflow action.
parent_context = executions.get_parent_context(liveaction) or {}
- parent_user = parent_context.get('user', None)
+ parent_user = parent_context.get("user", None)
if parent_user:
- liveaction.context['user'] = parent_user
+ liveaction.context["user"] = parent_user
# Validate action
if not action_db:
@@ -89,31 +87,44 @@ def create_request(liveaction, action_db=None, runnertype_db=None):
if not action_db:
raise ValueError('Action "%s" cannot be found.' % liveaction.action)
if not action_db.enabled:
- raise ValueError('Unable to execute. Action "%s" is disabled.' % liveaction.action)
+ raise ValueError(
+ 'Unable to execute. Action "%s" is disabled.' % liveaction.action
+ )
if not runnertype_db:
- runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_utils.get_runnertype_by_name(
+ action_db.runner_type["name"]
+ )
- if not hasattr(liveaction, 'parameters'):
+ if not hasattr(liveaction, "parameters"):
liveaction.parameters = dict()
# For consistency add pack to the context here in addition to RunnerContainer.dispatch() method
- liveaction.context['pack'] = action_db.pack
+ liveaction.context["pack"] = action_db.pack
# Validate action parameters.
schema = util_schema.get_schema_for_action_parameters(action_db, runnertype_db)
validator = util_schema.get_validator()
- util_schema.validate(liveaction.parameters, schema, validator, use_default=True,
- allow_default_none=True)
+ util_schema.validate(
+ liveaction.parameters,
+ schema,
+ validator,
+ use_default=True,
+ allow_default_none=True,
+ )
# validate that no immutable params are being overriden. Although possible to
# ignore the override it is safer to inform the user to avoid surprises.
immutables = _get_immutable_params(action_db.parameters)
immutables.extend(_get_immutable_params(runnertype_db.runner_parameters))
- overridden_immutables = [p for p in six.iterkeys(liveaction.parameters) if p in immutables]
+ overridden_immutables = [
+ p for p in six.iterkeys(liveaction.parameters) if p in immutables
+ ]
if len(overridden_immutables) > 0:
- raise ValueError('Override of immutable parameter(s) %s is unsupported.'
- % str(overridden_immutables))
+ raise ValueError(
+ "Override of immutable parameter(s) %s is unsupported."
+ % str(overridden_immutables)
+ )
# Set notification settings for action.
# XXX: There are cases when we don't want notifications to be sent for a particular
@@ -140,17 +151,24 @@ def create_request(liveaction, action_db=None, runnertype_db=None):
_cleanup_liveaction(liveaction)
raise trace_exc.TraceNotFoundException(six.text_type(e))
- execution = executions.create_execution_object(liveaction=liveaction, action_db=action_db,
- runnertype_db=runnertype_db, publish=False)
+ execution = executions.create_execution_object(
+ liveaction=liveaction,
+ action_db=action_db,
+ runnertype_db=runnertype_db,
+ publish=False,
+ )
if trace_db:
trace_service.add_or_update_given_trace_db(
trace_db=trace_db,
action_executions=[
- trace_service.get_trace_component_for_action_execution(execution, liveaction)
- ])
+ trace_service.get_trace_component_for_action_execution(
+ execution, liveaction
+ )
+ ],
+ )
- get_driver().inc_counter('action.executions.%s' % (liveaction.status))
+ get_driver().inc_counter("action.executions.%s" % (liveaction.status))
return liveaction, execution
@@ -170,8 +188,11 @@ def publish_request(liveaction, execution):
# TODO: This results in two queries, optimize it
# extra = {'liveaction_db': liveaction, 'execution_db': execution}
extra = {}
- LOG.audit('Action execution requested. LiveAction.id=%s, ActionExecution.id=%s' %
- (liveaction.id, execution.id), extra=extra)
+ LOG.audit(
+ "Action execution requested. LiveAction.id=%s, ActionExecution.id=%s"
+ % (liveaction.id, execution.id),
+ extra=extra,
+ )
return liveaction, execution
@@ -190,33 +211,34 @@ def update_status(liveaction, new_status, result=None, publish=True):
old_status = liveaction.status
updates = {
- 'liveaction_id': liveaction.id,
- 'status': new_status,
- 'result': result,
- 'publish': False
+ "liveaction_id": liveaction.id,
+ "status": new_status,
+ "result": result,
+ "publish": False,
}
if new_status in action_constants.LIVEACTION_COMPLETED_STATES:
- updates['end_timestamp'] = date_utils.get_datetime_utc_now()
+ updates["end_timestamp"] = date_utils.get_datetime_utc_now()
liveaction = action_utils.update_liveaction_status(**updates)
action_execution = executions.update_execution(liveaction)
- msg = ('The status of action execution is changed from %s to %s. '
- '' % (old_status,
- new_status, liveaction.id, action_execution.id))
+ msg = (
+ "The status of action execution is changed from %s to %s. "
+ ""
+ % (old_status, new_status, liveaction.id, action_execution.id)
+ )
- extra = {
- 'action_execution_db': action_execution,
- 'liveaction_db': liveaction
- }
+ extra = {"action_execution_db": action_execution, "liveaction_db": liveaction}
LOG.audit(msg, extra=extra)
LOG.info(msg)
# Invoke post run if liveaction status is completed or paused.
- if (new_status in action_constants.LIVEACTION_COMPLETED_STATES or
- new_status == action_constants.LIVEACTION_STATUS_PAUSED):
+ if (
+ new_status in action_constants.LIVEACTION_COMPLETED_STATES
+ or new_status == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
runners_utils.invoke_post_run(liveaction)
if publish:
@@ -227,14 +249,18 @@ def update_status(liveaction, new_status, result=None, publish=True):
def is_action_canceled_or_canceling(liveaction_id):
liveaction_db = action_utils.get_liveaction_by_id(liveaction_id)
- return liveaction_db.status in [action_constants.LIVEACTION_STATUS_CANCELED,
- action_constants.LIVEACTION_STATUS_CANCELING]
+ return liveaction_db.status in [
+ action_constants.LIVEACTION_STATUS_CANCELED,
+ action_constants.LIVEACTION_STATUS_CANCELING,
+ ]
def is_action_paused_or_pausing(liveaction_id):
liveaction_db = action_utils.get_liveaction_by_id(liveaction_id)
- return liveaction_db.status in [action_constants.LIVEACTION_STATUS_PAUSED,
- action_constants.LIVEACTION_STATUS_PAUSING]
+ return liveaction_db.status in [
+ action_constants.LIVEACTION_STATUS_PAUSED,
+ action_constants.LIVEACTION_STATUS_PAUSING,
+ ]
def request_cancellation(liveaction, requester):
@@ -250,18 +276,17 @@ def request_cancellation(liveaction, requester):
if liveaction.status not in action_constants.LIVEACTION_CANCELABLE_STATES:
raise Exception(
'Unable to cancel liveaction "%s" because it is already in a '
- 'completed state.' % liveaction.id
+ "completed state." % liveaction.id
)
- result = {
- 'message': 'Action canceled by user.',
- 'user': requester
- }
+ result = {"message": "Action canceled by user.", "user": requester}
# Run cancelation sequence for liveaction that is in running state or
# if the liveaction is operating under a workflow.
- if ('parent' in liveaction.context or
- liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING):
+ if (
+ "parent" in liveaction.context
+ or liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING
+ ):
status = action_constants.LIVEACTION_STATUS_CANCELING
else:
status = action_constants.LIVEACTION_STATUS_CANCELED
@@ -286,17 +311,19 @@ def request_pause(liveaction, requester):
if not action_db:
raise ValueError(
'Unable to pause liveaction "%s" because the action "%s" '
- 'is not found.' % (liveaction.id, liveaction.action)
+ "is not found." % (liveaction.id, liveaction.action)
)
- if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES:
+ if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES:
raise runner_exc.InvalidActionRunnerOperationError(
'Unable to pause liveaction "%s" because it is not supported by the '
- '"%s" runner.' % (liveaction.id, action_db.runner_type['name'])
+ '"%s" runner.' % (liveaction.id, action_db.runner_type["name"])
)
- if (liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING or
- liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED):
+ if (
+ liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING
+ or liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED
+ ):
execution = ActionExecution.get(liveaction__id=str(liveaction.id))
return (liveaction, execution)
@@ -326,18 +353,18 @@ def request_resume(liveaction, requester):
if not action_db:
raise ValueError(
'Unable to resume liveaction "%s" because the action "%s" '
- 'is not found.' % (liveaction.id, liveaction.action)
+ "is not found." % (liveaction.id, liveaction.action)
)
- if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES:
+ if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES:
raise runner_exc.InvalidActionRunnerOperationError(
'Unable to resume liveaction "%s" because it is not supported by the '
- '"%s" runner.' % (liveaction.id, action_db.runner_type['name'])
+ '"%s" runner.' % (liveaction.id, action_db.runner_type["name"])
)
running_states = [
action_constants.LIVEACTION_STATUS_RUNNING,
- action_constants.LIVEACTION_STATUS_RESUMING
+ action_constants.LIVEACTION_STATUS_RESUMING,
]
if liveaction.status in running_states:
@@ -367,13 +394,13 @@ def get_parent_liveaction(liveaction_db):
:rtype: LiveActionDB
"""
- parent = liveaction_db.context.get('parent')
+ parent = liveaction_db.context.get("parent")
if not parent:
return None
- parent_execution_db = ActionExecution.get(id=parent['execution_id'])
- parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction['id'])
+ parent_execution_db = ActionExecution.get(id=parent["execution_id"])
+ parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction["id"])
return parent_liveaction_db
@@ -409,7 +436,11 @@ def get_root_liveaction(liveaction_db):
parent_liveaction_db = get_parent_liveaction(liveaction_db)
- return get_root_liveaction(parent_liveaction_db) if parent_liveaction_db else liveaction_db
+ return (
+ get_root_liveaction(parent_liveaction_db)
+ if parent_liveaction_db
+ else liveaction_db
+ )
def get_root_execution(execution_db):
@@ -425,36 +456,48 @@ def get_root_execution(execution_db):
parent_execution_db = get_parent_execution(execution_db)
- return get_root_execution(parent_execution_db) if parent_execution_db else execution_db
+ return (
+ get_root_execution(parent_execution_db) if parent_execution_db else execution_db
+ )
-def store_execution_output_data(execution_db, action_db, data, output_type='output',
- timestamp=None):
+def store_execution_output_data(
+ execution_db, action_db, data, output_type="output", timestamp=None
+):
"""
Store output from an execution as a new document in the collection.
"""
execution_id = str(execution_db.id)
if action_db is None:
- action_ref = execution_db.action.get('ref', 'unknown')
- runner_ref = execution_db.action.get('runner_type', 'unknown')
+ action_ref = execution_db.action.get("ref", "unknown")
+ runner_ref = execution_db.action.get("runner_type", "unknown")
else:
action_ref = action_db.ref
- runner_ref = getattr(action_db, 'runner_type', {}).get('name', 'unknown')
+ runner_ref = getattr(action_db, "runner_type", {}).get("name", "unknown")
return store_execution_output_data_ex(
- execution_id, action_ref, runner_ref, data,
- output_type=output_type, timestamp=timestamp
+ execution_id,
+ action_ref,
+ runner_ref,
+ data,
+ output_type=output_type,
+ timestamp=timestamp,
)
-def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, output_type='output',
- timestamp=None):
+def store_execution_output_data_ex(
+ execution_id, action_ref, runner_ref, data, output_type="output", timestamp=None
+):
timestamp = timestamp or date_utils.get_datetime_utc_now()
output_db = ActionExecutionOutputDB(
- execution_id=execution_id, action_ref=action_ref, runner_ref=runner_ref,
- timestamp=timestamp, output_type=output_type, data=data
+ execution_id=execution_id,
+ action_ref=action_ref,
+ runner_ref=runner_ref,
+ timestamp=timestamp,
+ output_type=output_type,
+ data=data,
)
output_db = ActionExecutionOutput.add_or_update(
@@ -467,29 +510,29 @@ def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, o
def is_children_active(liveaction_id):
execution_db = ActionExecution.get(liveaction__id=str(liveaction_id))
- if execution_db.runner['name'] not in action_constants.WORKFLOW_RUNNER_TYPES:
+ if execution_db.runner["name"] not in action_constants.WORKFLOW_RUNNER_TYPES:
return False
children_execution_dbs = ActionExecution.query(parent=str(execution_db.id))
- inactive_statuses = (
- action_constants.LIVEACTION_COMPLETED_STATES +
- [action_constants.LIVEACTION_STATUS_PAUSED, action_constants.LIVEACTION_STATUS_PENDING]
- )
+ inactive_statuses = action_constants.LIVEACTION_COMPLETED_STATES + [
+ action_constants.LIVEACTION_STATUS_PAUSED,
+ action_constants.LIVEACTION_STATUS_PENDING,
+ ]
completed = [
child_exec_db.status in inactive_statuses
for child_exec_db in children_execution_dbs
]
- return (not all(completed))
+ return not all(completed)
def _cleanup_liveaction(liveaction):
try:
LiveAction.delete(liveaction)
except:
- LOG.exception('Failed cleaning up LiveAction: %s.', liveaction)
+ LOG.exception("Failed cleaning up LiveAction: %s.", liveaction)
pass
diff --git a/st2common/st2common/services/config.py b/st2common/st2common/services/config.py
index f23d91ee9c..bef8f483dd 100644
--- a/st2common/st2common/services/config.py
+++ b/st2common/st2common/services/config.py
@@ -28,13 +28,15 @@
from st2common.exceptions.db import StackStormDBObjectNotFoundError
__all__ = [
- 'set_datastore_value_for_config_key',
+ "set_datastore_value_for_config_key",
]
LOG = logging.getLogger(__name__)
-def set_datastore_value_for_config_key(pack_name, key_name, value, secret=False, user=None):
+def set_datastore_value_for_config_key(
+ pack_name, key_name, value, secret=False, user=None
+):
"""
Set config value in the datastore.
diff --git a/st2common/st2common/services/coordination.py b/st2common/st2common/services/coordination.py
index 42556ea084..1a068632ab 100644
--- a/st2common/st2common/services/coordination.py
+++ b/st2common/st2common/services/coordination.py
@@ -31,19 +31,17 @@
COORDINATOR = None
__all__ = [
- 'configured',
-
- 'get_coordinator',
- 'get_coordinator_if_set',
- 'get_member_id',
-
- 'coordinator_setup',
- 'coordinator_teardown'
+ "configured",
+ "get_coordinator",
+ "get_coordinator_if_set",
+ "get_member_id",
+ "coordinator_setup",
+ "coordinator_teardown",
]
class NoOpLock(locking.Lock):
- def __init__(self, name='noop'):
+ def __init__(self, name="noop"):
super(NoOpLock, self).__init__(name=name)
def acquire(self, blocking=True):
@@ -61,6 +59,7 @@ class NoOpAsyncResult(object):
In most scenarios, tooz library returns an async result, a future and this
class wrapper is here to correctly mimic tooz API and behavior.
"""
+
def __init__(self, result=None):
self._result = result
@@ -108,7 +107,7 @@ def stand_down_group_leader(group_id):
@classmethod
def create_group(cls, group_id):
- cls.groups[group_id] = {'members': {}}
+ cls.groups[group_id] = {"members": {}}
return NoOpAsyncResult()
@classmethod
@@ -116,17 +115,17 @@ def get_groups(cls):
return NoOpAsyncResult(result=cls.groups.keys())
@classmethod
- def join_group(cls, group_id, capabilities=''):
+ def join_group(cls, group_id, capabilities=""):
member_id = get_member_id()
- cls.groups[group_id]['members'][member_id] = {'capabilities': capabilities}
+ cls.groups[group_id]["members"][member_id] = {"capabilities": capabilities}
return NoOpAsyncResult()
@classmethod
def leave_group(cls, group_id):
member_id = get_member_id()
- del cls.groups[group_id]['members'][member_id]
+ del cls.groups[group_id]["members"][member_id]
return NoOpAsyncResult()
@classmethod
@@ -137,15 +136,15 @@ def delete_group(cls, group_id):
@classmethod
def get_members(cls, group_id):
try:
- member_ids = cls.groups[group_id]['members'].keys()
+ member_ids = cls.groups[group_id]["members"].keys()
except KeyError:
- raise GroupNotCreated('Group doesnt exist')
+ raise GroupNotCreated("Group doesnt exist")
return NoOpAsyncResult(result=member_ids)
@classmethod
def get_member_capabilities(cls, group_id, member_id):
- member_capabiliteis = cls.groups[group_id]['members'][member_id]['capabilities']
+ member_capabiliteis = cls.groups[group_id]["members"][member_id]["capabilities"]
return NoOpAsyncResult(result=member_capabiliteis)
@staticmethod
@@ -158,7 +157,7 @@ def get_leader(group_id):
@staticmethod
def get_lock(name):
- return NoOpLock(name='noop')
+ return NoOpLock(name="noop")
def configured():
@@ -168,8 +167,10 @@ def configured():
:rtype: ``bool``
"""
backend_configured = cfg.CONF.coordination.url is not None
- mock_backend = backend_configured and (cfg.CONF.coordination.url.startswith('zake') or
- cfg.CONF.coordination.url.startswith('file'))
+ mock_backend = backend_configured and (
+ cfg.CONF.coordination.url.startswith("zake")
+ or cfg.CONF.coordination.url.startswith("file")
+ )
return backend_configured and not mock_backend
@@ -189,7 +190,9 @@ def coordinator_setup(start_heart=True):
member_id = get_member_id()
if url:
- coordinator = coordination.get_coordinator(url, member_id, lock_timeout=lock_timeout)
+ coordinator = coordination.get_coordinator(
+ url, member_id, lock_timeout=lock_timeout
+ )
else:
# Use a no-op backend
# Note: We don't use tooz to obtain a reference since for this to work we would need to
@@ -217,17 +220,21 @@ def get_coordinator(start_heart=True, use_cache=True):
global COORDINATOR
if not configured():
- LOG.warn('Coordination backend is not configured. Code paths which use coordination '
- 'service will use best effort approach and race conditions are possible.')
+ LOG.warn(
+ "Coordination backend is not configured. Code paths which use coordination "
+ "service will use best effort approach and race conditions are possible."
+ )
if not use_cache:
return coordinator_setup(start_heart=start_heart)
if not COORDINATOR:
COORDINATOR = coordinator_setup(start_heart=start_heart)
- LOG.debug('Initializing and caching new coordinator instance: %s' % (str(COORDINATOR)))
+ LOG.debug(
+ "Initializing and caching new coordinator instance: %s" % (str(COORDINATOR))
+ )
else:
- LOG.debug('Using cached coordinator instance: %s' % (str(COORDINATOR)))
+ LOG.debug("Using cached coordinator instance: %s" % (str(COORDINATOR)))
return COORDINATOR
@@ -247,5 +254,5 @@ def get_member_id():
:rtype: ``bytes``
"""
proc_info = system_info.get_process_info()
- member_id = six.b('%s_%d' % (proc_info['hostname'], proc_info['pid']))
+ member_id = six.b("%s_%d" % (proc_info["hostname"], proc_info["pid"]))
return member_id
diff --git a/st2common/st2common/services/datastore.py b/st2common/st2common/services/datastore.py
index 986ffd0d03..9655499e49 100644
--- a/st2common/st2common/services/datastore.py
+++ b/st2common/st2common/services/datastore.py
@@ -24,11 +24,7 @@
from st2common.util.date import get_datetime_utc_now
from st2common.constants.keyvalue import DATASTORE_KEY_SEPARATOR, SYSTEM_SCOPE
-__all__ = [
- 'BaseDatastoreService',
- 'ActionDatastoreService',
- 'SensorDatastoreService'
-]
+__all__ = ["BaseDatastoreService", "ActionDatastoreService", "SensorDatastoreService"]
class BaseDatastoreService(object):
@@ -63,7 +59,7 @@ def get_user_info(self):
"""
client = self.get_api_client()
- self._logger.debug('Retrieving user information')
+ self._logger.debug("Retrieving user information")
result = client.get_user_info()
return result
@@ -85,7 +81,7 @@ def list_values(self, local=True, prefix=None):
:rtype: ``list`` of :class:`KeyValuePair`
"""
client = self.get_api_client()
- self._logger.debug('Retrieving all the values from the datastore')
+ self._logger.debug("Retrieving all the values from the datastore")
key_prefix = self._get_full_key_prefix(local=local, prefix=prefix)
kvps = client.keys.get_all(prefix=key_prefix)
@@ -113,21 +109,19 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False):
:rtype: ``str`` or ``None``
"""
if scope != SYSTEM_SCOPE:
- raise ValueError('Scope %s is unsupported.' % scope)
+ raise ValueError("Scope %s is unsupported." % scope)
name = self._get_full_key_name(name=name, local=local)
client = self.get_api_client()
- self._logger.debug('Retrieving value from the datastore (name=%s)', name)
+ self._logger.debug("Retrieving value from the datastore (name=%s)", name)
try:
- params = {'decrypt': str(decrypt).lower(), 'scope': scope}
+ params = {"decrypt": str(decrypt).lower(), "scope": scope}
kvp = client.keys.get_by_id(id=name, params=params)
except Exception as e:
self._logger.exception(
- 'Exception retrieving value from datastore (name=%s): %s',
- name,
- e
+ "Exception retrieving value from datastore (name=%s): %s", name, e
)
return None
@@ -136,7 +130,9 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False):
return None
- def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False):
+ def set_value(
+ self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False
+ ):
"""
Set a value for the provided key.
@@ -165,14 +161,14 @@ def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encry
:rtype: ``bool``
"""
if scope != SYSTEM_SCOPE:
- raise ValueError('Scope %s is unsupported.' % scope)
+ raise ValueError("Scope %s is unsupported." % scope)
name = self._get_full_key_name(name=name, local=local)
value = str(value)
client = self.get_api_client()
- self._logger.debug('Setting value in the datastore (name=%s)', name)
+ self._logger.debug("Setting value in the datastore (name=%s)", name)
instance = KeyValuePair()
instance.id = name
@@ -208,7 +204,7 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE):
:rtype: ``bool``
"""
if scope != SYSTEM_SCOPE:
- raise ValueError('Scope %s is unsupported.' % scope)
+ raise ValueError("Scope %s is unsupported." % scope)
name = self._get_full_key_name(name=name, local=local)
@@ -218,16 +214,14 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE):
instance.id = name
instance.name = name
- self._logger.debug('Deleting value from the datastore (name=%s)', name)
+ self._logger.debug("Deleting value from the datastore (name=%s)", name)
try:
- params = {'scope': scope}
+ params = {"scope": scope}
client.keys.delete(instance=instance, params=params)
except Exception as e:
self._logger.exception(
- 'Exception deleting value from datastore (name=%s): %s',
- name,
- e
+ "Exception deleting value from datastore (name=%s): %s", name, e
)
return False
@@ -237,7 +231,7 @@ def get_api_client(self):
"""
Retrieve API client instance.
"""
- raise NotImplementedError('get_api_client() not implemented')
+ raise NotImplementedError("get_api_client() not implemented")
def _get_full_key_name(self, name, local):
"""
@@ -282,7 +276,7 @@ def _get_key_name_with_prefix(self, name):
return full_name
def _get_datastore_key_prefix(self):
- prefix = '%s.%s' % (self._pack_name, self._class_name)
+ prefix = "%s.%s" % (self._pack_name, self._class_name)
return prefix
@@ -299,8 +293,9 @@ def __init__(self, logger, pack_name, class_name, auth_token):
:param auth_token: Auth token used to authenticate with StackStorm API.
:type auth_token: ``str``
"""
- super(ActionDatastoreService, self).__init__(logger=logger, pack_name=pack_name,
- class_name=class_name)
+ super(ActionDatastoreService, self).__init__(
+ logger=logger, pack_name=pack_name, class_name=class_name
+ )
self._auth_token = auth_token
self._client = None
@@ -310,7 +305,7 @@ def get_api_client(self):
Retrieve API client instance.
"""
if not self._client:
- self._logger.debug('Creating new Client object.')
+ self._logger.debug("Creating new Client object.")
api_url = get_full_public_api_url()
client = Client(api_url=api_url, token=self._auth_token)
@@ -330,8 +325,9 @@ class SensorDatastoreService(BaseDatastoreService):
"""
def __init__(self, logger, pack_name, class_name, api_username):
- super(SensorDatastoreService, self).__init__(logger=logger, pack_name=pack_name,
- class_name=class_name)
+ super(SensorDatastoreService, self).__init__(
+ logger=logger, pack_name=pack_name, class_name=class_name
+ )
self._api_username = api_username
self._token_expire = get_datetime_utc_now()
@@ -344,12 +340,15 @@ def get_api_client(self):
if not self._client or token_expire:
# Note: Late import to avoid high import cost (time wise)
from st2common.services.access import create_token
- self._logger.debug('Creating new Client object.')
+
+ self._logger.debug("Creating new Client object.")
ttl = cfg.CONF.auth.service_token_ttl
api_url = get_full_public_api_url()
- temporary_token = create_token(username=self._api_username, ttl=ttl, service=True)
+ temporary_token = create_token(
+ username=self._api_username, ttl=ttl, service=True
+ )
self._client = Client(api_url=api_url, token=temporary_token.token)
self._token_expire = get_datetime_utc_now() + timedelta(seconds=ttl)
diff --git a/st2common/st2common/services/executions.py b/st2common/st2common/services/executions.py
index e259977bdc..51447796b0 100644
--- a/st2common/st2common/services/executions.py
+++ b/st2common/st2common/services/executions.py
@@ -51,13 +51,13 @@
__all__ = [
- 'create_execution_object',
- 'update_execution',
- 'abandon_execution_if_incomplete',
- 'is_execution_canceled',
- 'AscendingSortedDescendantView',
- 'DFSDescendantView',
- 'get_descendants'
+ "create_execution_object",
+ "update_execution",
+ "abandon_execution_if_incomplete",
+ "is_execution_canceled",
+ "AscendingSortedDescendantView",
+ "DFSDescendantView",
+ "get_descendants",
]
LOG = logging.getLogger(__name__)
@@ -66,13 +66,13 @@
# into a ActionExecution compatible dictionary.
# Those attributes are LiveAction specific and are therefore stored in a "liveaction" key
LIVEACTION_ATTRIBUTES = [
- 'id',
- 'callback',
- 'action',
- 'action_is_workflow',
- 'runner_info',
- 'parameters',
- 'notify'
+ "id",
+ "callback",
+ "action",
+ "action_is_workflow",
+ "runner_info",
+ "parameters",
+ "notify",
]
@@ -80,11 +80,11 @@ def _decompose_liveaction(liveaction_db):
"""
Splits the liveaction into an ActionExecution compatible dict.
"""
- decomposed = {'liveaction': {}}
+ decomposed = {"liveaction": {}}
liveaction_api = vars(LiveActionAPI.from_model(liveaction_db))
for k in liveaction_api.keys():
if k in LIVEACTION_ATTRIBUTES:
- decomposed['liveaction'][k] = liveaction_api[k]
+ decomposed["liveaction"][k] = liveaction_api[k]
else:
decomposed[k] = getattr(liveaction_db, k)
return decomposed
@@ -94,49 +94,53 @@ def _create_execution_log_entry(status):
"""
Create execution log entry object for the provided execution status.
"""
- return {
- 'timestamp': date_utils.get_datetime_utc_now(),
- 'status': status
- }
+ return {"timestamp": date_utils.get_datetime_utc_now(), "status": status}
-def create_execution_object(liveaction, action_db=None, runnertype_db=None, publish=True):
+def create_execution_object(
+ liveaction, action_db=None, runnertype_db=None, publish=True
+):
if not action_db:
action_db = action_utils.get_action_by_ref(liveaction.action)
if not runnertype_db:
- runnertype_db = RunnerType.get_by_name(action_db.runner_type['name'])
+ runnertype_db = RunnerType.get_by_name(action_db.runner_type["name"])
attrs = {
- 'action': vars(ActionAPI.from_model(action_db)),
- 'parameters': liveaction['parameters'],
- 'runner': vars(RunnerTypeAPI.from_model(runnertype_db))
+ "action": vars(ActionAPI.from_model(action_db)),
+ "parameters": liveaction["parameters"],
+ "runner": vars(RunnerTypeAPI.from_model(runnertype_db)),
}
attrs.update(_decompose_liveaction(liveaction))
- if 'rule' in liveaction.context:
- rule = reference.get_model_from_ref(Rule, liveaction.context.get('rule', {}))
- attrs['rule'] = vars(RuleAPI.from_model(rule))
+ if "rule" in liveaction.context:
+ rule = reference.get_model_from_ref(Rule, liveaction.context.get("rule", {}))
+ attrs["rule"] = vars(RuleAPI.from_model(rule))
- if 'trigger_instance' in liveaction.context:
- trigger_instance_id = liveaction.context.get('trigger_instance', {})
- trigger_instance_id = trigger_instance_id.get('id', None)
+ if "trigger_instance" in liveaction.context:
+ trigger_instance_id = liveaction.context.get("trigger_instance", {})
+ trigger_instance_id = trigger_instance_id.get("id", None)
trigger_instance = TriggerInstance.get_by_id(trigger_instance_id)
- trigger = reference.get_model_by_resource_ref(db_api=Trigger,
- ref=trigger_instance.trigger)
- trigger_type = reference.get_model_by_resource_ref(db_api=TriggerType,
- ref=trigger.type)
+ trigger = reference.get_model_by_resource_ref(
+ db_api=Trigger, ref=trigger_instance.trigger
+ )
+ trigger_type = reference.get_model_by_resource_ref(
+ db_api=TriggerType, ref=trigger.type
+ )
trigger_instance = reference.get_model_from_ref(
- TriggerInstance, liveaction.context.get('trigger_instance', {}))
- attrs['trigger_instance'] = vars(TriggerInstanceAPI.from_model(trigger_instance))
- attrs['trigger'] = vars(TriggerAPI.from_model(trigger))
- attrs['trigger_type'] = vars(TriggerTypeAPI.from_model(trigger_type))
+ TriggerInstance, liveaction.context.get("trigger_instance", {})
+ )
+ attrs["trigger_instance"] = vars(
+ TriggerInstanceAPI.from_model(trigger_instance)
+ )
+ attrs["trigger"] = vars(TriggerAPI.from_model(trigger))
+ attrs["trigger_type"] = vars(TriggerTypeAPI.from_model(trigger_type))
parent = _get_parent_execution(liveaction)
if parent:
- attrs['parent'] = str(parent.id)
+ attrs["parent"] = str(parent.id)
- attrs['log'] = [_create_execution_log_entry(liveaction['status'])]
+ attrs["log"] = [_create_execution_log_entry(liveaction["status"])]
# TODO: This object initialization takes 20-30or so ms
execution = ActionExecutionDB(**attrs)
@@ -146,24 +150,30 @@ def create_execution_object(liveaction, action_db=None, runnertype_db=None, publ
# NOTE: User input data is already validate as part of the API request,
# other data is set by us. Skipping validation here makes operation 10%-30% faster
- execution = ActionExecution.add_or_update(execution, publish=publish, validate=False)
+ execution = ActionExecution.add_or_update(
+ execution, publish=publish, validate=False
+ )
if parent and str(execution.id) not in parent.children:
values = {}
- values['push__children'] = str(execution.id)
+ values["push__children"] = str(execution.id)
ActionExecution.update(parent, **values)
return execution
def _get_parent_execution(child_liveaction_db):
- parent_execution_id = child_liveaction_db.context.get('parent', {}).get('execution_id', None)
+ parent_execution_id = child_liveaction_db.context.get("parent", {}).get(
+ "execution_id", None
+ )
if parent_execution_id:
try:
return ActionExecution.get_by_id(parent_execution_id)
except:
- LOG.exception('No valid execution object found in db for id: %s' % parent_execution_id)
+ LOG.exception(
+ "No valid execution object found in db for id: %s" % parent_execution_id
+ )
return None
return None
@@ -180,12 +190,12 @@ def update_execution(liveaction_db, publish=True):
kw = {}
for k, v in six.iteritems(decomposed):
- kw['set__' + k] = v
+ kw["set__" + k] = v
if liveaction_db.status != execution.status:
# Note: If the status changes we store this transition in the "log" attribute of action
# execution
- kw['push__log'] = _create_execution_log_entry(liveaction_db.status)
+ kw["push__log"] = _create_execution_log_entry(liveaction_db.status)
execution = ActionExecution.update(execution, publish=publish, **kw)
return execution
@@ -201,19 +211,25 @@ def abandon_execution_if_incomplete(liveaction_id, publish=True):
# No need to abandon and already complete action
if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES:
- raise ValueError('LiveAction %s already in a completed state %s.' %
- (liveaction_id, liveaction_db.status))
+ raise ValueError(
+ "LiveAction %s already in a completed state %s."
+ % (liveaction_id, liveaction_db.status)
+ )
# Update status to reflect execution being abandoned.
liveaction_db = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_ABANDONED,
liveaction_db=liveaction_db,
- result={})
+ result={},
+ )
execution_db = update_execution(liveaction_db, publish=publish)
- LOG.info('Marked execution %s as %s.', execution_db.id,
- action_constants.LIVEACTION_STATUS_ABANDONED)
+ LOG.info(
+ "Marked execution %s as %s.",
+ execution_db.id,
+ action_constants.LIVEACTION_STATUS_ABANDONED,
+ )
# Invoke post run on the action to execute post run operations such as callback.
runners_utils.invoke_post_run(liveaction_db)
@@ -236,10 +252,10 @@ def get_parent_context(liveaction_db):
:return: If found the parent context else None.
:rtype: dict
"""
- context = getattr(liveaction_db, 'context', None)
+ context = getattr(liveaction_db, "context", None)
if not context:
return None
- return context.get('parent', None)
+ return context.get("parent", None)
class AscendingSortedDescendantView(object):
@@ -267,8 +283,8 @@ def result(self):
DESCENDANT_VIEWS = {
- 'sorted': AscendingSortedDescendantView,
- 'default': DFSDescendantView
+ "sorted": AscendingSortedDescendantView,
+ "default": DFSDescendantView,
}
@@ -278,9 +294,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None):
the supplied actionexecution_id.
"""
descendants = DESCENDANT_VIEWS.get(result_fmt, DFSDescendantView)()
- children = ActionExecution.query(parent=actionexecution_id,
- **{'order_by': ['start_timestamp']})
- LOG.debug('Found %s children for id %s.', len(children), actionexecution_id)
+ children = ActionExecution.query(
+ parent=actionexecution_id, **{"order_by": ["start_timestamp"]}
+ )
+ LOG.debug("Found %s children for id %s.", len(children), actionexecution_id)
current_level = [(child, 1) for child in children]
while current_level:
@@ -291,8 +308,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None):
continue
if level != -1 and level == descendant_depth:
continue
- children = ActionExecution.query(parent=parent_id, **{'order_by': ['start_timestamp']})
- LOG.debug('Found %s children for id %s.', len(children), parent_id)
+ children = ActionExecution.query(
+ parent=parent_id, **{"order_by": ["start_timestamp"]}
+ )
+ LOG.debug("Found %s children for id %s.", len(children), parent_id)
# prepend for DFS
for idx in range(len(children)):
current_level.insert(idx, (children[idx], level + 1))
diff --git a/st2common/st2common/services/inquiry.py b/st2common/st2common/services/inquiry.py
index 5b511b3a97..09be3cc8f1 100644
--- a/st2common/st2common/services/inquiry.py
+++ b/st2common/st2common/services/inquiry.py
@@ -40,9 +40,11 @@
def check_inquiry(inquiry):
- LOG.debug('Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id))
+ LOG.debug(
+ 'Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id)
+ )
- if inquiry.runner.get('name') != 'inquirer':
+ if inquiry.runner.get("name") != "inquirer":
raise inquiry_exceptions.InvalidInquiryInstance(str(inquiry.id))
LOG.debug('Checking if the inquiry "%s" has timed out.' % str(inquiry.id))
@@ -69,7 +71,7 @@ def check_permission(inquiry, requester):
users_passed = False
# Determine role-level permissions
- roles = getattr(inquiry, 'roles', [])
+ roles = getattr(inquiry, "roles", [])
if not roles:
# No roles definition so we treat it as a pass
@@ -79,14 +81,16 @@ def check_permission(inquiry, requester):
rbac_utils = get_rbac_backend().get_utils_class()
user_has_role = rbac_utils.user_has_role(user_db, role)
- LOG.debug('Checking user %s is in role %s - %s' % (user_db, role, user_has_role))
+ LOG.debug(
+ "Checking user %s is in role %s - %s" % (user_db, role, user_has_role)
+ )
if user_has_role:
roles_passed = True
break
# Determine user-level permissions
- users = getattr(inquiry, 'users', [])
+ users = getattr(inquiry, "users", [])
if not users or user_db.name in users:
users_passed = True
@@ -98,7 +102,7 @@ def check_permission(inquiry, requester):
def validate_response(inquiry, response):
schema = inquiry.schema
- LOG.debug('Validating inquiry response: %s against schema: %s' % (response, schema))
+ LOG.debug("Validating inquiry response: %s against schema: %s" % (response, schema))
try:
schema_utils.validate(
@@ -106,12 +110,14 @@ def validate_response(inquiry, response):
schema=schema,
cls=schema_utils.CustomValidator,
use_default=True,
- allow_default_none=True
+ allow_default_none=True,
)
except Exception as e:
msg = 'Response for inquiry "%s" did not pass schema validation.'
LOG.exception(msg % str(inquiry.id))
- raise inquiry_exceptions.InvalidInquiryResponse(str(inquiry.id), six.text_type(e))
+ raise inquiry_exceptions.InvalidInquiryResponse(
+ str(inquiry.id), six.text_type(e)
+ )
def respond(inquiry, response, requester=None):
@@ -120,14 +126,14 @@ def respond(inquiry, response, requester=None):
requester = cfg.CONF.system_user.user
# Retrieve the liveaction from the database.
- liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get('id'))
+ liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get("id"))
# Resume the parent workflow first. If the action execution for the inquiry is updated first,
# it triggers handling of the action execution completion which will interact with the paused
# parent workflow. The resuming logic that is executed here will then race with the completion
# of the inquiry action execution, which will randomly result in the parent workflow stuck in
# paused state.
- if liveaction_db.context.get('parent'):
+ if liveaction_db.context.get("parent"):
LOG.debug('Resuming workflow parent(s) for inquiry "%s".' % str(inquiry.id))
# For action execution under Action Chain workflows, request the entire
@@ -136,7 +142,9 @@ def respond(inquiry, response, requester=None):
# there is no other paused branches, the conductor will resume the rest of the workflow.
resume_target = (
action_service.get_parent_liveaction(liveaction_db)
- if workflow_service.is_action_execution_under_workflow_context(liveaction_db)
+ if workflow_service.is_action_execution_under_workflow_context(
+ liveaction_db
+ )
else action_service.get_root_liveaction(liveaction_db)
)
@@ -147,14 +155,14 @@ def respond(inquiry, response, requester=None):
LOG.debug('Updating response for inquiry "%s".' % str(inquiry.id))
result = copy.deepcopy(inquiry.result)
- result['response'] = response
+ result["response"] = response
liveaction_db = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_SUCCEEDED,
end_timestamp=date_utils.get_datetime_utc_now(),
runner_info=sys_info_utils.get_process_info(),
result=result,
- liveaction_id=str(liveaction_db.id)
+ liveaction_id=str(liveaction_db.id),
)
# Sync the liveaction with the corresponding action execution.
@@ -164,7 +172,7 @@ def respond(inquiry, response, requester=None):
LOG.debug('Invoking post run for inquiry "%s".' % str(inquiry.id))
runner_container = container.get_runner_container()
action_db = action_utils.get_action_by_ref(liveaction_db.action)
- runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"])
runner = runner_container._get_runner(runnertype_db, action_db, liveaction_db)
runner.post_run(status=action_constants.LIVEACTION_STATUS_SUCCEEDED, result=result)
diff --git a/st2common/st2common/services/keyvalues.py b/st2common/st2common/services/keyvalues.py
index 722603eee5..d38f28ca93 100644
--- a/st2common/st2common/services/keyvalues.py
+++ b/st2common/st2common/services/keyvalues.py
@@ -28,11 +28,10 @@
from st2common.persistence.keyvalue import KeyValuePair
__all__ = [
- 'get_kvp_for_name',
- 'get_values_for_names',
-
- 'KeyValueLookup',
- 'UserKeyValueLookup'
+ "get_kvp_for_name",
+ "get_values_for_names",
+ "KeyValueLookup",
+ "UserKeyValueLookup",
]
LOG = logging.getLogger(__name__)
@@ -81,17 +80,17 @@ def get_key_name(self):
:rtype: ``str``
"""
key_name_parts = [DATASTORE_PARENT_SCOPE, self.scope]
- key_name = self._key_prefix.split(':', 1)
+ key_name = self._key_prefix.split(":", 1)
if len(key_name) == 1:
key_name = key_name[0]
elif len(key_name) >= 2:
key_name = key_name[1]
else:
- key_name = ''
+ key_name = ""
key_name_parts.append(key_name)
- key_name = '.'.join(key_name_parts)
+ key_name = ".".join(key_name_parts)
return key_name
@@ -99,7 +98,9 @@ class KeyValueLookup(BaseKeyValueLookup):
scope = SYSTEM_SCOPE
- def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE):
+ def __init__(
+ self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE
+ ):
if not scope:
scope = FULL_SYSTEM_SCOPE
@@ -107,7 +108,7 @@ def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_S
scope = FULL_SYSTEM_SCOPE
self._prefix = prefix
- self._key_prefix = key_prefix or ''
+ self._key_prefix = key_prefix or ""
self._value_cache = cache or {}
self._scope = scope
@@ -129,7 +130,7 @@ def __getattr__(self, name):
def _get(self, name):
# get the value for this key and save in value_cache
if self._key_prefix:
- key = '%s.%s' % (self._key_prefix, name)
+ key = "%s.%s" % (self._key_prefix, name)
else:
key = name
@@ -144,12 +145,16 @@ def _get(self, name):
# the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja,
# will expect to do a dictionary style lookup for key_base and key_value as subsequent
# calls. Saving the value in cache avoids extra DB calls.
- return KeyValueLookup(prefix=self._prefix, key_prefix=key, cache=self._value_cache,
- scope=self._scope)
+ return KeyValueLookup(
+ prefix=self._prefix,
+ key_prefix=key,
+ cache=self._value_cache,
+ scope=self._scope,
+ )
def _get_kv(self, key):
scope = self._scope
- LOG.debug('Lookup system kv: scope: %s and key: %s', scope, key)
+ LOG.debug("Lookup system kv: scope: %s and key: %s", scope, key)
try:
kvp = KeyValuePair.get_by_scope_and_name(scope=scope, name=key)
@@ -157,15 +162,17 @@ def _get_kv(self, key):
kvp = None
if kvp:
- LOG.debug('Got value %s from datastore.', kvp.value)
- return kvp.value if kvp else ''
+ LOG.debug("Got value %s from datastore.", kvp.value)
+ return kvp.value if kvp else ""
class UserKeyValueLookup(BaseKeyValueLookup):
scope = USER_SCOPE
- def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE):
+ def __init__(
+ self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE
+ ):
if not scope:
scope = FULL_USER_SCOPE
@@ -173,7 +180,7 @@ def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_US
scope = FULL_USER_SCOPE
self._prefix = prefix
- self._key_prefix = key_prefix or ''
+ self._key_prefix = key_prefix or ""
self._value_cache = cache or {}
self._user = user
self._scope = scope
@@ -190,7 +197,7 @@ def __getattr__(self, name):
def _get(self, name):
# get the value for this key and save in value_cache
if self._key_prefix:
- key = '%s.%s' % (self._key_prefix, name)
+ key = "%s.%s" % (self._key_prefix, name)
else:
key = UserKeyReference(name=name, user=self._user).ref
@@ -205,8 +212,13 @@ def _get(self, name):
# the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja,
# will expect to do a dictionary style lookup for key_base and key_value as subsequent
# calls. Saving the value in cache avoids extra DB calls.
- return UserKeyValueLookup(prefix=self._prefix, user=self._user, key_prefix=key,
- cache=self._value_cache, scope=self._scope)
+ return UserKeyValueLookup(
+ prefix=self._prefix,
+ user=self._user,
+ key_prefix=key,
+ cache=self._value_cache,
+ scope=self._scope,
+ )
def _get_kv(self, key):
scope = self._scope
@@ -216,7 +228,7 @@ def _get_kv(self, key):
except StackStormDBObjectNotFoundError:
kvp = None
- return kvp.value if kvp else ''
+ return kvp.value if kvp else ""
def get_key_reference(scope, name, user=None):
@@ -232,12 +244,15 @@ def get_key_reference(scope, name, user=None):
:rtype: ``str``
"""
- if (scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE):
+ if scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE:
return name
- elif (scope == USER_SCOPE or scope == FULL_USER_SCOPE):
+ elif scope == USER_SCOPE or scope == FULL_USER_SCOPE:
if not user:
- raise InvalidUserException('A valid user must be specified for user key ref.')
+ raise InvalidUserException(
+ "A valid user must be specified for user key ref."
+ )
return UserKeyReference(name=name, user=user).ref
else:
- raise InvalidScopeException('Scope "%s" is not valid. Allowed scopes are %s.' %
- (scope, ALLOWED_SCOPES))
+ raise InvalidScopeException(
+ 'Scope "%s" is not valid. Allowed scopes are %s.' % (scope, ALLOWED_SCOPES)
+ )
diff --git a/st2common/st2common/services/packs.py b/st2common/st2common/services/packs.py
index 7088b5f368..9f2794ed78 100644
--- a/st2common/st2common/services/packs.py
+++ b/st2common/st2common/services/packs.py
@@ -27,21 +27,15 @@
from six.moves import range
__all__ = [
- 'get_pack_by_ref',
- 'fetch_pack_index',
- 'get_pack_from_index',
- 'search_pack_index'
+ "get_pack_by_ref",
+ "fetch_pack_index",
+ "get_pack_from_index",
+ "search_pack_index",
]
-EXCLUDE_FIELDS = [
- "repo_url",
- "email"
-]
+EXCLUDE_FIELDS = ["repo_url", "email"]
-SEARCH_PRIORITY = [
- "name",
- "keywords"
-]
+SEARCH_PRIORITY = ["name", "keywords"]
LOG = logging.getLogger(__name__)
@@ -55,7 +49,7 @@ def _build_index_list(index_url):
index_urls = cfg.CONF.content.index_url[::-1]
elif isinstance(index_url, str):
index_urls = [index_url]
- elif hasattr(index_url, '__iter__'):
+ elif hasattr(index_url, "__iter__"):
index_urls = index_url
else:
raise TypeError('"index_url" should either be a string or an iterable object.')
@@ -73,23 +67,23 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None):
verify = True
if proxy_config:
- https_proxy = proxy_config.get('https_proxy', None)
- http_proxy = proxy_config.get('http_proxy', None)
- ca_bundle_path = proxy_config.get('proxy_ca_bundle_path', None)
+ https_proxy = proxy_config.get("https_proxy", None)
+ http_proxy = proxy_config.get("http_proxy", None)
+ ca_bundle_path = proxy_config.get("proxy_ca_bundle_path", None)
if https_proxy:
- proxies_dict['https'] = https_proxy
+ proxies_dict["https"] = https_proxy
verify = ca_bundle_path or True
if http_proxy:
- proxies_dict['http'] = http_proxy
+ proxies_dict["http"] = http_proxy
for index_url in index_urls:
index_status = {
- 'url': index_url,
- 'packs': 0,
- 'message': None,
- 'error': None,
+ "url": index_url,
+ "packs": 0,
+ "message": None,
+ "error": None,
}
index_json = None
@@ -98,32 +92,32 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None):
request.raise_for_status()
index_json = request.json()
except ValueError as e:
- index_status['error'] = 'malformed'
- index_status['message'] = repr(e)
+ index_status["error"] = "malformed"
+ index_status["message"] = repr(e)
except requests.exceptions.RequestException as e:
- index_status['error'] = 'unresponsive'
- index_status['message'] = repr(e)
+ index_status["error"] = "unresponsive"
+ index_status["message"] = repr(e)
except Exception as e:
- index_status['error'] = 'other errors'
- index_status['message'] = repr(e)
+ index_status["error"] = "other errors"
+ index_status["message"] = repr(e)
if index_json == {}:
- index_status['error'] = 'empty'
- index_status['message'] = 'The index URL returned an empty object.'
+ index_status["error"] = "empty"
+ index_status["message"] = "The index URL returned an empty object."
elif type(index_json) is list:
- index_status['error'] = 'malformed'
- index_status['message'] = 'Expected an index object, got a list instead.'
- elif index_json and 'packs' not in index_json:
- index_status['error'] = 'malformed'
- index_status['message'] = 'Index object is missing "packs" attribute.'
+ index_status["error"] = "malformed"
+ index_status["message"] = "Expected an index object, got a list instead."
+ elif index_json and "packs" not in index_json:
+ index_status["error"] = "malformed"
+ index_status["message"] = 'Index object is missing "packs" attribute.'
- if index_status['error']:
+ if index_status["error"]:
logger.error("Index parsing error: %s" % json.dumps(index_status, indent=4))
else:
# TODO: Notify on a duplicate pack aka pack being overwritten from a different index
- packs_data = index_json['packs']
- index_status['message'] = 'Success.'
- index_status['packs'] = len(packs_data)
+ packs_data = index_json["packs"]
+ index_status["message"] = "Success."
+ index_status["packs"] = len(packs_data)
index.update(packs_data)
status.append(index_status)
@@ -147,8 +141,9 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi
logger = logger or LOG
index_urls = _build_index_list(index_url)
- index, status = _fetch_and_compile_index(index_urls=index_urls, logger=logger,
- proxy_config=proxy_config)
+ index, status = _fetch_and_compile_index(
+ index_urls=index_urls, logger=logger, proxy_config=proxy_config
+ )
# If one of the indexes on the list is unresponsive, we do not throw
# immediately. The only case where an exception is raised is when no
@@ -156,11 +151,14 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi
# This behavior allows for mirrors / backups and handling connection
# or network issues in one of the indexes.
if not index and not allow_empty:
- raise ValueError("No results from the %s: tried %s.\nStatus: %s" % (
- ("index" if len(index_urls) == 1 else "indexes"),
- ", ".join(index_urls),
- json.dumps(status, indent=4)
- ))
+ raise ValueError(
+ "No results from the %s: tried %s.\nStatus: %s"
+ % (
+ ("index" if len(index_urls) == 1 else "indexes"),
+ ", ".join(index_urls),
+ json.dumps(status, indent=4),
+ )
+ )
return (index, status)
@@ -177,13 +175,15 @@ def get_pack_from_index(pack, proxy_config=None):
return index.get(pack)
-def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, proxy_config=None):
+def search_pack_index(
+ query, exclude=None, priority=None, case_sensitive=True, proxy_config=None
+):
"""
Search the pack index by query.
Returns a list of matches for a query.
"""
if not query:
- raise ValueError('Query must be specified.')
+ raise ValueError("Query must be specified.")
if not exclude:
exclude = EXCLUDE_FIELDS
@@ -198,7 +198,7 @@ def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, p
matches = [[] for i in range(len(priority) + 1)]
for pack in six.itervalues(index):
for key, value in six.iteritems(pack):
- if not hasattr(value, '__contains__'):
+ if not hasattr(value, "__contains__"):
value = str(value)
if not case_sensitive:
diff --git a/st2common/st2common/services/policies.py b/st2common/st2common/services/policies.py
index 50ba28f304..46e24ce290 100644
--- a/st2common/st2common/services/policies.py
+++ b/st2common/st2common/services/policies.py
@@ -25,13 +25,10 @@
def has_policies(lv_ac_db, policy_types=None):
- query_params = {
- 'resource_ref': lv_ac_db.action,
- 'enabled': True
- }
+ query_params = {"resource_ref": lv_ac_db.action, "enabled": True}
if policy_types:
- query_params['policy_type__in'] = policy_types
+ query_params["policy_type__in"] = policy_types
policy_dbs = pc_db_access.Policy.query(**query_params)
@@ -42,11 +39,19 @@ def apply_pre_run_policies(lv_ac_db):
LOG.debug('Applying pre-run policies for liveaction "%s".' % str(lv_ac_db.id))
policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True)
- LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action))
+ LOG.debug(
+ 'Identified %s policies for the action "%s".'
+ % (len(policy_dbs), lv_ac_db.action)
+ )
for policy_db in policy_dbs:
- LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type))
- driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters)
+ LOG.debug(
+ 'Getting driver for policy "%s" (%s).'
+ % (policy_db.ref, policy_db.policy_type)
+ )
+ driver = engine.get_driver(
+ policy_db.ref, policy_db.policy_type, **policy_db.parameters
+ )
try:
message = 'Applying policy "%s" (%s) for liveaction "%s".'
@@ -54,7 +59,9 @@ def apply_pre_run_policies(lv_ac_db):
lv_ac_db = driver.apply_before(lv_ac_db)
except:
message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".'
- LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)))
+ LOG.exception(
+ message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))
+ )
if lv_ac_db.status == ac_const.LIVEACTION_STATUS_DELAYED:
break
@@ -66,11 +73,19 @@ def apply_post_run_policies(lv_ac_db):
LOG.debug('Applying post run policies for liveaction "%s".' % str(lv_ac_db.id))
policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True)
- LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action))
+ LOG.debug(
+ 'Identified %s policies for the action "%s".'
+ % (len(policy_dbs), lv_ac_db.action)
+ )
for policy_db in policy_dbs:
- LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type))
- driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters)
+ LOG.debug(
+ 'Getting driver for policy "%s" (%s).'
+ % (policy_db.ref, policy_db.policy_type)
+ )
+ driver = engine.get_driver(
+ policy_db.ref, policy_db.policy_type, **policy_db.parameters
+ )
try:
message = 'Applying policy "%s" (%s) for liveaction "%s".'
@@ -78,6 +93,8 @@ def apply_post_run_policies(lv_ac_db):
lv_ac_db = driver.apply_after(lv_ac_db)
except:
message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".'
- LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)))
+ LOG.exception(
+ message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))
+ )
return lv_ac_db
diff --git a/st2common/st2common/services/queries.py b/st2common/st2common/services/queries.py
index e6d769e365..20c7a0c990 100644
--- a/st2common/st2common/services/queries.py
+++ b/st2common/st2common/services/queries.py
@@ -25,13 +25,15 @@
def setup_query(liveaction_id, runnertype_db, query_context):
- if not getattr(runnertype_db, 'query_module', None):
- raise Exception('The runner "%s" does not have a query module.' % runnertype_db.name)
+ if not getattr(runnertype_db, "query_module", None):
+ raise Exception(
+ 'The runner "%s" does not have a query module.' % runnertype_db.name
+ )
state_db = ActionExecutionStateDB(
execution_id=liveaction_id,
query_module=runnertype_db.query_module,
- query_context=query_context
+ query_context=query_context,
)
ActionExecutionState.add_or_update(state_db)
diff --git a/st2common/st2common/services/rules.py b/st2common/st2common/services/rules.py
index d9be718e27..ebb8083433 100644
--- a/st2common/st2common/services/rules.py
+++ b/st2common/st2common/services/rules.py
@@ -22,10 +22,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'get_rules_given_trigger',
- 'get_rules_with_trigger_ref'
-]
+__all__ = ["get_rules_given_trigger", "get_rules_with_trigger_ref"]
def get_rules_given_trigger(trigger):
@@ -34,13 +31,15 @@ def get_rules_given_trigger(trigger):
return get_rules_with_trigger_ref(trigger_ref=trigger)
if isinstance(trigger, dict):
- trigger_ref = trigger.get('ref', None)
+ trigger_ref = trigger.get("ref", None)
if trigger_ref:
return get_rules_with_trigger_ref(trigger_ref=trigger_ref)
else:
- raise ValueError('Trigger dict %s is missing ``ref``.' % trigger)
+ raise ValueError("Trigger dict %s is missing ``ref``." % trigger)
- raise ValueError('Unknown type %s for trigger. Cannot do rule lookups.' % type(trigger))
+ raise ValueError(
+ "Unknown type %s for trigger. Cannot do rule lookups." % type(trigger)
+ )
def get_rules_with_trigger_ref(trigger_ref=None, enabled=True):
@@ -56,5 +55,5 @@ def get_rules_with_trigger_ref(trigger_ref=None, enabled=True):
if not trigger_ref:
return None
- LOG.debug('Querying rules with trigger %s', trigger_ref)
+ LOG.debug("Querying rules with trigger %s", trigger_ref)
return Rule.query(trigger=trigger_ref, enabled=enabled)
diff --git a/st2common/st2common/services/sensor_watcher.py b/st2common/st2common/services/sensor_watcher.py
index 0105ba46d6..1c54881663 100644
--- a/st2common/st2common/services/sensor_watcher.py
+++ b/st2common/st2common/services/sensor_watcher.py
@@ -32,9 +32,9 @@
class SensorWatcher(ConsumerMixin):
-
- def __init__(self, create_handler, update_handler, delete_handler,
- queue_suffix=None):
+ def __init__(
+ self, create_handler, update_handler, delete_handler, queue_suffix=None
+ ):
"""
:param create_handler: Function which is called on SensorDB create event.
:type create_handler: ``callable``
@@ -57,34 +57,41 @@ def __init__(self, create_handler, update_handler, delete_handler,
self._handlers = {
publishers.CREATE_RK: create_handler,
publishers.UPDATE_RK: update_handler,
- publishers.DELETE_RK: delete_handler
+ publishers.DELETE_RK: delete_handler,
}
def get_consumers(self, Consumer, channel):
- consumers = [Consumer(queues=[self._sensor_watcher_q],
- accept=['pickle'],
- callbacks=[self.process_task])]
+ consumers = [
+ Consumer(
+ queues=[self._sensor_watcher_q],
+ accept=["pickle"],
+ callbacks=[self.process_task],
+ )
+ ]
return consumers
def process_task(self, body, message):
- LOG.debug('process_task')
- LOG.debug(' body: %s', body)
- LOG.debug(' message.properties: %s', message.properties)
- LOG.debug(' message.delivery_info: %s', message.delivery_info)
+ LOG.debug("process_task")
+ LOG.debug(" body: %s", body)
+ LOG.debug(" message.properties: %s", message.properties)
+ LOG.debug(" message.delivery_info: %s", message.delivery_info)
- routing_key = message.delivery_info.get('routing_key', '')
+ routing_key = message.delivery_info.get("routing_key", "")
handler = self._handlers.get(routing_key, None)
try:
if not handler:
- LOG.info('Skipping message %s as no handler was found.', message)
+ LOG.info("Skipping message %s as no handler was found.", message)
return
try:
handler(body)
except Exception as e:
- LOG.exception('Handling failed. Message body: %s. Exception: %s',
- body, six.text_type(e))
+ LOG.exception(
+ "Handling failed. Message body: %s. Exception: %s",
+ body,
+ six.text_type(e),
+ )
finally:
message.ack()
@@ -93,11 +100,11 @@ def start(self):
self.connection = transport_utils.get_connection()
self._updates_thread = concurrency.spawn(self.run)
except:
- LOG.exception('Failed to start sensor_watcher.')
+ LOG.exception("Failed to start sensor_watcher.")
self.connection.release()
def stop(self):
- LOG.debug('Shutting down sensor watcher.')
+ LOG.debug("Shutting down sensor watcher.")
try:
if self._updates_thread:
self._updates_thread = concurrency.kill(self._updates_thread)
@@ -108,15 +115,19 @@ def stop(self):
try:
bound_sensor_watch_q.delete()
except:
- LOG.error('Unable to delete sensor watcher queue: %s', self._sensor_watcher_q)
+ LOG.error(
+ "Unable to delete sensor watcher queue: %s",
+ self._sensor_watcher_q,
+ )
finally:
if self.connection:
self.connection.release()
@staticmethod
def _get_queue(queue_suffix):
- queue_name = queue_utils.get_queue_name(queue_name_base='st2.sensor.watch',
- queue_name_suffix=queue_suffix,
- add_random_uuid_to_suffix=True
- )
- return reactor.get_sensor_cud_queue(queue_name, routing_key='#')
+ queue_name = queue_utils.get_queue_name(
+ queue_name_base="st2.sensor.watch",
+ queue_name_suffix=queue_suffix,
+ add_random_uuid_to_suffix=True,
+ )
+ return reactor.get_sensor_cud_queue(queue_name, routing_key="#")
diff --git a/st2common/st2common/services/trace.py b/st2common/st2common/services/trace.py
index 3eb92bd2f1..4dadef0964 100644
--- a/st2common/st2common/services/trace.py
+++ b/st2common/st2common/services/trace.py
@@ -32,22 +32,24 @@
LOG = logging.getLogger(__name__)
__all__ = [
- 'get_trace_db_by_action_execution',
- 'get_trace_db_by_rule',
- 'get_trace_db_by_trigger_instance',
- 'get_trace',
- 'add_or_update_given_trace_context',
- 'add_or_update_given_trace_db',
- 'get_trace_component_for_action_execution',
- 'get_trace_component_for_rule',
- 'get_trace_component_for_trigger_instance'
+ "get_trace_db_by_action_execution",
+ "get_trace_db_by_rule",
+ "get_trace_db_by_trigger_instance",
+ "get_trace",
+ "add_or_update_given_trace_context",
+ "add_or_update_given_trace_db",
+ "get_trace_component_for_action_execution",
+ "get_trace_component_for_rule",
+ "get_trace_component_for_trigger_instance",
]
ACTION_SENSOR_TRIGGER_REF = ResourceReference.to_string_reference(
- pack=ACTION_SENSOR_TRIGGER['pack'], name=ACTION_SENSOR_TRIGGER['name'])
+ pack=ACTION_SENSOR_TRIGGER["pack"], name=ACTION_SENSOR_TRIGGER["name"]
+)
NOTIFY_TRIGGER_REF = ResourceReference.to_string_reference(
- pack=NOTIFY_TRIGGER['pack'], name=NOTIFY_TRIGGER['name'])
+ pack=NOTIFY_TRIGGER["pack"], name=NOTIFY_TRIGGER["name"]
+)
def _get_valid_trace_context(trace_context):
@@ -74,14 +76,17 @@ def _get_single_trace_by_component(**component_filter):
return None
elif len(traces) > 1:
raise UniqueTraceNotFoundException(
- 'More than 1 trace matching %s found.' % component_filter)
+ "More than 1 trace matching %s found." % component_filter
+ )
return traces[0]
def get_trace_db_by_action_execution(action_execution=None, action_execution_id=None):
if action_execution:
action_execution_id = str(action_execution.id)
- return _get_single_trace_by_component(action_executions__object_id=action_execution_id)
+ return _get_single_trace_by_component(
+ action_executions__object_id=action_execution_id
+ )
def get_trace_db_by_rule(rule=None, rule_id=None):
@@ -94,7 +99,9 @@ def get_trace_db_by_rule(rule=None, rule_id=None):
def get_trace_db_by_trigger_instance(trigger_instance=None, trigger_instance_id=None):
if trigger_instance:
trigger_instance_id = str(trigger_instance.id)
- return _get_single_trace_by_component(trigger_instances__object_id=trigger_instance_id)
+ return _get_single_trace_by_component(
+ trigger_instances__object_id=trigger_instance_id
+ )
def get_trace(trace_context, ignore_trace_tag=False):
@@ -111,16 +118,20 @@ def get_trace(trace_context, ignore_trace_tag=False):
trace_context = _get_valid_trace_context(trace_context)
if not trace_context.id_ and not trace_context.trace_tag:
- raise ValueError('Atleast one of id_ or trace_tag should be specified.')
+ raise ValueError("Atleast one of id_ or trace_tag should be specified.")
if trace_context.id_:
try:
return Trace.get_by_id(trace_context.id_)
except (ValidationError, ValueError):
- LOG.warning('Database lookup for Trace with id="%s" failed.',
- trace_context.id_, exc_info=True)
+ LOG.warning(
+ 'Database lookup for Trace with id="%s" failed.',
+ trace_context.id_,
+ exc_info=True,
+ )
raise StackStormDBObjectNotFoundError(
- 'Unable to find Trace with id="%s"' % trace_context.id_)
+ 'Unable to find Trace with id="%s"' % trace_context.id_
+ )
if ignore_trace_tag:
return None
@@ -130,7 +141,8 @@ def get_trace(trace_context, ignore_trace_tag=False):
# Assume this method only handles 1 trace.
if len(traces) > 1:
raise UniqueTraceNotFoundException(
- 'More than 1 Trace matching %s found.' % trace_context.trace_tag)
+ "More than 1 Trace matching %s found." % trace_context.trace_tag
+ )
return traces[0]
@@ -168,14 +180,17 @@ def get_trace_db_by_live_action(liveaction):
# This cover case for child execution of a workflow.
parent_context = executions.get_parent_context(liveaction_db=liveaction)
if not trace_context and parent_context:
- parent_execution_id = parent_context.get('execution_id', None)
+ parent_execution_id = parent_context.get("execution_id", None)
if parent_execution_id:
# go straight to a trace_db. If there is a parent execution then that must
# be associated with a Trace.
- trace_db = get_trace_db_by_action_execution(action_execution_id=parent_execution_id)
+ trace_db = get_trace_db_by_action_execution(
+ action_execution_id=parent_execution_id
+ )
if not trace_db:
- raise StackStormDBObjectNotFoundError('No trace found for execution %s' %
- parent_execution_id)
+ raise StackStormDBObjectNotFoundError(
+ "No trace found for execution %s" % parent_execution_id
+ )
return (created, trace_db)
# 3. Check if the action_execution associated with liveaction leads to a trace_db
execution = ActionExecution.get(liveaction__id=str(liveaction.id))
@@ -184,13 +199,14 @@ def get_trace_db_by_live_action(liveaction):
# 4. No trace_db found, therefore create one. This typically happens
# when execution is run by hand.
if not trace_db:
- trace_db = TraceDB(trace_tag='execution-%s' % str(liveaction.id))
+ trace_db = TraceDB(trace_tag="execution-%s" % str(liveaction.id))
created = True
return (created, trace_db)
-def add_or_update_given_trace_context(trace_context, action_executions=None, rules=None,
- trigger_instances=None):
+def add_or_update_given_trace_context(
+ trace_context, action_executions=None, rules=None, trigger_instances=None
+):
"""
Will update an existing Trace or add a new Trace. This method will only look for exact
Trace as identified by the trace_context. Even if the trace_context contain a trace_tag
@@ -222,14 +238,17 @@ def add_or_update_given_trace_context(trace_context, action_executions=None, rul
# since trace_db is None need to end up with a valid trace_context
trace_context = _get_valid_trace_context(trace_context)
trace_db = TraceDB(trace_tag=trace_context.trace_tag)
- return add_or_update_given_trace_db(trace_db=trace_db,
- action_executions=action_executions,
- rules=rules,
- trigger_instances=trigger_instances)
+ return add_or_update_given_trace_db(
+ trace_db=trace_db,
+ action_executions=action_executions,
+ rules=rules,
+ trigger_instances=trigger_instances,
+ )
-def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None,
- trigger_instances=None):
+def add_or_update_given_trace_db(
+ trace_db, action_executions=None, rules=None, trigger_instances=None
+):
"""
Will update an existing Trace.
@@ -251,12 +270,14 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None,
:rtype: ``TraceDB``
"""
if trace_db is None:
- raise ValueError('trace_db should be non-None.')
+ raise ValueError("trace_db should be non-None.")
if not action_executions:
action_executions = []
- action_executions = [_to_trace_component_db(component=action_execution)
- for action_execution in action_executions]
+ action_executions = [
+ _to_trace_component_db(component=action_execution)
+ for action_execution in action_executions
+ ]
if not rules:
rules = []
@@ -264,16 +285,20 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None,
if not trigger_instances:
trigger_instances = []
- trigger_instances = [_to_trace_component_db(component=trigger_instance)
- for trigger_instance in trigger_instances]
+ trigger_instances = [
+ _to_trace_component_db(component=trigger_instance)
+ for trigger_instance in trigger_instances
+ ]
# If an id exists then this is an update and we do not want to perform
# an upsert so use push_components which will use the push operator.
if trace_db.id:
- return Trace.push_components(trace_db,
- action_executions=action_executions,
- rules=rules,
- trigger_instances=trigger_instances)
+ return Trace.push_components(
+ trace_db,
+ action_executions=action_executions,
+ rules=rules,
+ trigger_instances=trigger_instances,
+ )
trace_db.action_executions = action_executions
trace_db.rules = rules
@@ -295,23 +320,25 @@ def get_trace_component_for_action_execution(action_execution_db, liveaction_db)
:rtype: ``dict``
"""
if not action_execution_db:
- raise ValueError('action_execution_db expected.')
+ raise ValueError("action_execution_db expected.")
trace_component = {
- 'id': str(action_execution_db.id),
- 'ref': str(action_execution_db.action.get('ref', ''))
+ "id": str(action_execution_db.id),
+ "ref": str(action_execution_db.action.get("ref", "")),
}
caused_by = {}
parent_context = executions.get_parent_context(liveaction_db=liveaction_db)
if liveaction_db and parent_context:
- caused_by['type'] = 'action_execution'
- caused_by['id'] = liveaction_db.context['parent'].get('execution_id', None)
+ caused_by["type"] = "action_execution"
+ caused_by["id"] = liveaction_db.context["parent"].get("execution_id", None)
elif action_execution_db.rule and action_execution_db.trigger_instance:
# Once RuleEnforcement is available that can be used instead.
- caused_by['type'] = 'rule'
- caused_by['id'] = '%s:%s' % (action_execution_db.rule['id'],
- action_execution_db.trigger_instance['id'])
+ caused_by["type"] = "rule"
+ caused_by["id"] = "%s:%s" % (
+ action_execution_db.rule["id"],
+ action_execution_db.trigger_instance["id"],
+ )
- trace_component['caused_by'] = caused_by
+ trace_component["caused_by"] = caused_by
return trace_component
@@ -328,13 +355,13 @@ def get_trace_component_for_rule(rule_db, trigger_instance_db):
:rtype: ``dict``
"""
trace_component = {}
- trace_component = {'id': str(rule_db.id), 'ref': rule_db.ref}
+ trace_component = {"id": str(rule_db.id), "ref": rule_db.ref}
caused_by = {}
if trigger_instance_db:
# Once RuleEnforcement is available that can be used instead.
- caused_by['type'] = 'trigger_instance'
- caused_by['id'] = str(trigger_instance_db.id)
- trace_component['caused_by'] = caused_by
+ caused_by["type"] = "trigger_instance"
+ caused_by["id"] = str(trigger_instance_db.id)
+ trace_component["caused_by"] = caused_by
return trace_component
@@ -349,18 +376,20 @@ def get_trace_component_for_trigger_instance(trigger_instance_db):
"""
trace_component = {}
trace_component = {
- 'id': str(trigger_instance_db.id),
- 'ref': trigger_instance_db.trigger
+ "id": str(trigger_instance_db.id),
+ "ref": trigger_instance_db.trigger,
}
caused_by = {}
# Special handling for ACTION_SENSOR_TRIGGER and NOTIFY_TRIGGER where we
# know how to maintain the links.
- if trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF or \
- trigger_instance_db.trigger == NOTIFY_TRIGGER_REF:
- caused_by['type'] = 'action_execution'
+ if (
+ trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF
+ or trigger_instance_db.trigger == NOTIFY_TRIGGER_REF
+ ):
+ caused_by["type"] = "action_execution"
# For both action trigger and notidy trigger execution_id is stored in the payload.
- caused_by['id'] = trigger_instance_db.payload['execution_id']
- trace_component['caused_by'] = caused_by
+ caused_by["id"] = trigger_instance_db.payload["execution_id"]
+ trace_component["caused_by"] = caused_by
return trace_component
@@ -376,10 +405,12 @@ def _to_trace_component_db(component):
"""
if not isinstance(component, (six.string_types, dict)):
print(type(component))
- raise ValueError('Expected component to be str or dict')
+ raise ValueError("Expected component to be str or dict")
- object_id = component if isinstance(component, six.string_types) else component['id']
- ref = component.get('ref', '') if isinstance(component, dict) else ''
- caused_by = component.get('caused_by', {}) if isinstance(component, dict) else {}
+ object_id = (
+ component if isinstance(component, six.string_types) else component["id"]
+ )
+ ref = component.get("ref", "") if isinstance(component, dict) else ""
+ caused_by = component.get("caused_by", {}) if isinstance(component, dict) else {}
return TraceComponentDB(object_id=object_id, ref=ref, caused_by=caused_by)
diff --git a/st2common/st2common/services/trigger_dispatcher.py b/st2common/st2common/services/trigger_dispatcher.py
index 6843a1eb74..6343a555b9 100644
--- a/st2common/st2common/services/trigger_dispatcher.py
+++ b/st2common/st2common/services/trigger_dispatcher.py
@@ -23,9 +23,7 @@
from st2common.transport.reactor import TriggerDispatcher
from st2common.validators.api.reactor import validate_trigger_payload
-__all__ = [
- 'TriggerDispatcherService'
-]
+__all__ = ["TriggerDispatcherService"]
class TriggerDispatcherService(object):
@@ -37,7 +35,9 @@ def __init__(self, logger):
self._logger = logger
self._dispatcher = TriggerDispatcher(self._logger)
- def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False):
+ def dispatch(
+ self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False
+ ):
"""
Method which dispatches the trigger.
@@ -56,12 +56,19 @@ def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_er
"""
# empty strings
trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None
- self._logger.debug('Added trace_context %s to trigger %s.', trace_context, trigger)
- return self.dispatch_with_context(trigger, payload=payload, trace_context=trace_context,
- throw_on_validation_error=throw_on_validation_error)
-
- def dispatch_with_context(self, trigger, payload=None, trace_context=None,
- throw_on_validation_error=False):
+ self._logger.debug(
+ "Added trace_context %s to trigger %s.", trace_context, trigger
+ )
+ return self.dispatch_with_context(
+ trigger,
+ payload=payload,
+ trace_context=trace_context,
+ throw_on_validation_error=throw_on_validation_error,
+ )
+
+ def dispatch_with_context(
+ self, trigger, payload=None, trace_context=None, throw_on_validation_error=False
+ ):
"""
Method which dispatches the trigger.
@@ -81,18 +88,25 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None,
# Note: We perform validation even if it's disabled in the config so we can at least warn
# the user if validation fals (but not throw if it's disabled)
try:
- validate_trigger_payload(trigger_type_ref=trigger, payload=payload,
- throw_on_inexistent_trigger=True)
+ validate_trigger_payload(
+ trigger_type_ref=trigger,
+ payload=payload,
+ throw_on_inexistent_trigger=True,
+ )
except (ValidationError, ValueError, Exception) as e:
- self._logger.warn('Failed to validate payload (%s) for trigger "%s": %s' %
- (str(payload), trigger, six.text_type(e)))
+ self._logger.warn(
+ 'Failed to validate payload (%s) for trigger "%s": %s'
+ % (str(payload), trigger, six.text_type(e))
+ )
# If validation is disabled, still dispatch a trigger even if it failed validation
# This condition prevents unexpected restriction.
if cfg.CONF.system.validate_trigger_payload:
- msg = ('Trigger payload validation failed and validation is enabled, not '
- 'dispatching a trigger "%s" (%s): %s' % (trigger, str(payload),
- six.text_type(e)))
+ msg = (
+ "Trigger payload validation failed and validation is enabled, not "
+ 'dispatching a trigger "%s" (%s): %s'
+ % (trigger, str(payload), six.text_type(e))
+ )
if throw_on_validation_error:
raise ValueError(msg)
@@ -100,5 +114,7 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None,
self._logger.warn(msg)
return None
- self._logger.debug('Dispatching trigger %s with payload %s.', trigger, payload)
- return self._dispatcher.dispatch(trigger, payload=payload, trace_context=trace_context)
+ self._logger.debug("Dispatching trigger %s with payload %s.", trigger, payload)
+ return self._dispatcher.dispatch(
+ trigger, payload=payload, trace_context=trace_context
+ )
diff --git a/st2common/st2common/services/triggers.py b/st2common/st2common/services/triggers.py
index 6448aa2533..bbdce26b81 100644
--- a/st2common/st2common/services/triggers.py
+++ b/st2common/st2common/services/triggers.py
@@ -23,25 +23,22 @@
from st2common.exceptions.triggers import TriggerDoesNotExistException
from st2common.exceptions.db import StackStormDBObjectNotFoundError
from st2common.exceptions.db import StackStormDBObjectConflictError
-from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI)
+from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI
from st2common.models.system.common import ResourceReference
-from st2common.persistence.trigger import (Trigger, TriggerType)
+from st2common.persistence.trigger import Trigger, TriggerType
__all__ = [
- 'add_trigger_models',
-
- 'get_trigger_db_by_ref',
- 'get_trigger_db_by_id',
- 'get_trigger_db_by_uid',
- 'get_trigger_db_by_ref_or_dict',
- 'get_trigger_db_given_type_and_params',
- 'get_trigger_type_db',
-
- 'create_trigger_db',
- 'create_trigger_type_db',
-
- 'create_or_update_trigger_db',
- 'create_or_update_trigger_type_db'
+ "add_trigger_models",
+ "get_trigger_db_by_ref",
+ "get_trigger_db_by_id",
+ "get_trigger_db_by_uid",
+ "get_trigger_db_by_ref_or_dict",
+ "get_trigger_db_given_type_and_params",
+ "get_trigger_type_db",
+ "create_trigger_db",
+ "create_trigger_type_db",
+ "create_or_update_trigger_db",
+ "create_or_update_trigger_type_db",
]
LOG = logging.getLogger(__name__)
@@ -50,8 +47,7 @@
def get_trigger_db_given_type_and_params(type=None, parameters=None):
try:
parameters = parameters or {}
- trigger_dbs = Trigger.query(type=type,
- parameters=parameters)
+ trigger_dbs = Trigger.query(type=type, parameters=parameters)
trigger_db = trigger_dbs[0] if len(trigger_dbs) > 0 else None
@@ -59,23 +55,24 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None):
# pymongo and mongoengine
# Work around for cron-timer when in some scenarios finding an object fails when Python
# value types are unicode :/
- is_cron_trigger = (type == CRON_TIMER_TRIGGER_REF)
+ is_cron_trigger = type == CRON_TIMER_TRIGGER_REF
has_parameters = bool(parameters)
if not trigger_db and six.PY2 and is_cron_trigger and has_parameters:
non_unicode_literal_parameters = {}
for key, value in six.iteritems(parameters):
- key = key.encode('utf-8')
+ key = key.encode("utf-8")
if isinstance(value, six.text_type):
# We only encode unicode to str
- value = value.encode('utf-8')
+ value = value.encode("utf-8")
non_unicode_literal_parameters[key] = value
parameters = non_unicode_literal_parameters
- trigger_dbs = Trigger.query(type=type,
- parameters=non_unicode_literal_parameters).no_cache()
+ trigger_dbs = Trigger.query(
+ type=type, parameters=non_unicode_literal_parameters
+ ).no_cache()
# Note: We need to directly access the object, using len or accessing the query set
# twice won't work - there seems to bug a bug with cursor where accessing it twice
@@ -93,8 +90,14 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None):
return trigger_db
except StackStormDBObjectNotFoundError as e:
- LOG.debug('Database lookup for type="%s" parameters="%s" resulted ' +
- 'in exception : %s.', type, parameters, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for type="%s" parameters="%s" resulted '
+ + "in exception : %s.",
+ type,
+ parameters,
+ e,
+ exc_info=True,
+ )
return None
@@ -109,26 +112,30 @@ def get_trigger_db_by_ref_or_dict(trigger):
else:
# If id / uid is available we try to look up Trigger by id. This way we can avoid bug in
# pymongo / mongoengine related to "parameters" dictionary lookups
- trigger_id = trigger.get('id', None)
- trigger_uid = trigger.get('uid', None)
+ trigger_id = trigger.get("id", None)
+ trigger_uid = trigger.get("uid", None)
# TODO: Remove parameters dictionary look up when we can confirm each trigger dictionary
# passed to this method always contains id or uid
if trigger_id:
- LOG.debug('Looking up TriggerDB by id: %s', trigger_id)
+ LOG.debug("Looking up TriggerDB by id: %s", trigger_id)
trigger_db = get_trigger_db_by_id(id=trigger_id)
elif trigger_uid:
- LOG.debug('Looking up TriggerDB by uid: %s', trigger_uid)
+ LOG.debug("Looking up TriggerDB by uid: %s", trigger_uid)
trigger_db = get_trigger_db_by_uid(uid=trigger_uid)
else:
# Last resort - look it up by parameters
- trigger_type = trigger.get('type', None)
- parameters = trigger.get('parameters', {})
-
- LOG.debug('Looking up TriggerDB by type and parameters: type=%s, parameters=%s',
- trigger_type, parameters)
- trigger_db = get_trigger_db_given_type_and_params(type=trigger_type,
- parameters=parameters)
+ trigger_type = trigger.get("type", None)
+ parameters = trigger.get("parameters", {})
+
+ LOG.debug(
+ "Looking up TriggerDB by type and parameters: type=%s, parameters=%s",
+ trigger_type,
+ parameters,
+ )
+ trigger_db = get_trigger_db_given_type_and_params(
+ type=trigger_type, parameters=parameters
+ )
return trigger_db
@@ -145,8 +152,12 @@ def get_trigger_db_by_id(id):
try:
return Trigger.get_by_id(id)
except StackStormDBObjectNotFoundError as e:
- LOG.debug('Database lookup for id="%s" resulted in exception : %s.',
- id, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for id="%s" resulted in exception : %s.',
+ id,
+ e,
+ exc_info=True,
+ )
return None
@@ -163,8 +174,12 @@ def get_trigger_db_by_uid(uid):
try:
return Trigger.get_by_uid(uid)
except StackStormDBObjectNotFoundError as e:
- LOG.debug('Database lookup for uid="%s" resulted in exception : %s.',
- uid, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for uid="%s" resulted in exception : %s.',
+ uid,
+ e,
+ exc_info=True,
+ )
return None
@@ -181,8 +196,12 @@ def get_trigger_db_by_ref(ref):
try:
return Trigger.get_by_ref(ref)
except StackStormDBObjectNotFoundError as e:
- LOG.debug('Database lookup for ref="%s" resulted ' +
- 'in exception : %s.', ref, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for ref="%s" resulted ' + "in exception : %s.",
+ ref,
+ e,
+ exc_info=True,
+ )
return None
@@ -192,16 +211,17 @@ def _get_trigger_db(trigger):
# XXX: Do not make this method public.
if isinstance(trigger, dict):
- name = trigger.get('name', None)
- pack = trigger.get('pack', None)
+ name = trigger.get("name", None)
+ pack = trigger.get("pack", None)
if name and pack:
ref = ResourceReference.to_string_reference(name=name, pack=pack)
return get_trigger_db_by_ref(ref)
- return get_trigger_db_given_type_and_params(type=trigger['type'],
- parameters=trigger.get('parameters', {}))
+ return get_trigger_db_given_type_and_params(
+ type=trigger["type"], parameters=trigger.get("parameters", {})
+ )
else:
- raise Exception('Unrecognized object')
+ raise Exception("Unrecognized object")
def get_trigger_type_db(ref):
@@ -216,8 +236,12 @@ def get_trigger_type_db(ref):
try:
return TriggerType.get_by_ref(ref)
except StackStormDBObjectNotFoundError as e:
- LOG.debug('Database lookup for ref="%s" resulted ' +
- 'in exception : %s.', ref, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for ref="%s" resulted ' + "in exception : %s.",
+ ref,
+ e,
+ exc_info=True,
+ )
return None
@@ -225,22 +249,23 @@ def get_trigger_type_db(ref):
def _get_trigger_dict_given_rule(rule):
trigger = rule.trigger
trigger_dict = {}
- triggertype_ref = ResourceReference.from_string_reference(trigger.get('type'))
- trigger_dict['pack'] = trigger_dict.get('pack', triggertype_ref.pack)
- trigger_dict['type'] = triggertype_ref.ref
- trigger_dict['parameters'] = rule.trigger.get('parameters', {})
+ triggertype_ref = ResourceReference.from_string_reference(trigger.get("type"))
+ trigger_dict["pack"] = trigger_dict.get("pack", triggertype_ref.pack)
+ trigger_dict["type"] = triggertype_ref.ref
+ trigger_dict["parameters"] = rule.trigger.get("parameters", {})
return trigger_dict
def create_trigger_db(trigger_api):
# TODO: This is used only in trigger API controller. We should get rid of this.
- trigger_ref = ResourceReference.to_string_reference(name=trigger_api.name,
- pack=trigger_api.pack)
+ trigger_ref = ResourceReference.to_string_reference(
+ name=trigger_api.name, pack=trigger_api.pack
+ )
trigger_db = get_trigger_db_by_ref(trigger_ref)
if not trigger_db:
trigger_db = TriggerAPI.to_model(trigger_api)
- LOG.debug('Verified trigger and formulated TriggerDB=%s', trigger_db)
+ LOG.debug("Verified trigger and formulated TriggerDB=%s", trigger_db)
trigger_db = Trigger.add_or_update(trigger_db)
return trigger_db
@@ -269,15 +294,16 @@ def create_or_update_trigger_db(trigger, log_not_unique_error_as_debug=False):
if is_update:
trigger_db.id = existing_trigger_db.id
- trigger_db = Trigger.add_or_update(trigger_db,
- log_not_unique_error_as_debug=log_not_unique_error_as_debug)
+ trigger_db = Trigger.add_or_update(
+ trigger_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug
+ )
- extra = {'trigger_db': trigger_db}
+ extra = {"trigger_db": trigger_db}
if is_update:
- LOG.audit('Trigger updated. Trigger.id=%s' % (trigger_db.id), extra=extra)
+ LOG.audit("Trigger updated. Trigger.id=%s" % (trigger_db.id), extra=extra)
else:
- LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra)
+ LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra)
return trigger_db
@@ -288,10 +314,11 @@ def create_trigger_db_from_rule(rule):
# For simple triggertypes (triggertype with no parameters), we create a trigger when
# registering triggertype. So if we hit the case that there is no trigger in db but
# parameters is empty, then this case is a run time error.
- if not trigger_dict.get('parameters', {}) and not existing_trigger_db:
+ if not trigger_dict.get("parameters", {}) and not existing_trigger_db:
raise TriggerDoesNotExistException(
- 'A simple trigger should have been created when registering '
- 'triggertype. Cannot create trigger: %s.' % (trigger_dict))
+ "A simple trigger should have been created when registering "
+ "triggertype. Cannot create trigger: %s." % (trigger_dict)
+ )
if not existing_trigger_db:
trigger_db = create_or_update_trigger_db(trigger_dict)
@@ -316,7 +343,7 @@ def increment_trigger_ref_count(rule_api):
trigger_dict = _get_trigger_dict_given_rule(rule_api)
# Special reference counting for trigger with parameters.
- if trigger_dict.get('parameters', None):
+ if trigger_dict.get("parameters", None):
trigger_db = _get_trigger_db(trigger_dict)
Trigger.update(trigger_db, inc__ref_count=1)
@@ -326,7 +353,7 @@ def cleanup_trigger_db_for_rule(rule_db):
existing_trigger_db = get_trigger_db_by_ref(rule_db.trigger)
if not existing_trigger_db or not existing_trigger_db.parameters:
# nothing to be done here so moving on.
- LOG.debug('ref_count decrement for %s not required.', existing_trigger_db)
+ LOG.debug("ref_count decrement for %s not required.", existing_trigger_db)
return
Trigger.update(existing_trigger_db, dec__ref_count=1)
Trigger.delete_if_unreferenced(existing_trigger_db)
@@ -350,15 +377,17 @@ def create_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False):
"""
trigger_type_api = TriggerTypeAPI(**trigger_type)
trigger_type_api.validate()
- ref = ResourceReference.to_string_reference(name=trigger_type_api.name,
- pack=trigger_type_api.pack)
+ ref = ResourceReference.to_string_reference(
+ name=trigger_type_api.name, pack=trigger_type_api.pack
+ )
trigger_type_db = get_trigger_type_db(ref)
if not trigger_type_db:
trigger_type_db = TriggerTypeAPI.to_model(trigger_type_api)
- LOG.debug('verified trigger and formulated TriggerDB=%s', trigger_type_db)
- trigger_type_db = TriggerType.add_or_update(trigger_type_db,
- log_not_unique_error_as_debug=log_not_unique_error_as_debug)
+ LOG.debug("verified trigger and formulated TriggerDB=%s", trigger_type_db)
+ trigger_type_db = TriggerType.add_or_update(
+ trigger_type_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug
+ )
return trigger_type_db
@@ -378,16 +407,21 @@ def create_shadow_trigger(trigger_type_db, log_not_unique_error_as_debug=False):
trigger_type_ref = trigger_type_db.get_reference().ref
if trigger_type_db.parameters_schema:
- LOG.debug('Skip shadow trigger for TriggerType with parameters %s.', trigger_type_ref)
+ LOG.debug(
+ "Skip shadow trigger for TriggerType with parameters %s.", trigger_type_ref
+ )
return None
- trigger = {'name': trigger_type_db.name,
- 'pack': trigger_type_db.pack,
- 'type': trigger_type_ref,
- 'parameters': {}}
+ trigger = {
+ "name": trigger_type_db.name,
+ "pack": trigger_type_db.pack,
+ "type": trigger_type_ref,
+ "parameters": {},
+ }
- return create_or_update_trigger_db(trigger,
- log_not_unique_error_as_debug=log_not_unique_error_as_debug)
+ return create_or_update_trigger_db(
+ trigger, log_not_unique_error_as_debug=log_not_unique_error_as_debug
+ )
def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False):
@@ -412,8 +446,9 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug
trigger_type_api.validate()
trigger_type_api = TriggerTypeAPI.to_model(trigger_type_api)
- ref = ResourceReference.to_string_reference(name=trigger_type_api.name,
- pack=trigger_type_api.pack)
+ ref = ResourceReference.to_string_reference(
+ name=trigger_type_api.name, pack=trigger_type_api.pack
+ )
existing_trigger_type_db = get_trigger_type_db(ref)
if existing_trigger_type_db:
@@ -425,8 +460,10 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug
trigger_type_api.id = existing_trigger_type_db.id
try:
- trigger_type_db = TriggerType.add_or_update(trigger_type_api,
- log_not_unique_error_as_debug=log_not_unique_error_as_debug)
+ trigger_type_db = TriggerType.add_or_update(
+ trigger_type_api,
+ log_not_unique_error_as_debug=log_not_unique_error_as_debug,
+ )
except StackStormDBObjectConflictError:
# Operation is idempotent and trigger could have already been created by
# another process. Ignore object already exists because it simply means
@@ -434,26 +471,37 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug
trigger_type_db = get_trigger_type_db(ref)
is_update = True
- extra = {'trigger_type_db': trigger_type_db}
+ extra = {"trigger_type_db": trigger_type_db}
if is_update:
- LOG.audit('TriggerType updated. TriggerType.id=%s' % (trigger_type_db.id), extra=extra)
+ LOG.audit(
+ "TriggerType updated. TriggerType.id=%s" % (trigger_type_db.id), extra=extra
+ )
else:
- LOG.audit('TriggerType created. TriggerType.id=%s' % (trigger_type_db.id), extra=extra)
+ LOG.audit(
+ "TriggerType created. TriggerType.id=%s" % (trigger_type_db.id), extra=extra
+ )
return trigger_type_db
-def _create_trigger_type(pack, name, description=None, payload_schema=None,
- parameters_schema=None, tags=None, metadata_file=None):
+def _create_trigger_type(
+ pack,
+ name,
+ description=None,
+ payload_schema=None,
+ parameters_schema=None,
+ tags=None,
+ metadata_file=None,
+):
trigger_type = {
- 'name': name,
- 'pack': pack,
- 'description': description,
- 'payload_schema': payload_schema,
- 'parameters_schema': parameters_schema,
- 'tags': tags,
- 'metadata_file': metadata_file
+ "name": name,
+ "pack": pack,
+ "description": description,
+ "payload_schema": payload_schema,
+ "parameters_schema": parameters_schema,
+ "tags": tags,
+ "metadata_file": metadata_file,
}
return create_or_update_trigger_type_db(trigger_type=trigger_type)
@@ -464,11 +512,12 @@ def _validate_trigger_type(trigger_type):
XXX: We need validator objects that define the required and optional fields.
For now, manually check them.
"""
- required_fields = ['name']
+ required_fields = ["name"]
for field in required_fields:
if field not in trigger_type:
- raise TriggerTypeRegistrationException('Invalid trigger type. Missing field "%s"' %
- (field))
+ raise TriggerTypeRegistrationException(
+ 'Invalid trigger type. Missing field "%s"' % (field)
+ )
def _create_trigger(trigger_type):
@@ -476,37 +525,46 @@ def _create_trigger(trigger_type):
:param trigger_type: TriggerType db object.
:type trigger_type: :class:`TriggerTypeDB`
"""
- if hasattr(trigger_type, 'parameters_schema') and not trigger_type['parameters_schema']:
+ if (
+ hasattr(trigger_type, "parameters_schema")
+ and not trigger_type["parameters_schema"]
+ ):
trigger_dict = {
- 'name': trigger_type.name,
- 'pack': trigger_type.pack,
- 'type': trigger_type.get_reference().ref
+ "name": trigger_type.name,
+ "pack": trigger_type.pack,
+ "type": trigger_type.get_reference().ref,
}
try:
return create_or_update_trigger_db(trigger=trigger_dict)
except:
- LOG.exception('Validation failed for Trigger=%s.', trigger_dict)
+ LOG.exception("Validation failed for Trigger=%s.", trigger_dict)
raise TriggerTypeRegistrationException(
- 'Unable to create Trigger for TriggerType=%s.' % trigger_type.name)
+ "Unable to create Trigger for TriggerType=%s." % trigger_type.name
+ )
else:
- LOG.debug('Won\'t create Trigger object as TriggerType %s expects ' +
- 'parameters.', trigger_type)
+ LOG.debug(
+ "Won't create Trigger object as TriggerType %s expects " + "parameters.",
+ trigger_type,
+ )
return None
def _add_trigger_models(trigger_type):
- pack = trigger_type['pack']
- description = trigger_type['description'] if 'description' in trigger_type else ''
- payload_schema = trigger_type['payload_schema'] if 'payload_schema' in trigger_type else {}
- parameters_schema = trigger_type['parameters_schema'] \
- if 'parameters_schema' in trigger_type else {}
- tags = trigger_type.get('tags', [])
- metadata_file = trigger_type.get('metadata_file', None)
+ pack = trigger_type["pack"]
+ description = trigger_type["description"] if "description" in trigger_type else ""
+ payload_schema = (
+ trigger_type["payload_schema"] if "payload_schema" in trigger_type else {}
+ )
+ parameters_schema = (
+ trigger_type["parameters_schema"] if "parameters_schema" in trigger_type else {}
+ )
+ tags = trigger_type.get("tags", [])
+ metadata_file = trigger_type.get("metadata_file", None)
trigger_type = _create_trigger_type(
pack=pack,
- name=trigger_type['name'],
+ name=trigger_type["name"],
description=description,
payload_schema=payload_schema,
parameters_schema=parameters_schema,
@@ -526,8 +584,13 @@ def add_trigger_models(trigger_types):
:rtype: ``list`` of ``tuple`` (trigger_type, trigger)
"""
- [r for r in (_validate_trigger_type(trigger_type)
- for trigger_type in trigger_types) if r is not None]
+ [
+ r
+ for r in (
+ _validate_trigger_type(trigger_type) for trigger_type in trigger_types
+ )
+ if r is not None
+ ]
result = []
for trigger_type in trigger_types:
diff --git a/st2common/st2common/services/triggerwatcher.py b/st2common/st2common/services/triggerwatcher.py
index 4830c349f4..b82a46043a 100644
--- a/st2common/st2common/services/triggerwatcher.py
+++ b/st2common/st2common/services/triggerwatcher.py
@@ -33,8 +33,15 @@ class TriggerWatcher(ConsumerMixin):
sleep_interval = 0 # sleep to co-operatively yield after processing each message
- def __init__(self, create_handler, update_handler, delete_handler,
- trigger_types=None, queue_suffix=None, exclusive=False):
+ def __init__(
+ self,
+ create_handler,
+ update_handler,
+ delete_handler,
+ trigger_types=None,
+ queue_suffix=None,
+ exclusive=False,
+ ):
"""
:param create_handler: Function which is called on TriggerDB create event.
:type create_handler: ``callable``
@@ -69,39 +76,49 @@ def __init__(self, create_handler, update_handler, delete_handler,
self._handlers = {
publishers.CREATE_RK: create_handler,
publishers.UPDATE_RK: update_handler,
- publishers.DELETE_RK: delete_handler
+ publishers.DELETE_RK: delete_handler,
}
def get_consumers(self, Consumer, channel):
- return [Consumer(queues=[self._trigger_watch_q],
- accept=['pickle'],
- callbacks=[self.process_task])]
+ return [
+ Consumer(
+ queues=[self._trigger_watch_q],
+ accept=["pickle"],
+ callbacks=[self.process_task],
+ )
+ ]
def process_task(self, body, message):
- LOG.debug('process_task')
- LOG.debug(' body: %s', body)
- LOG.debug(' message.properties: %s', message.properties)
- LOG.debug(' message.delivery_info: %s', message.delivery_info)
+ LOG.debug("process_task")
+ LOG.debug(" body: %s", body)
+ LOG.debug(" message.properties: %s", message.properties)
+ LOG.debug(" message.delivery_info: %s", message.delivery_info)
- routing_key = message.delivery_info.get('routing_key', '')
+ routing_key = message.delivery_info.get("routing_key", "")
handler = self._handlers.get(routing_key, None)
try:
if not handler:
- LOG.debug('Skipping message %s as no handler was found.', message)
+ LOG.debug("Skipping message %s as no handler was found.", message)
return
- trigger_type = getattr(body, 'type', None)
+ trigger_type = getattr(body, "type", None)
if self._trigger_types and trigger_type not in self._trigger_types:
- LOG.debug('Skipping message %s since trigger_type doesn\'t match (type=%s)',
- message, trigger_type)
+ LOG.debug(
+ "Skipping message %s since trigger_type doesn't match (type=%s)",
+ message,
+ trigger_type,
+ )
return
try:
handler(body)
except Exception as e:
- LOG.exception('Handling failed. Message body: %s. Exception: %s',
- body, six.text_type(e))
+ LOG.exception(
+ "Handling failed. Message body: %s. Exception: %s",
+ body,
+ six.text_type(e),
+ )
finally:
message.ack()
@@ -113,7 +130,7 @@ def start(self):
self._updates_thread = concurrency.spawn(self.run)
self._load_thread = concurrency.spawn(self._load_triggers_from_db)
except:
- LOG.exception('Failed to start watcher.')
+ LOG.exception("Failed to start watcher.")
self.connection.release()
def stop(self):
@@ -128,8 +145,9 @@ def stop(self):
# waiting for a message on the queue.
def on_consume_end(self, connection, channel):
- super(TriggerWatcher, self).on_consume_end(connection=connection,
- channel=channel)
+ super(TriggerWatcher, self).on_consume_end(
+ connection=connection, channel=channel
+ )
concurrency.sleep(seconds=self.sleep_interval)
def on_iteration(self):
@@ -139,13 +157,16 @@ def on_iteration(self):
def _load_triggers_from_db(self):
for trigger_type in self._trigger_types:
for trigger in Trigger.query(type=trigger_type):
- LOG.debug('Found existing trigger: %s in db.' % trigger)
+ LOG.debug("Found existing trigger: %s in db." % trigger)
self._handlers[publishers.CREATE_RK](trigger)
@staticmethod
def _get_queue(queue_suffix, exclusive):
- queue_name = queue_utils.get_queue_name(queue_name_base='st2.trigger.watch',
- queue_name_suffix=queue_suffix,
- add_random_uuid_to_suffix=True
- )
- return reactor.get_trigger_cud_queue(queue_name, routing_key='#', exclusive=exclusive)
+ queue_name = queue_utils.get_queue_name(
+ queue_name_base="st2.trigger.watch",
+ queue_name_suffix=queue_suffix,
+ add_random_uuid_to_suffix=True,
+ )
+ return reactor.get_trigger_cud_queue(
+ queue_name, routing_key="#", exclusive=exclusive
+ )
diff --git a/st2common/st2common/services/workflows.py b/st2common/st2common/services/workflows.py
index db64c681b1..16ddcb5a02 100644
--- a/st2common/st2common/services/workflows.py
+++ b/st2common/st2common/services/workflows.py
@@ -54,59 +54,61 @@
LOG = logging.getLogger(__name__)
LOG_FUNCTIONS = {
- 'audit': LOG.audit,
- 'debug': LOG.debug,
- 'info': LOG.info,
- 'warning': LOG.warning,
- 'error': LOG.error,
- 'critical': LOG.critical,
+ "audit": LOG.audit,
+ "debug": LOG.debug,
+ "info": LOG.info,
+ "warning": LOG.warning,
+ "error": LOG.error,
+ "critical": LOG.critical,
}
-def update_progress(wf_ex_db, message, severity='info', log=True, stream=True):
+def update_progress(wf_ex_db, message, severity="info", log=True, stream=True):
if not wf_ex_db:
return
if log and severity in LOG_FUNCTIONS:
- LOG_FUNCTIONS[severity]('[%s] %s', wf_ex_db.context['st2']['action_execution_id'], message)
+ LOG_FUNCTIONS[severity](
+ "[%s] %s", wf_ex_db.context["st2"]["action_execution_id"], message
+ )
if stream:
ac_svc.store_execution_output_data_ex(
- wf_ex_db.context['st2']['action_execution_id'],
- wf_ex_db.context['st2']['action'],
- wf_ex_db.context['st2']['runner'],
- '%s\n' % message,
+ wf_ex_db.context["st2"]["action_execution_id"],
+ wf_ex_db.context["st2"]["action"],
+ wf_ex_db.context["st2"]["runner"],
+ "%s\n" % message,
)
def is_action_execution_under_workflow_context(ac_ex_db):
# The action execution is executed under the context of a workflow
# if it contains the orquesta key in its context dictionary.
- return ac_ex_db.context and 'orquesta' in ac_ex_db.context
+ return ac_ex_db.context and "orquesta" in ac_ex_db.context
def format_inspection_result(result):
errors = []
categories = {
- 'contents': 'content',
- 'context': 'context',
- 'expressions': 'expression',
- 'semantics': 'semantic',
- 'syntax': 'syntax'
+ "contents": "content",
+ "context": "context",
+ "expressions": "expression",
+ "semantics": "semantic",
+ "syntax": "syntax",
}
# For context and expression errors, rename the attribute from type to language.
- for category in ['context', 'expressions']:
+ for category in ["context", "expressions"]:
for entry in result.get(category, []):
- if 'language' not in entry:
- entry['language'] = entry['type']
- del entry['type']
+ if "language" not in entry:
+ entry["language"] = entry["type"]
+ del entry["type"]
# For all categories, put the category value in the type attribute.
for category, entries in six.iteritems(result):
for entry in entries:
- entry['type'] = categories[category]
+ entry["type"] = categories[category]
errors.append(entry)
return errors
@@ -121,7 +123,7 @@ def inspect(wf_spec, st2_ctx, raise_exception=True):
errors += inspect_task_contents(wf_spec)
# Sort the list of errors by type and path.
- errors = sorted(errors, key=lambda e: (e['type'], e['schema_path']))
+ errors = sorted(errors, key=lambda e: (e["type"], e["schema_path"]))
if errors and raise_exception:
raise orquesta_exc.WorkflowInspectionError(errors)
@@ -131,10 +133,10 @@ def inspect(wf_spec, st2_ctx, raise_exception=True):
def inspect_task_contents(wf_spec):
result = []
- spec_path = 'tasks'
- schema_path = 'properties.tasks.patternProperties.^\\w+$'
- action_schema_path = schema_path + '.properties.action'
- action_input_schema_path = schema_path + '.properties.input'
+ spec_path = "tasks"
+ schema_path = "properties.tasks.patternProperties.^\\w+$"
+ action_schema_path = schema_path + ".properties.action"
+ action_input_schema_path = schema_path + ".properties.input"
def is_action_an_expression(action):
if isinstance(action, six.string_types):
@@ -143,9 +145,9 @@ def is_action_an_expression(action):
return True
for task_name, task_spec in six.iteritems(wf_spec.tasks):
- action_ref = getattr(task_spec, 'action', None)
- action_spec_path = spec_path + '.' + task_name + '.action'
- action_input_spec_path = spec_path + '.' + task_name + '.input'
+ action_ref = getattr(task_spec, "action", None)
+ action_spec_path = spec_path + "." + task_name + ".action"
+ action_input_spec_path = spec_path + "." + task_name + ".input"
# Move on if action is empty or an expression.
if not action_ref or is_action_an_expression(action_ref):
@@ -154,10 +156,11 @@ def is_action_an_expression(action):
# Check that the format of the action is a valid resource reference.
if not sys_models.ResourceReference.is_resource_reference(action_ref):
entry = {
- 'type': 'content',
- 'message': 'The action reference "%s" is not formatted correctly.' % action_ref,
- 'spec_path': action_spec_path,
- 'schema_path': action_schema_path
+ "type": "content",
+ "message": 'The action reference "%s" is not formatted correctly.'
+ % action_ref,
+ "spec_path": action_spec_path,
+ "schema_path": action_schema_path,
}
result.append(entry)
@@ -166,31 +169,37 @@ def is_action_an_expression(action):
# Check that the action is registered in the database.
if not action_utils.get_action_by_ref(ref=action_ref):
entry = {
- 'type': 'content',
- 'message': 'The action "%s" is not registered in the database.' % action_ref,
- 'spec_path': action_spec_path,
- 'schema_path': action_schema_path
+ "type": "content",
+ "message": 'The action "%s" is not registered in the database.'
+ % action_ref,
+ "spec_path": action_spec_path,
+ "schema_path": action_schema_path,
}
result.append(entry)
continue
# Check the action parameters.
- params = getattr(task_spec, 'input', None) or {}
+ params = getattr(task_spec, "input", None) or {}
if params and not isinstance(params, dict):
continue
- requires, unexpected = action_param_utils.validate_action_parameters(action_ref, params)
+ requires, unexpected = action_param_utils.validate_action_parameters(
+ action_ref, params
+ )
for param in requires:
- message = 'Action "%s" is missing required input "%s".' % (action_ref, param)
+ message = 'Action "%s" is missing required input "%s".' % (
+ action_ref,
+ param,
+ )
entry = {
- 'type': 'content',
- 'message': message,
- 'spec_path': action_input_spec_path,
- 'schema_path': action_input_schema_path
+ "type": "content",
+ "message": message,
+ "spec_path": action_input_spec_path,
+ "schema_path": action_input_schema_path,
}
result.append(entry)
@@ -199,10 +208,10 @@ def is_action_an_expression(action):
message = 'Action "%s" has unexpected input "%s".' % (action_ref, param)
entry = {
- 'type': 'content',
- 'message': message,
- 'spec_path': action_input_spec_path + '.' + param,
- 'schema_path': action_input_schema_path + '.patternProperties.^\\w+$'
+ "type": "content",
+ "message": message,
+ "spec_path": action_input_spec_path + "." + param,
+ "schema_path": action_input_schema_path + ".patternProperties.^\\w+$",
}
result.append(entry)
@@ -211,35 +220,35 @@ def is_action_an_expression(action):
def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None):
- LOG.info('[%s] Processing action execution request for workflow.', str(ac_ex_db.id))
+ LOG.info("[%s] Processing action execution request for workflow.", str(ac_ex_db.id))
# Load workflow definition into workflow spec model.
- spec_module = specs_loader.get_spec_module('native')
+ spec_module = specs_loader.get_spec_module("native")
wf_spec = spec_module.instantiate(wf_def)
# Inspect the workflow spec.
inspect(wf_spec, st2_ctx, raise_exception=True)
# Identify the action to execute.
- action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action['ref'])
+ action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action["ref"])
if not action_db:
- error = 'Unable to find action "%s".' % ac_ex_db.action['ref']
+ error = 'Unable to find action "%s".' % ac_ex_db.action["ref"]
raise ac_exc.InvalidActionReferencedException(error)
# Identify the runner for the action.
- runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"])
# Render action execution parameters.
runner_params, action_params = param_utils.render_final_params(
runner_type_db.runner_parameters,
action_db.parameters,
ac_ex_db.parameters,
- ac_ex_db.context
+ ac_ex_db.context,
)
# Instantiate the workflow conductor.
- conductor_params = {'inputs': action_params, 'context': st2_ctx}
+ conductor_params = {"inputs": action_params, "context": st2_ctx}
conductor = conducting.WorkflowConductor(wf_spec, **conductor_params)
# Serialize the conductor which initializes some internal values.
@@ -248,33 +257,32 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None):
# Create a record for workflow execution.
wf_ex_db = wf_db_models.WorkflowExecutionDB(
action_execution=str(ac_ex_db.id),
- spec=data['spec'],
- graph=data['graph'],
- input=data['input'],
- context=data['context'],
- state=data['state'],
- status=data['state']['status'],
- output=data['output'],
- errors=data['errors']
+ spec=data["spec"],
+ graph=data["graph"],
+ input=data["input"],
+ context=data["context"],
+ state=data["state"],
+ status=data["state"]["status"],
+ output=data["output"],
+ errors=data["errors"],
)
# Inspect that the list of tasks in the notify parameter exist in the workflow spec.
- if runner_params.get('notify'):
- invalid_tasks = list(set(runner_params.get('notify')) - set(wf_spec.tasks.keys()))
+ if runner_params.get("notify"):
+ invalid_tasks = list(
+ set(runner_params.get("notify")) - set(wf_spec.tasks.keys())
+ )
if invalid_tasks:
raise wf_exc.WorkflowExecutionException(
- 'The following tasks in the notify parameter do not exist '
- 'in the workflow definition: %s.' % ', '.join(invalid_tasks)
+ "The following tasks in the notify parameter do not exist "
+ "in the workflow definition: %s." % ", ".join(invalid_tasks)
)
# Write notify instruction to record.
if notify_cfg:
# Set up the notify instruction in the workflow execution record.
- wf_ex_db.notify = {
- 'config': notify_cfg,
- 'tasks': runner_params.get('notify')
- }
+ wf_ex_db.notify = {"config": notify_cfg, "tasks": runner_params.get("notify")}
# Insert new record into the database and do not publish to the message bus yet.
wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False)
@@ -286,12 +294,12 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None):
# Set the initial workflow status to requested.
conductor.request_workflow_status(statuses.REQUESTED)
data = conductor.serialize()
- wf_ex_db.state = data['state']
- wf_ex_db.status = data['state']['status']
+ wf_ex_db.state = data["state"]
+ wf_ex_db.status = data["state"]["status"]
# Put the ID of the workflow execution record in the context.
- wf_ex_db.context['st2']['workflow_execution_id'] = str(wf_ex_db.id)
- wf_ex_db.state['contexts'][0]['st2']['workflow_execution_id'] = str(wf_ex_db.id)
+ wf_ex_db.context["st2"]["workflow_execution_id"] = str(wf_ex_db.id)
+ wf_ex_db.state["contexts"][0]["st2"]["workflow_execution_id"] = str(wf_ex_db.id)
# Update the workflow execution record.
wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False)
@@ -308,15 +316,17 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_pause(ac_ex_db):
wf_ac_ex_id = str(ac_ex_db.id)
- LOG.info('[%s] Processing pause request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Processing pause request for workflow.", wf_ac_ex_id)
wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
@@ -343,7 +353,7 @@ def request_pause(ac_ex_db):
wf_ex_db.state = conductor.workflow_state.serialize()
wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False)
- LOG.info('[%s] Completed processing pause request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Completed processing pause request for workflow.", wf_ac_ex_id)
return wf_ex_db
@@ -351,15 +361,17 @@ def request_pause(ac_ex_db):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_resume(ac_ex_db):
wf_ac_ex_id = str(ac_ex_db.id)
- LOG.info('[%s] Processing resume request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Processing resume request for workflow.", wf_ac_ex_id)
wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
@@ -375,7 +387,9 @@ def request_resume(ac_ex_db):
raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id))
if wf_ex_db.status in statuses.RUNNING_STATUSES:
- msg = '[%s] Workflow execution "%s" is not resumed because it is already active.'
+ msg = (
+ '[%s] Workflow execution "%s" is not resumed because it is already active.'
+ )
LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id))
return
@@ -385,7 +399,9 @@ def request_resume(ac_ex_db):
raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id))
if conductor.get_workflow_status() in statuses.RUNNING_STATUSES:
- msg = '[%s] Workflow execution "%s" is not resumed because it is already active.'
+ msg = (
+ '[%s] Workflow execution "%s" is not resumed because it is already active.'
+ )
LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id))
return
@@ -400,7 +416,7 @@ def request_resume(ac_ex_db):
# Publish status change.
wf_db_access.WorkflowExecution.publish_status(wf_ex_db)
- LOG.info('[%s] Completed processing resume request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Completed processing resume request for workflow.", wf_ac_ex_id)
return wf_ex_db
@@ -408,15 +424,17 @@ def request_resume(ac_ex_db):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_cancellation(ac_ex_db):
wf_ac_ex_id = str(ac_ex_db.id)
- LOG.info('[%s] Processing cancelation request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Processing cancelation request for workflow.", wf_ac_ex_id)
wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
@@ -446,13 +464,16 @@ def request_cancellation(ac_ex_db):
# Cascade the cancellation up to the root of the workflow.
root_ac_ex_db = ac_svc.get_root_execution(ac_ex_db)
- if root_ac_ex_db != ac_ex_db and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES:
- LOG.info('[%s] Cascading cancelation request to parent workflow.', wf_ac_ex_id)
- root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction['id'])
+ if (
+ root_ac_ex_db != ac_ex_db
+ and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES
+ ):
+ LOG.info("[%s] Cascading cancelation request to parent workflow.", wf_ac_ex_id)
+ root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction["id"])
ac_svc.request_cancellation(root_lv_ac_db, None)
- LOG.debug('[%s] %s', wf_ac_ex_id, conductor.serialize())
- LOG.info('[%s] Completed processing cancelation request for workflow.', wf_ac_ex_id)
+ LOG.debug("[%s] %s", wf_ac_ex_id, conductor.serialize())
+ LOG.info("[%s] Completed processing cancelation request for workflow.", wf_ac_ex_id)
return wf_ex_db
@@ -460,20 +481,22 @@ def request_cancellation(ac_ex_db):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_rerun(ac_ex_db, st2_ctx, options=None):
wf_ac_ex_id = str(ac_ex_db.id)
- LOG.info('[%s] Processing rerun request for workflow.', wf_ac_ex_id)
+ LOG.info("[%s] Processing rerun request for workflow.", wf_ac_ex_id)
- wf_ex_id = st2_ctx.get('workflow_execution_id')
+ wf_ex_id = st2_ctx.get("workflow_execution_id")
if not wf_ex_id:
- msg = 'Unable to rerun workflow execution because workflow_execution_id is not provided.'
+ msg = "Unable to rerun workflow execution because workflow_execution_id is not provided."
raise wf_exc.WorkflowExecutionRerunException(msg)
try:
@@ -487,8 +510,8 @@ def request_rerun(ac_ex_db, st2_ctx, options=None):
raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id)
wf_ex_db.action_execution = wf_ac_ex_id
- wf_ex_db.context['st2'] = st2_ctx['st2']
- wf_ex_db.context['parent'] = st2_ctx['parent']
+ wf_ex_db.context["st2"] = st2_ctx["st2"]
+ wf_ex_db.context["parent"] = st2_ctx["parent"]
conductor = deserialize_conductor(wf_ex_db)
try:
@@ -497,26 +520,29 @@ def request_rerun(ac_ex_db, st2_ctx, options=None):
if options:
task_requests = []
- task_names = options.get('tasks', [])
- task_resets = options.get('reset', [])
+ task_names = options.get("tasks", [])
+ task_resets = options.get("reset", [])
for task_name in task_names:
reset_items = task_name in task_resets
- task_state_entries = conductor.workflow_state.get_tasks(task_id=task_name)
+ task_state_entries = conductor.workflow_state.get_tasks(
+ task_id=task_name
+ )
if not task_state_entries:
problems.append(task_name)
continue
for _, task_state_entry in task_state_entries:
- route = task_state_entry['route']
+ route = task_state_entry["route"]
req = orquesta_reqs.TaskRerunRequest.new(
- task_name, route, reset_items=reset_items)
+ task_name, route, reset_items=reset_items
+ )
task_requests.append(req)
if problems:
- msg = 'Unable to rerun workflow because one or more tasks is not found: %s'
- raise Exception(msg % ','.join(problems))
+ msg = "Unable to rerun workflow because one or more tasks is not found: %s"
+ raise Exception(msg % ",".join(problems))
conductor.request_workflow_rerun(task_requests=task_requests)
except Exception as e:
@@ -527,10 +553,10 @@ def request_rerun(ac_ex_db, st2_ctx, options=None):
raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id)
data = conductor.serialize()
- wf_ex_db.status = data['state']['status']
- wf_ex_db.spec = data['spec']
- wf_ex_db.graph = data['graph']
- wf_ex_db.state = data['state']
+ wf_ex_db.status = data["state"]["status"]
+ wf_ex_db.spec = data["spec"]
+ wf_ex_db.graph = data["graph"]
+ wf_ex_db.state = data["state"]
wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id))
@@ -542,12 +568,12 @@ def request_rerun(ac_ex_db, st2_ctx, options=None):
def request_task_execution(wf_ex_db, st2_ctx, task_ex_req):
- task_id = task_ex_req['id']
- task_route = task_ex_req['route']
- task_spec = task_ex_req['spec']
- task_ctx = task_ex_req['ctx']
- task_actions = task_ex_req['actions']
- task_delay = task_ex_req.get('delay')
+ task_id = task_ex_req["id"]
+ task_route = task_ex_req["route"]
+ task_spec = task_ex_req["spec"]
+ task_ctx = task_ex_req["ctx"]
+ task_actions = task_ex_req["actions"]
+ task_delay = task_ex_req.get("delay")
msg = 'Processing task execution request for task "%s", route "%s".'
update_progress(wf_ex_db, msg % (task_id, str(task_route)), stream=False)
@@ -557,11 +583,14 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req):
workflow_execution=str(wf_ex_db.id),
task_id=task_id,
task_route=task_route,
- order_by=['-start_timestamp']
+ order_by=["-start_timestamp"],
)
- if (len(task_ex_dbs) > 0 and task_ex_dbs[0].itemized and
- task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING):
+ if (
+ len(task_ex_dbs) > 0
+ and task_ex_dbs[0].itemized
+ and task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING
+ ):
task_ex_db = task_ex_dbs[0]
task_ex_id = str(task_ex_db.id)
msg = 'Task execution "%s" retrieved for task "%s", route "%s".'
@@ -576,15 +605,15 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req):
task_spec=task_spec.serialize(),
delay=task_delay,
itemized=task_spec.has_items(),
- items_count=task_ex_req.get('items_count'),
- items_concurrency=task_ex_req.get('concurrency'),
+ items_count=task_ex_req.get("items_count"),
+ items_concurrency=task_ex_req.get("concurrency"),
context=task_ctx,
- status=statuses.REQUESTED
+ status=statuses.REQUESTED,
)
# Prepare the result format for itemized task execution.
if task_ex_db.itemized:
- task_ex_db.result = {'items': [None] * task_ex_db.items_count}
+ task_ex_db.result = {"items": [None] * task_ex_db.items_count}
# Insert new record into the database.
task_ex_db = wf_db_access.TaskExecution.insert(task_ex_db, publish=False)
@@ -627,26 +656,35 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req):
# Request action execution for each actions in the task request.
for ac_ex_req in task_actions:
- ac_ex_delay = eval_action_execution_delay(task_ex_req, ac_ex_req, task_ex_db.itemized)
- request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay)
+ ac_ex_delay = eval_action_execution_delay(
+ task_ex_req, ac_ex_req, task_ex_db.itemized
+ )
+ request_action_execution(
+ wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay
+ )
task_ex_db = wf_db_access.TaskExecution.get_by_id(str(task_ex_db.id))
except Exception as e:
msg = 'Failed action execution(s) for task "%s", route "%s".'
msg = msg % (task_id, str(task_route))
LOG.exception(msg)
- msg = '%s %s: %s' % (msg, type(e).__name__, six.text_type(e))
- update_progress(wf_ex_db, msg, severity='error', log=False)
- msg = '%s: %s' % (type(e).__name__, six.text_type(e))
- error = {'type': 'error', 'message': msg, 'task_id': task_id, 'route': task_route}
- update_task_execution(str(task_ex_db.id), statuses.FAILED, {'errors': [error]})
+ msg = "%s %s: %s" % (msg, type(e).__name__, six.text_type(e))
+ update_progress(wf_ex_db, msg, severity="error", log=False)
+ msg = "%s: %s" % (type(e).__name__, six.text_type(e))
+ error = {
+ "type": "error",
+ "message": msg,
+ "task_id": task_id,
+ "route": task_route,
+ }
+ update_task_execution(str(task_ex_db.id), statuses.FAILED, {"errors": [error]})
raise e
return task_ex_db
def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False):
- task_ex_delay = task_ex_req.get('delay')
- items_concurrency = task_ex_req.get('concurrency')
+ task_ex_delay = task_ex_req.get("delay")
+ items_concurrency = task_ex_req.get("concurrency")
# If there is a task delay and not with items, return the delay value.
if task_ex_delay and not itemized:
@@ -658,7 +696,7 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False):
# If there is a task delay and task has items with concurrency,
# return the delay value up if item id is less than the concurrency value.
- if task_ex_delay and itemized and ac_ex_req['item_id'] < items_concurrency:
+ if task_ex_delay and itemized and ac_ex_req["item_id"] < items_concurrency:
return task_ex_delay
return None
@@ -667,20 +705,22 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=None):
- action_ref = ac_ex_req['action']
- action_input = ac_ex_req['input']
- item_id = ac_ex_req.get('item_id')
+ action_ref = ac_ex_req["action"]
+ action_input = ac_ex_req["input"]
+ item_id = ac_ex_req.get("item_id")
# If the task is with items and item_id is not provided, raise exception.
if task_ex_db.itemized and item_id is None:
- msg = 'Unable to request action execution. Identifier for the item is not provided.'
+ msg = "Unable to request action execution. Identifier for the item is not provided."
raise Exception(msg)
# Identify the action to execute.
@@ -691,40 +731,40 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non
raise ac_exc.InvalidActionReferencedException(error)
# Identify the runner for the action.
- runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"])
# Identify action pack name
- pack_name = action_ref.split('.')[0] if action_ref else st2_ctx.get('pack')
+ pack_name = action_ref.split(".")[0] if action_ref else st2_ctx.get("pack")
# Set context for the action execution.
ac_ex_ctx = {
- 'pack': pack_name,
- 'user': st2_ctx.get('user'),
- 'parent': st2_ctx,
- 'orquesta': {
- 'workflow_execution_id': str(wf_ex_db.id),
- 'task_execution_id': str(task_ex_db.id),
- 'task_name': task_ex_db.task_name,
- 'task_id': task_ex_db.task_id,
- 'task_route': task_ex_db.task_route
- }
+ "pack": pack_name,
+ "user": st2_ctx.get("user"),
+ "parent": st2_ctx,
+ "orquesta": {
+ "workflow_execution_id": str(wf_ex_db.id),
+ "task_execution_id": str(task_ex_db.id),
+ "task_name": task_ex_db.task_name,
+ "task_id": task_ex_db.task_id,
+ "task_route": task_ex_db.task_route,
+ },
}
- if st2_ctx.get('api_user'):
- ac_ex_ctx['api_user'] = st2_ctx.get('api_user')
+ if st2_ctx.get("api_user"):
+ ac_ex_ctx["api_user"] = st2_ctx.get("api_user")
- if st2_ctx.get('source_channel'):
- ac_ex_ctx['source_channel'] = st2_ctx.get('source_channel')
+ if st2_ctx.get("source_channel"):
+ ac_ex_ctx["source_channel"] = st2_ctx.get("source_channel")
if item_id is not None:
- ac_ex_ctx['orquesta']['item_id'] = item_id
+ ac_ex_ctx["orquesta"]["item_id"] = item_id
# Render action execution parameters and setup action execution object.
ac_ex_params = param_utils.render_live_params(
runner_type_db.runner_parameters or {},
action_db.parameters or {},
action_input or {},
- ac_ex_ctx
+ ac_ex_ctx,
)
# The delay spec is in seconds and scheduler expects milliseconds.
@@ -738,13 +778,19 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non
task_execution=str(task_ex_db.id),
delay=delay,
context=ac_ex_ctx,
- parameters=ac_ex_params
+ parameters=ac_ex_params,
)
# Set notification if instructed.
- if (wf_ex_db.notify and wf_ex_db.notify.get('config') and
- wf_ex_db.notify.get('tasks') and task_ex_db.task_name in wf_ex_db.notify['tasks']):
- lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(wf_ex_db.notify['config'])
+ if (
+ wf_ex_db.notify
+ and wf_ex_db.notify.get("config")
+ and wf_ex_db.notify.get("tasks")
+ and task_ex_db.task_name in wf_ex_db.notify["tasks"]
+ ):
+ lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(
+ wf_ex_db.notify["config"]
+ )
# Set the task execution to running first otherwise a race can occur
# where the action execution finishes first and the completion handler
@@ -765,13 +811,13 @@ def handle_action_execution_pending(ac_ex_db):
# Check that the action execution is paused.
if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PENDING:
raise Exception(
- 'Unable to handle pending of action execution. The action execution '
+ "Unable to handle pending of action execution. The action execution "
'"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status)
)
# Get related record identifiers.
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
# Get execution records for logging purposes.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
@@ -780,14 +826,14 @@ def handle_action_execution_pending(ac_ex_db):
msg = 'Handling pending of action execution "%s" for task "%s", route "%s".'
update_progress(
wf_ex_db,
- msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route))
+ msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)),
)
# Updat task execution
update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context)
# Update task flow in the workflow execution.
- ac_ex_ctx = ac_ex_db.context.get('orquesta')
+ ac_ex_ctx = ac_ex_db.context.get("orquesta")
update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True)
@@ -795,13 +841,13 @@ def handle_action_execution_pause(ac_ex_db):
# Check that the action execution is paused.
if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PAUSED:
raise Exception(
- 'Unable to handle pause of action execution. The action execution '
+ "Unable to handle pause of action execution. The action execution "
'"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status)
)
# Get related record identifiers.
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
# Get execution records for logging purposes.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
@@ -814,27 +860,27 @@ def handle_action_execution_pause(ac_ex_db):
msg = 'Handling pause of action execution "%s" for task "%s", route "%s".'
update_progress(
wf_ex_db,
- msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route))
+ msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)),
)
# Updat task execution
update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context)
# Update task flow in the workflow execution.
- ac_ex_ctx = ac_ex_db.context.get('orquesta')
+ ac_ex_ctx = ac_ex_db.context.get("orquesta")
update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True)
def handle_action_execution_resume(ac_ex_db):
- if 'orquesta' not in ac_ex_db.context:
+ if "orquesta" not in ac_ex_db.context:
raise Exception(
- 'Unable to handle resume of action execution. The action execution '
- '%s is not an orquesta workflow task.' % str(ac_ex_db.id)
+ "Unable to handle resume of action execution. The action execution "
+ "%s is not an orquesta workflow task." % str(ac_ex_db.id)
)
# Get related record identifiers.
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
# Get execution records for logging purposes.
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
@@ -843,7 +889,7 @@ def handle_action_execution_resume(ac_ex_db):
msg = 'Handling resume of action execution "%s" for task "%s", route "%s".'
update_progress(
wf_ex_db,
- msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route))
+ msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)),
)
# Updat task execution to running.
@@ -854,18 +900,22 @@ def handle_action_execution_resume(ac_ex_db):
# If action execution has a parent, cascade status change upstream and do not publish
# the status change because we do not want to trigger resume of other peer subworkflows.
- if 'parent' in ac_ex_db.context:
- parent_ac_ex_id = ac_ex_db.context['parent']['execution_id']
+ if "parent" in ac_ex_db.context:
+ parent_ac_ex_id = ac_ex_db.context["parent"]["execution_id"]
parent_ac_ex_db = ex_db_access.ActionExecution.get_by_id(parent_ac_ex_id)
if parent_ac_ex_db.status == ac_const.LIVEACTION_STATUS_PAUSED:
action_utils.update_liveaction_status(
- liveaction_id=parent_ac_ex_db.liveaction['id'],
+ liveaction_id=parent_ac_ex_db.liveaction["id"],
status=ac_const.LIVEACTION_STATUS_RUNNING,
- publish=False)
+ publish=False,
+ )
# If there are grand parents, handle the resume of the parent action execution.
- if 'orquesta' in parent_ac_ex_db.context and 'parent' in parent_ac_ex_db.context:
+ if (
+ "orquesta" in parent_ac_ex_db.context
+ and "parent" in parent_ac_ex_db.context
+ ):
handle_action_execution_resume(parent_ac_ex_db)
@@ -873,18 +923,19 @@ def handle_action_execution_resume(ac_ex_db):
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def handle_action_execution_completion(ac_ex_db):
# Check that the action execution is completed.
if ac_ex_db.status not in ac_const.LIVEACTION_COMPLETED_STATES:
raise Exception(
- 'Unable to handle completion of action execution. The action execution '
+ "Unable to handle completion of action execution. The action execution "
'"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status)
)
# Get related record identifiers.
- wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id']
- task_ex_id = ac_ex_db.context['orquesta']['task_execution_id']
+ wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"]
+ task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"]
# Acquire lock before write operations.
with coord_svc.get_coordinator(start_heart=True).get_lock(wf_ex_id):
@@ -894,9 +945,12 @@ def handle_action_execution_completion(ac_ex_db):
msg = (
'Handling completion of action execution "%s" '
- 'in status "%s" for task "%s", route "%s".' % (
- str(ac_ex_db.id), ac_ex_db.status, task_ex_db.task_id,
- str(task_ex_db.task_route)
+ 'in status "%s" for task "%s", route "%s".'
+ % (
+ str(ac_ex_db.id),
+ ac_ex_db.status,
+ task_ex_db.task_id,
+ str(task_ex_db.task_route),
)
)
update_progress(wf_ex_db, msg)
@@ -907,14 +961,16 @@ def handle_action_execution_completion(ac_ex_db):
resume_task_execution(task_ex_id)
# Update task execution if completed.
- update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context)
+ update_task_execution(
+ task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context
+ )
# Update task flow in the workflow execution.
update_task_state(
task_ex_id,
ac_ex_db.status,
ac_ex_result=ac_ex_db.result,
- ac_ex_ctx=ac_ex_db.context.get('orquesta')
+ ac_ex_ctx=ac_ex_db.context.get("orquesta"),
)
# Request the next set of tasks if workflow execution is not complete.
@@ -926,13 +982,13 @@ def handle_action_execution_completion(ac_ex_db):
def deserialize_conductor(wf_ex_db):
data = {
- 'spec': wf_ex_db.spec,
- 'graph': wf_ex_db.graph,
- 'input': wf_ex_db.input,
- 'context': wf_ex_db.context,
- 'state': wf_ex_db.state,
- 'output': wf_ex_db.output,
- 'errors': wf_ex_db.errors
+ "spec": wf_ex_db.spec,
+ "graph": wf_ex_db.graph,
+ "input": wf_ex_db.input,
+ "context": wf_ex_db.context,
+ "state": wf_ex_db.state,
+ "output": wf_ex_db.output,
+ "errors": wf_ex_db.errors,
}
return conducting.WorkflowConductor.deserialize(data)
@@ -948,18 +1004,22 @@ def refresh_conductor(wf_ex_id):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
-def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True):
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
+def update_task_state(
+ task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True
+):
# Return if action execution status is not in the list of statuses to process.
- statuses_to_process = (
- copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) +
- [ac_const.LIVEACTION_STATUS_PAUSED, ac_const.LIVEACTION_STATUS_PENDING]
- )
+ statuses_to_process = copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) + [
+ ac_const.LIVEACTION_STATUS_PAUSED,
+ ac_const.LIVEACTION_STATUS_PENDING,
+ ]
if ac_ex_status not in statuses_to_process:
return
@@ -973,22 +1033,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non
msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), task_ex_db.status)
update_progress(wf_ex_db, msg, stream=False)
- if not ac_ex_ctx or 'item_id' not in ac_ex_ctx or ac_ex_ctx['item_id'] < 0:
+ if not ac_ex_ctx or "item_id" not in ac_ex_ctx or ac_ex_ctx["item_id"] < 0:
ac_ex_event = events.ActionExecutionEvent(ac_ex_status, result=ac_ex_result)
else:
accumulated_result = [
- item.get('result') if item else None
- for item in task_ex_db.result['items']
+ item.get("result") if item else None for item in task_ex_db.result["items"]
]
ac_ex_event = events.TaskItemActionExecutionEvent(
- ac_ex_ctx['item_id'],
+ ac_ex_ctx["item_id"],
ac_ex_status,
result=ac_ex_result,
- accumulated_result=accumulated_result
+ accumulated_result=accumulated_result,
)
- update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False)
+ update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False)
conductor.update_task_state(task_ex_db.task_id, task_ex_db.task_route, ac_ex_event)
# Update workflow execution and related liveaction and action execution.
@@ -997,19 +1056,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non
conductor,
update_lv_ac_on_statuses=statuses_to_process,
pub_lv_ac=publish,
- pub_ac_ex=publish
+ pub_ac_ex=publish,
)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def request_next_tasks(wf_ex_db, task_ex_id=None):
iteration = 0
@@ -1018,7 +1079,9 @@ def request_next_tasks(wf_ex_db, task_ex_id=None):
# If workflow is in requested status, set it to running.
if conductor.get_workflow_status() in [statuses.REQUESTED, statuses.SCHEDULED]:
- update_progress(wf_ex_db, 'Requesting conductor to start running workflow execution.')
+ update_progress(
+ wf_ex_db, "Requesting conductor to start running workflow execution."
+ )
conductor.request_workflow_status(statuses.RUNNING)
# Identify the list of next set of tasks. Don't pass the task id to the conductor
@@ -1028,93 +1091,104 @@ def request_next_tasks(wf_ex_db, task_ex_id=None):
msg = 'Identifying next set (iter %s) of tasks after completion of task "%s", route "%s".'
msg = msg % (str(iteration), task_ex_db.task_id, str(task_ex_db.task_route))
update_progress(wf_ex_db, msg)
- update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False)
+ update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False)
next_tasks = conductor.get_next_tasks()
else:
msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".'
msg = msg % (str(iteration), conductor.get_workflow_status())
update_progress(wf_ex_db, msg)
- update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False)
+ update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False)
next_tasks = conductor.get_next_tasks()
# If there is no new tasks, update execution records to handle possible completion.
if not next_tasks:
# Update workflow execution and related liveaction and action execution.
- update_progress(wf_ex_db, 'No tasks identified to execute next.')
- update_progress(wf_ex_db, '\n', log=False)
+ update_progress(wf_ex_db, "No tasks identified to execute next.")
+ update_progress(wf_ex_db, "\n", log=False)
update_execution_records(wf_ex_db, conductor)
if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES:
msg = 'The workflow execution is completed with status "%s".'
update_progress(wf_ex_db, msg % conductor.get_workflow_status())
- update_progress(wf_ex_db, '\n', log=False)
+ update_progress(wf_ex_db, "\n", log=False)
# Iterate while there are next tasks identified for processing. In the case for
# task with no action execution defined, the task execution will complete
# immediately with a new set of tasks available.
while next_tasks:
- msg = 'Identified the following set of tasks to execute next: %s'
- tasks_list = ', '.join(["%s (route %s)" % (t['id'], str(t['route'])) for t in next_tasks])
+ msg = "Identified the following set of tasks to execute next: %s"
+ tasks_list = ", ".join(
+ ["%s (route %s)" % (t["id"], str(t["route"])) for t in next_tasks]
+ )
update_progress(wf_ex_db, msg % tasks_list)
# Mark the tasks as running in the task flow before actual task execution.
for task in next_tasks:
msg = 'Mark task "%s", route "%s", in conductor as running.'
- update_progress(wf_ex_db, msg % (task['id'], str(task['route'])), stream=False)
+ update_progress(
+ wf_ex_db, msg % (task["id"], str(task["route"])), stream=False
+ )
# If task has items and items list is empty, then actions under the next task is empty
# and will not be processed in the for loop below. Handle this use case separately and
# mark it as running in the conductor. The task will be completed automatically when
# it is requested for task execution.
- if task['spec'].has_items() and 'items_count' in task and task['items_count'] == 0:
+ if (
+ task["spec"].has_items()
+ and "items_count" in task
+ and task["items_count"] == 0
+ ):
ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING)
- conductor.update_task_state(task['id'], task['route'], ac_ex_event)
+ conductor.update_task_state(task["id"], task["route"], ac_ex_event)
# If task contains multiple action execution (i.e. with items),
# then mark each item individually.
- for action in task['actions']:
- if 'item_id' not in action or action['item_id'] is None:
+ for action in task["actions"]:
+ if "item_id" not in action or action["item_id"] is None:
ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING)
else:
- msg = 'Mark task "%s", route "%s", item "%s" in conductor as running.'
- msg = msg % (task['id'], str(task['route']), action['item_id'])
+ msg = (
+ 'Mark task "%s", route "%s", item "%s" in conductor as running.'
+ )
+ msg = msg % (task["id"], str(task["route"]), action["item_id"])
update_progress(wf_ex_db, msg)
ac_ex_event = events.TaskItemActionExecutionEvent(
- action['item_id'],
- statuses.RUNNING
+ action["item_id"], statuses.RUNNING
)
- conductor.update_task_state(task['id'], task['route'], ac_ex_event)
+ conductor.update_task_state(task["id"], task["route"], ac_ex_event)
# Update workflow execution and related liveaction and action execution.
- update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False)
+ update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False)
update_execution_records(wf_ex_db, conductor)
# Request task execution for the tasks.
for task in next_tasks:
try:
msg = 'Requesting execution for task "%s", route "%s".'
- update_progress(wf_ex_db, msg % (task['id'], str(task['route'])))
+ update_progress(wf_ex_db, msg % (task["id"], str(task["route"])))
# Pass down appropriate st2 context to the task and action execution(s).
- root_st2_ctx = wf_ex_db.context.get('st2', {})
+ root_st2_ctx = wf_ex_db.context.get("st2", {})
st2_ctx = {
- 'execution_id': wf_ex_db.action_execution,
- 'user': root_st2_ctx.get('user'),
- 'pack': root_st2_ctx.get('pack')
+ "execution_id": wf_ex_db.action_execution,
+ "user": root_st2_ctx.get("user"),
+ "pack": root_st2_ctx.get("pack"),
}
- if root_st2_ctx.get('api_user'):
- st2_ctx['api_user'] = root_st2_ctx.get('api_user')
+ if root_st2_ctx.get("api_user"):
+ st2_ctx["api_user"] = root_st2_ctx.get("api_user")
- if root_st2_ctx.get('source_channel'):
- st2_ctx['source_channel'] = root_st2_ctx.get('source_channel')
+ if root_st2_ctx.get("source_channel"):
+ st2_ctx["source_channel"] = root_st2_ctx.get("source_channel")
# Request the task execution.
request_task_execution(wf_ex_db, st2_ctx, task)
except Exception as e:
msg = 'Failed task execution for task "%s", route "%s".'
- msg = msg % (task['id'], str(task['route']))
- update_progress(wf_ex_db, '%s %s' % (msg, str(e)), severity='error', log=False)
+ msg = msg % (task["id"], str(task["route"]))
+ update_progress(
+ wf_ex_db, "%s %s" % (msg, str(e)), severity="error", log=False
+ )
LOG.exception(msg)
fail_workflow_execution(str(wf_ex_db.id), e, task=task)
return
@@ -1125,25 +1199,30 @@ def request_next_tasks(wf_ex_db, task_ex_id=None):
msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".'
msg = msg % (str(iteration), conductor.get_workflow_status())
update_progress(wf_ex_db, msg)
- update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False)
+ update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False)
next_tasks = conductor.get_next_tasks()
if not next_tasks:
- update_progress(wf_ex_db, 'No tasks identified to execute next.')
- update_progress(wf_ex_db, '\n', log=False)
+ update_progress(wf_ex_db, "No tasks identified to execute next.")
+ update_progress(wf_ex_db, "\n", log=False)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None):
- if ac_ex_status not in statuses.COMPLETED_STATUSES + [statuses.PAUSED, statuses.PENDING]:
+ if ac_ex_status not in statuses.COMPLETED_STATUSES + [
+ statuses.PAUSED,
+ statuses.PENDING,
+ ]:
return
task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id)
@@ -1153,31 +1232,43 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx
if not task_ex_db.itemized or (task_ex_db.itemized and task_ex_db.items_count == 0):
if ac_ex_status != task_ex_db.status:
msg = 'Updating task execution "%s" for task "%s" from status "%s" to "%s".'
- msg = msg % (task_ex_id, task_ex_db.task_id, task_ex_db.status, ac_ex_status)
+ msg = msg % (
+ task_ex_id,
+ task_ex_db.task_id,
+ task_ex_db.status,
+ ac_ex_status,
+ )
update_progress(wf_ex_db, msg)
task_ex_db.status = ac_ex_status
task_ex_db.result = ac_ex_result if ac_ex_result else task_ex_db.result
elif task_ex_db.itemized and ac_ex_ctx:
- if 'orquesta' not in ac_ex_ctx or 'item_id' not in ac_ex_ctx['orquesta']:
- msg = 'Context information for the item is not provided. %s' % str(ac_ex_ctx)
- update_progress(wf_ex_db, msg, severity='error', log=False)
+ if "orquesta" not in ac_ex_ctx or "item_id" not in ac_ex_ctx["orquesta"]:
+ msg = "Context information for the item is not provided. %s" % str(
+ ac_ex_ctx
+ )
+ update_progress(wf_ex_db, msg, severity="error", log=False)
raise Exception(msg)
- item_id = ac_ex_ctx['orquesta']['item_id']
+ item_id = ac_ex_ctx["orquesta"]["item_id"]
msg = 'Processing action execution for task "%s", route "%s", item "%s".'
msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), item_id)
- update_progress(wf_ex_db, msg, severity='debug')
+ update_progress(wf_ex_db, msg, severity="debug")
- task_ex_db.result['items'][item_id] = {'status': ac_ex_status, 'result': ac_ex_result}
+ task_ex_db.result["items"][item_id] = {
+ "status": ac_ex_status,
+ "result": ac_ex_result,
+ }
item_statuses = [
- item.get('status', statuses.UNSET) if item else statuses.UNSET
- for item in task_ex_db.result['items']
+ item.get("status", statuses.UNSET) if item else statuses.UNSET
+ for item in task_ex_db.result["items"]
]
- task_completed = all([status in statuses.COMPLETED_STATUSES for status in item_statuses])
+ task_completed = all(
+ [status in statuses.COMPLETED_STATUSES for status in item_statuses]
+ )
if task_completed:
new_task_status = (
@@ -1187,11 +1278,15 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx
)
msg = 'Updating task execution from status "%s" to "%s".'
- update_progress(wf_ex_db, msg % (task_ex_db.status, new_task_status), severity='debug')
+ update_progress(
+ wf_ex_db, msg % (task_ex_db.status, new_task_status), severity="debug"
+ )
task_ex_db.status = new_task_status
else:
- msg = 'Task execution is not complete because not all items are complete: %s'
- update_progress(wf_ex_db, msg % ', '.join(item_statuses), severity='debug')
+ msg = (
+ "Task execution is not complete because not all items are complete: %s"
+ )
+ update_progress(wf_ex_db, msg % ", ".join(item_statuses), severity="debug")
if task_ex_db.status in statuses.COMPLETED_STATUSES:
task_ex_db.end_timestamp = date_utils.get_datetime_utc_now()
@@ -1202,19 +1297,23 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def resume_task_execution(task_ex_id):
# Update task execution to running.
task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id)
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(task_ex_db.workflow_execution)
msg = 'Updating task execution from status "%s" to "%s".'
- update_progress(wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity='debug')
+ update_progress(
+ wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity="debug"
+ )
task_ex_db.status = statuses.RUNNING
# Write update to the database.
@@ -1224,17 +1323,21 @@ def resume_task_execution(task_ex_id):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def update_workflow_execution(wf_ex_id):
conductor, wf_ex_db = refresh_conductor(wf_ex_id)
# There is nothing to update if workflow execution is not completed or paused.
- if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [statuses.PAUSED]:
+ if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [
+ statuses.PAUSED
+ ]:
# Update workflow execution and related liveaction and action execution.
update_execution_records(wf_ex_db, conductor)
@@ -1242,12 +1345,14 @@ def update_workflow_execution(wf_ex_id):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def resume_workflow_execution(wf_ex_id, task_ex_id):
# Update workflow execution to running.
conductor, wf_ex_db = refresh_conductor(wf_ex_id)
@@ -1265,12 +1370,14 @@ def resume_workflow_execution(wf_ex_id, task_ex_id):
@retrying.retry(
retry_on_exception=wf_exc.retry_on_transient_db_errors,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
@retrying.retry(
retry_on_exception=wf_exc.retry_on_connection_errors,
stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec,
wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec,
- wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec)
+ wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec,
+)
def fail_workflow_execution(wf_ex_id, exception, task=None):
conductor, wf_ex_db = refresh_conductor(wf_ex_id)
@@ -1278,7 +1385,7 @@ def fail_workflow_execution(wf_ex_id, exception, task=None):
conductor.request_workflow_status(statuses.FAILED)
if task is not None and isinstance(task, dict):
- conductor.log_error(exception, task_id=task.get('id'), route=task.get('route'))
+ conductor.log_error(exception, task_id=task.get("id"), route=task.get("route"))
else:
conductor.log_error(exception)
@@ -1286,8 +1393,14 @@ def fail_workflow_execution(wf_ex_id, exception, task=None):
update_execution_records(wf_ex_db, conductor)
-def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None,
- pub_wf_ex=False, pub_lv_ac=True, pub_ac_ex=True):
+def update_execution_records(
+ wf_ex_db,
+ conductor,
+ update_lv_ac_on_statuses=None,
+ pub_wf_ex=False,
+ pub_lv_ac=True,
+ pub_ac_ex=True,
+):
# If the workflow execution is completed, then render the workflow output.
if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES:
conductor.render_workflow_output()
@@ -1295,7 +1408,7 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None,
# Determine if workflow status has changed.
wf_old_status = wf_ex_db.status
wf_ex_db.status = conductor.get_workflow_status()
- status_changed = (wf_old_status != wf_ex_db.status)
+ status_changed = wf_old_status != wf_ex_db.status
if status_changed:
msg = 'Updating workflow execution from status "%s" to "%s".'
@@ -1314,53 +1427,58 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None,
wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=pub_wf_ex)
# Return if workflow execution status is not specified in update_lv_ac_on_statuses.
- if (isinstance(update_lv_ac_on_statuses, list) and
- wf_ex_db.status not in update_lv_ac_on_statuses):
+ if (
+ isinstance(update_lv_ac_on_statuses, list)
+ and wf_ex_db.status not in update_lv_ac_on_statuses
+ ):
return
# Update the corresponding liveaction and action execution for the workflow.
wf_ac_ex_db = ex_db_access.ActionExecution.get_by_id(wf_ex_db.action_execution)
- wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction['id'])
+ wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction["id"])
# Gather result for liveaction and action execution.
- result = {'output': wf_ex_db.output or None}
+ result = {"output": wf_ex_db.output or None}
if wf_ex_db.status in statuses.ABENDED_STATUSES:
- result['errors'] = wf_ex_db.errors
+ result["errors"] = wf_ex_db.errors
if wf_ex_db.errors:
- msg = 'Workflow execution completed with errors.'
- update_progress(wf_ex_db, msg, severity='error')
+ msg = "Workflow execution completed with errors."
+ update_progress(wf_ex_db, msg, severity="error")
for wf_ex_error in wf_ex_db.errors:
- update_progress(wf_ex_db, wf_ex_error, severity='error')
+ update_progress(wf_ex_db, wf_ex_error, severity="error")
# Sync update with corresponding liveaction and action execution.
if pub_lv_ac or pub_ac_ex:
- pub_lv_ac = (wf_lv_ac_db.status != wf_ex_db.status)
+ pub_lv_ac = wf_lv_ac_db.status != wf_ex_db.status
pub_ac_ex = pub_lv_ac
if wf_lv_ac_db.status != wf_ex_db.status:
- kwargs = {'severity': 'debug', 'stream': False}
+ kwargs = {"severity": "debug", "stream": False}
msg = 'Updating workflow liveaction from status "%s" to "%s".'
update_progress(wf_ex_db, msg % (wf_lv_ac_db.status, wf_ex_db.status), **kwargs)
- msg = 'Workflow liveaction status change %s be published.'
- update_progress(wf_ex_db, msg % 'will' if pub_lv_ac else 'will not', **kwargs)
- msg = 'Workflow action execution status change %s be published.'
- update_progress(wf_ex_db, msg % 'will' if pub_ac_ex else 'will not', **kwargs)
+ msg = "Workflow liveaction status change %s be published."
+ update_progress(wf_ex_db, msg % "will" if pub_lv_ac else "will not", **kwargs)
+ msg = "Workflow action execution status change %s be published."
+ update_progress(wf_ex_db, msg % "will" if pub_ac_ex else "will not", **kwargs)
wf_lv_ac_db = action_utils.update_liveaction_status(
status=wf_ex_db.status,
result=result,
end_timestamp=wf_ex_db.end_timestamp,
liveaction_db=wf_lv_ac_db,
- publish=pub_lv_ac)
+ publish=pub_lv_ac,
+ )
ex_svc.update_execution(wf_lv_ac_db, publish=pub_ac_ex)
# Invoke post run on the liveaction for the workflow execution.
if status_changed and wf_lv_ac_db.status in ac_const.LIVEACTION_COMPLETED_STATES:
- update_progress(wf_ex_db, 'Workflow action execution is completed and invoking post run.')
+ update_progress(
+ wf_ex_db, "Workflow action execution is completed and invoking post run."
+ )
runners_utils.invoke_post_run(wf_lv_ac_db)
@@ -1376,36 +1494,40 @@ def identify_orphaned_workflows():
# does not necessary means it is the max idle time. The use of workflow_executions_idled_ttl
# to filter is to reduce the number of action executions that need to be evaluated.
query_filters = {
- 'runner__name': 'orquesta',
- 'status': ac_const.LIVEACTION_STATUS_RUNNING,
- 'start_timestamp__lte': expiry_dt
+ "runner__name": "orquesta",
+ "status": ac_const.LIVEACTION_STATUS_RUNNING,
+ "start_timestamp__lte": expiry_dt,
}
ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters)
for ac_ex_db in ac_ex_dbs:
# Figure out the runtime for the action execution.
status_change_logs = sorted(
- [log for log in ac_ex_db.log if log['status'] == ac_const.LIVEACTION_STATUS_RUNNING],
- key=lambda x: x['timestamp'],
- reverse=True
+ [
+ log
+ for log in ac_ex_db.log
+ if log["status"] == ac_const.LIVEACTION_STATUS_RUNNING
+ ],
+ key=lambda x: x["timestamp"],
+ reverse=True,
)
if len(status_change_logs) <= 0:
continue
- runtime = (utc_now_dt - status_change_logs[0]['timestamp']).total_seconds()
+ runtime = (utc_now_dt - status_change_logs[0]["timestamp"]).total_seconds()
# Fetch the task executions for the workflow execution.
# Ensure that the root action execution is not being selected.
- wf_ex_id = ac_ex_db.context['workflow_execution']
+ wf_ex_id = ac_ex_db.context["workflow_execution"]
wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id)
- query_filters = {'workflow_execution': wf_ex_id, 'id__ne': ac_ex_db.id}
+ query_filters = {"workflow_execution": wf_ex_id, "id__ne": ac_ex_db.id}
tk_ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters)
# The workflow execution is orphaned if there are
# no task executions and runtime passed expiry.
if len(tk_ac_ex_dbs) <= 0 and runtime > gc_max_idle:
- msg = 'The action execution is orphaned and will be canceled by the garbage collector.'
+ msg = "The action execution is orphaned and will be canceled by the garbage collector."
update_progress(wf_ex_db, msg)
orphaned.append(ac_ex_db)
continue
@@ -1415,7 +1537,8 @@ def identify_orphaned_workflows():
has_active_tasks = len([t for t in tk_ac_ex_dbs if t.end_timestamp is None]) > 0
completed_tasks = [
- t for t in tk_ac_ex_dbs
+ t
+ for t in tk_ac_ex_dbs
if t.end_timestamp is not None and t.end_timestamp <= expiry_dt
]
@@ -1423,11 +1546,16 @@ def identify_orphaned_workflows():
most_recent_completed_task_expired = (
completed_tasks[-1].end_timestamp <= expiry_dt
- if len(completed_tasks) > 0 else False
+ if len(completed_tasks) > 0
+ else False
)
- if len(tk_ac_ex_dbs) > 0 and not has_active_tasks and most_recent_completed_task_expired:
- msg = 'The action execution is orphaned and will be canceled by the garbage collector.'
+ if (
+ len(tk_ac_ex_dbs) > 0
+ and not has_active_tasks
+ and most_recent_completed_task_expired
+ ):
+ msg = "The action execution is orphaned and will be canceled by the garbage collector."
update_progress(wf_ex_db, msg)
orphaned.append(ac_ex_db)
continue
diff --git a/st2common/st2common/signal_handlers.py b/st2common/st2common/signal_handlers.py
index 0fc2766175..bd785403f4 100644
--- a/st2common/st2common/signal_handlers.py
+++ b/st2common/st2common/signal_handlers.py
@@ -26,7 +26,7 @@
from st2common.logging.misc import reopen_log_files
__all__ = [
- 'register_common_signal_handlers',
+ "register_common_signal_handlers",
]
diff --git a/st2common/st2common/stream/listener.py b/st2common/st2common/stream/listener.py
index 6edbef1750..347c4cfc75 100644
--- a/st2common/st2common/stream/listener.py
+++ b/st2common/st2common/stream/listener.py
@@ -33,11 +33,10 @@
from st2common import log as logging
__all__ = [
- 'StreamListener',
- 'ExecutionOutputListener',
-
- 'get_listener',
- 'get_listener_if_set'
+ "StreamListener",
+ "ExecutionOutputListener",
+ "get_listener",
+ "get_listener_if_set",
]
LOG = logging.getLogger(__name__)
@@ -49,23 +48,24 @@
class BaseListener(ConsumerMixin):
-
def __init__(self, connection):
self.connection = connection
self.queues = []
self._stopped = False
def get_consumers(self, consumer, channel):
- raise NotImplementedError('get_consumers() is not implemented')
+ raise NotImplementedError("get_consumers() is not implemented")
def processor(self, model=None):
def process(body, message):
meta = message.delivery_info
- event_name = '%s__%s' % (meta.get('exchange'), meta.get('routing_key'))
+ event_name = "%s__%s" % (meta.get("exchange"), meta.get("routing_key"))
try:
if model:
- body = model.from_model(body, mask_secrets=cfg.CONF.api.mask_secrets)
+ body = model.from_model(
+ body, mask_secrets=cfg.CONF.api.mask_secrets
+ )
self.emit(event_name, body)
finally:
@@ -78,10 +78,17 @@ def emit(self, event, body):
for queue in self.queues:
queue.put(pack)
- def generator(self, events=None, action_refs=None, execution_ids=None,
- end_event=None, end_statuses=None, end_execution_id=None):
+ def generator(
+ self,
+ events=None,
+ action_refs=None,
+ execution_ids=None,
+ end_event=None,
+ end_statuses=None,
+ end_execution_id=None,
+ ):
queue = eventlet.Queue()
- queue.put('')
+ queue.put("")
self.queues.append(queue)
try:
stop = False
@@ -95,16 +102,19 @@ def generator(self, events=None, action_refs=None, execution_ids=None,
event_name, body = message
# check to see if this is the last message to send.
if event_name == end_event:
- if body is not None and \
- body.status in end_statuses and \
- end_execution_id is not None and \
- body.id == end_execution_id:
+ if (
+ body is not None
+ and body.status in end_statuses
+ and end_execution_id is not None
+ and body.id == end_execution_id
+ ):
stop = True
# TODO: We now do late filtering, but this could also be performed on the
# message bus level if we modified our exchange layout and utilize routing keys
# Filter on event name
- include_event = self._should_include_event(event_names_whitelist=events,
- event_name=event_name)
+ include_event = self._should_include_event(
+ event_names_whitelist=events, event_name=event_name
+ )
if not include_event:
LOG.debug('Skipping event "%s"' % (event_name))
continue
@@ -112,14 +122,18 @@ def generator(self, events=None, action_refs=None, execution_ids=None,
# Filter on action ref
action_ref = self._get_action_ref_for_body(body=body)
if action_refs and action_ref not in action_refs:
- LOG.debug('Skipping event "%s" with action_ref "%s"' % (event_name,
- action_ref))
+ LOG.debug(
+ 'Skipping event "%s" with action_ref "%s"'
+ % (event_name, action_ref)
+ )
continue
# Filter on execution id
execution_id = self._get_execution_id_for_body(body=body)
if execution_ids and execution_id not in execution_ids:
- LOG.debug('Skipping event "%s" with execution_id "%s"' % (event_name,
- execution_id))
+ LOG.debug(
+ 'Skipping event "%s" with execution_id "%s"'
+ % (event_name, execution_id)
+ )
continue
yield message
@@ -154,7 +168,7 @@ def _get_action_ref_for_body(self, body):
action_ref = None
if isinstance(body, ActionExecutionAPI):
- action_ref = body.action.get('ref', None) if body.action else None
+ action_ref = body.action.get("ref", None) if body.action else None
elif isinstance(body, LiveActionAPI):
action_ref = body.action
elif isinstance(body, (ActionExecutionOutputAPI)):
@@ -187,21 +201,26 @@ class StreamListener(BaseListener):
def get_consumers(self, consumer, channel):
return [
- consumer(queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor()]),
-
- consumer(queues=[STREAM_EXECUTION_ALL_WORK_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor(ActionExecutionAPI)]),
-
- consumer(queues=[STREAM_LIVEACTION_WORK_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor(LiveActionAPI)]),
-
- consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor(ActionExecutionOutputAPI)])
+ consumer(
+ queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor()],
+ ),
+ consumer(
+ queues=[STREAM_EXECUTION_ALL_WORK_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor(ActionExecutionAPI)],
+ ),
+ consumer(
+ queues=[STREAM_LIVEACTION_WORK_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor(LiveActionAPI)],
+ ),
+ consumer(
+ queues=[STREAM_EXECUTION_OUTPUT_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor(ActionExecutionOutputAPI)],
+ ),
]
@@ -214,13 +233,16 @@ class ExecutionOutputListener(BaseListener):
def get_consumers(self, consumer, channel):
return [
- consumer(queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor(ActionExecutionAPI)]),
-
- consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE],
- accept=['pickle'],
- callbacks=[self.processor(ActionExecutionOutputAPI)])
+ consumer(
+ queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor(ActionExecutionAPI)],
+ ),
+ consumer(
+ queues=[STREAM_EXECUTION_OUTPUT_QUEUE],
+ accept=["pickle"],
+ callbacks=[self.processor(ActionExecutionOutputAPI)],
+ ),
]
@@ -235,29 +257,29 @@ def get_listener(name):
global _stream_listener
global _execution_output_listener
- if name == 'stream':
+ if name == "stream":
if not _stream_listener:
with transport_utils.get_connection() as conn:
_stream_listener = StreamListener(conn)
eventlet.spawn_n(listen, _stream_listener)
return _stream_listener
- elif name == 'execution_output':
+ elif name == "execution_output":
if not _execution_output_listener:
with transport_utils.get_connection() as conn:
_execution_output_listener = ExecutionOutputListener(conn)
eventlet.spawn_n(listen, _execution_output_listener)
return _execution_output_listener
else:
- raise ValueError('Invalid listener name: %s' % (name))
+ raise ValueError("Invalid listener name: %s" % (name))
def get_listener_if_set(name):
global _stream_listener
global _execution_output_listener
- if name == 'stream':
+ if name == "stream":
return _stream_listener
- elif name == 'execution_output':
+ elif name == "execution_output":
return _execution_output_listener
else:
- raise ValueError('Invalid listener name: %s' % (name))
+ raise ValueError("Invalid listener name: %s" % (name))
diff --git a/st2common/st2common/transport/__init__.py b/st2common/st2common/transport/__init__.py
index cc384c878e..632c08dc0e 100644
--- a/st2common/st2common/transport/__init__.py
+++ b/st2common/st2common/transport/__init__.py
@@ -21,12 +21,12 @@
# TODO(manas) : Exchanges, Queues and RoutingKey design discussion pending.
__all__ = [
- 'liveaction',
- 'actionexecutionstate',
- 'execution',
- 'workflow',
- 'publishers',
- 'reactor',
- 'utils',
- 'connection_retry_wrapper'
+ "liveaction",
+ "actionexecutionstate",
+ "execution",
+ "workflow",
+ "publishers",
+ "reactor",
+ "utils",
+ "connection_retry_wrapper",
]
diff --git a/st2common/st2common/transport/actionexecutionstate.py b/st2common/st2common/transport/actionexecutionstate.py
index 268bffe0fc..46fe095fbf 100644
--- a/st2common/st2common/transport/actionexecutionstate.py
+++ b/st2common/st2common/transport/actionexecutionstate.py
@@ -21,18 +21,16 @@
from st2common.transport import publishers
-__all__ = [
- 'ActionExecutionStatePublisher'
-]
+__all__ = ["ActionExecutionStatePublisher"]
-ACTIONEXECUTIONSTATE_XCHG = Exchange('st2.actionexecutionstate',
- type='topic')
+ACTIONEXECUTIONSTATE_XCHG = Exchange("st2.actionexecutionstate", type="topic")
class ActionExecutionStatePublisher(publishers.CUDPublisher):
-
def __init__(self):
- super(ActionExecutionStatePublisher, self).__init__(exchange=ACTIONEXECUTIONSTATE_XCHG)
+ super(ActionExecutionStatePublisher, self).__init__(
+ exchange=ACTIONEXECUTIONSTATE_XCHG
+ )
def get_queue(name, routing_key):
diff --git a/st2common/st2common/transport/announcement.py b/st2common/st2common/transport/announcement.py
index 84c8bf27a7..e79506c608 100644
--- a/st2common/st2common/transport/announcement.py
+++ b/st2common/st2common/transport/announcement.py
@@ -22,17 +22,12 @@
from st2common.models.api.trace import TraceContext
from st2common.transport import publishers
-__all__ = [
- 'AnnouncementPublisher',
- 'AnnouncementDispatcher',
-
- 'get_queue'
-]
+__all__ = ["AnnouncementPublisher", "AnnouncementDispatcher", "get_queue"]
LOG = logging.getLogger(__name__)
# Exchange for Announcements
-ANNOUNCEMENT_XCHG = Exchange('st2.announcement', type='topic')
+ANNOUNCEMENT_XCHG = Exchange("st2.announcement", type="topic")
class AnnouncementPublisher(object):
@@ -68,16 +63,19 @@ def dispatch(self, routing_key, payload, trace_context=None):
assert isinstance(payload, (type(None), dict))
assert isinstance(trace_context, (type(None), dict, TraceContext))
- payload = {
- 'payload': payload,
- TRACE_CONTEXT: trace_context
- }
+ payload = {"payload": payload, TRACE_CONTEXT: trace_context}
- self._logger.debug('Dispatching announcement (routing_key=%s,payload=%s)',
- routing_key, payload)
+ self._logger.debug(
+ "Dispatching announcement (routing_key=%s,payload=%s)", routing_key, payload
+ )
self._publisher.publish(payload=payload, routing_key=routing_key)
-def get_queue(name=None, routing_key='#', exclusive=False, auto_delete=False):
- return Queue(name, ANNOUNCEMENT_XCHG, routing_key=routing_key, exclusive=exclusive,
- auto_delete=auto_delete)
+def get_queue(name=None, routing_key="#", exclusive=False, auto_delete=False):
+ return Queue(
+ name,
+ ANNOUNCEMENT_XCHG,
+ routing_key=routing_key,
+ exclusive=exclusive,
+ auto_delete=auto_delete,
+ )
diff --git a/st2common/st2common/transport/bootstrap.py b/st2common/st2common/transport/bootstrap.py
index 4c75072fe9..20d9277fae 100644
--- a/st2common/st2common/transport/bootstrap.py
+++ b/st2common/st2common/transport/bootstrap.py
@@ -24,8 +24,9 @@ def _setup():
config.parse_args()
# 2. setup logging.
- logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s',
- level=logging.DEBUG)
+ logging.basicConfig(
+ format="%(asctime)s %(levelname)s [-] %(message)s", level=logging.DEBUG
+ )
def main():
@@ -34,5 +35,5 @@ def main():
# The scripts sets up Exchanges in RabbitMQ.
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/st2common/st2common/transport/bootstrap_utils.py b/st2common/st2common/transport/bootstrap_utils.py
index d787adc493..2eea9ad64b 100644
--- a/st2common/st2common/transport/bootstrap_utils.py
+++ b/st2common/st2common/transport/bootstrap_utils.py
@@ -50,15 +50,14 @@
from st2common.transport.queues import WORKFLOW_EXECUTION_WORK_QUEUE
from st2common.transport.queues import WORKFLOW_EXECUTION_RESUME_QUEUE
-LOG = logging.getLogger('st2common.transport.bootstrap')
+LOG = logging.getLogger("st2common.transport.bootstrap")
__all__ = [
- 'register_exchanges',
- 'register_exchanges_with_retry',
- 'register_kombu_serializers',
-
- 'EXCHANGES',
- 'QUEUES'
+ "register_exchanges",
+ "register_exchanges_with_retry",
+ "register_kombu_serializers",
+ "EXCHANGES",
+ "QUEUES",
]
# List of exchanges which are pre-declared on service set up.
@@ -72,7 +71,7 @@
TRIGGER_INSTANCE_XCHG,
SENSOR_CUD_XCHG,
WORKFLOW_EXECUTION_XCHG,
- WORKFLOW_EXECUTION_STATUS_MGMT_XCHG
+ WORKFLOW_EXECUTION_STATUS_MGMT_XCHG,
]
# List of queues which are pre-declared on service startup.
@@ -85,41 +84,40 @@
NOTIFIER_ACTIONUPDATE_WORK_QUEUE,
RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE,
RULESENGINE_WORK_QUEUE,
-
STREAM_ANNOUNCEMENT_WORK_QUEUE,
STREAM_EXECUTION_ALL_WORK_QUEUE,
STREAM_LIVEACTION_WORK_QUEUE,
STREAM_EXECUTION_OUTPUT_QUEUE,
-
WORKFLOW_EXECUTION_WORK_QUEUE,
WORKFLOW_EXECUTION_RESUME_QUEUE,
-
# Those queues are dynamically / late created on some class init but we still need to
# pre-declare them for redis Kombu backend to work.
- reactor.get_trigger_cud_queue(name='st2.preinit', routing_key='init'),
- reactor.get_sensor_cud_queue(name='st2.preinit', routing_key='init')
+ reactor.get_trigger_cud_queue(name="st2.preinit", routing_key="init"),
+ reactor.get_sensor_cud_queue(name="st2.preinit", routing_key="init"),
]
def _do_register_exchange(exchange, connection, channel, retry_wrapper):
try:
kwargs = {
- 'exchange': exchange.name,
- 'type': exchange.type,
- 'durable': exchange.durable,
- 'auto_delete': exchange.auto_delete,
- 'arguments': exchange.arguments,
- 'nowait': False,
- 'passive': False
+ "exchange": exchange.name,
+ "type": exchange.type,
+ "durable": exchange.durable,
+ "auto_delete": exchange.auto_delete,
+ "arguments": exchange.arguments,
+ "nowait": False,
+ "passive": False,
}
# Use the retry wrapper to increase resiliency in recoverable errors.
- retry_wrapper.ensured(connection=connection,
- obj=channel,
- to_ensure_func=channel.exchange_declare,
- **kwargs)
- LOG.debug('Registered exchange %s (%s).' % (exchange.name, str(kwargs)))
+ retry_wrapper.ensured(
+ connection=connection,
+ obj=channel,
+ to_ensure_func=channel.exchange_declare,
+ **kwargs,
+ )
+ LOG.debug("Registered exchange %s (%s)." % (exchange.name, str(kwargs)))
except Exception:
- LOG.exception('Failed to register exchange: %s.', exchange.name)
+ LOG.exception("Failed to register exchange: %s.", exchange.name)
def _do_predeclare_queue(channel, queue):
@@ -132,23 +130,31 @@ def _do_predeclare_queue(channel, queue):
bound_queue.declare(nowait=False)
LOG.debug('Predeclared queue for exchange "%s"' % (queue.exchange.name))
except Exception:
- LOG.exception('Failed to predeclare queue for exchange "%s"' % (queue.exchange.name))
+ LOG.exception(
+ 'Failed to predeclare queue for exchange "%s"' % (queue.exchange.name)
+ )
return bound_queue
def register_exchanges():
- LOG.debug('Registering exchanges...')
+ LOG.debug("Registering exchanges...")
connection_urls = transport_utils.get_messaging_urls()
with transport_utils.get_connection() as conn:
# Use ConnectionRetryWrapper to deal with rmq clustering etc.
- retry_wrapper = ConnectionRetryWrapper(cluster_size=len(connection_urls), logger=LOG)
+ retry_wrapper = ConnectionRetryWrapper(
+ cluster_size=len(connection_urls), logger=LOG
+ )
def wrapped_register_exchanges(connection, channel):
for exchange in EXCHANGES:
- _do_register_exchange(exchange=exchange, connection=connection, channel=channel,
- retry_wrapper=retry_wrapper)
+ _do_register_exchange(
+ exchange=exchange,
+ connection=connection,
+ channel=channel,
+ retry_wrapper=retry_wrapper,
+ )
retry_wrapper.run(connection=conn, wrapped_callback=wrapped_register_exchanges)
@@ -166,7 +172,7 @@ def retry_if_io_error(exception):
retrying_obj = retrying.Retrying(
retry_on_exception=retry_if_io_error,
wait_fixed=cfg.CONF.messaging.connection_retry_wait,
- stop_max_attempt_number=cfg.CONF.messaging.connection_retries
+ stop_max_attempt_number=cfg.CONF.messaging.connection_retries,
)
return retrying_obj.call(register_exchanges)
@@ -181,24 +187,33 @@ def register_kombu_serializers():
https://github.com/celery/kombu/blob/3.0/kombu/utils/encoding.py#L47
"""
+
def pickle_dumps(obj, dumper=pickle.dumps):
return dumper(obj, protocol=pickle_protocol)
if six.PY3:
+
def str_to_bytes(s):
if isinstance(s, str):
- return s.encode('utf-8')
+ return s.encode("utf-8")
return s
def unpickle(s):
return pickle_loads(str_to_bytes(s))
+
else:
- def str_to_bytes(s): # noqa
- if isinstance(s, unicode): # noqa # pylint: disable=E0602
- return s.encode('utf-8')
+
+ def str_to_bytes(s): # noqa
+ if isinstance(s, unicode): # noqa # pylint: disable=E0602
+ return s.encode("utf-8")
return s
+
unpickle = pickle_loads # noqa
- register('pickle', pickle_dumps, unpickle,
- content_type='application/x-python-serialize',
- content_encoding='binary')
+ register(
+ "pickle",
+ pickle_dumps,
+ unpickle,
+ content_type="application/x-python-serialize",
+ content_encoding="binary",
+ )
diff --git a/st2common/st2common/transport/connection_retry_wrapper.py b/st2common/st2common/transport/connection_retry_wrapper.py
index d0c906fff6..492aa24f32 100644
--- a/st2common/st2common/transport/connection_retry_wrapper.py
+++ b/st2common/st2common/transport/connection_retry_wrapper.py
@@ -19,7 +19,7 @@
from st2common.util import concurrency
-__all__ = ['ConnectionRetryWrapper', 'ClusterRetryContext']
+__all__ = ["ConnectionRetryWrapper", "ClusterRetryContext"]
class ClusterRetryContext(object):
@@ -27,6 +27,7 @@ class ClusterRetryContext(object):
Stores retry context for cluster retries. It makes certain assumptions
on how cluster_size and retry should be determined.
"""
+
def __init__(self, cluster_size):
# No of nodes in a cluster
self.cluster_size = cluster_size
@@ -101,6 +102,7 @@ def wrapped_callback(connection, channel):
retry_wrapper.run(connection=connection, wrapped_callback=wrapped_callback)
"""
+
def __init__(self, cluster_size, logger, ensure_max_retries=3):
self._retry_context = ClusterRetryContext(cluster_size=cluster_size)
self._logger = logger
@@ -109,7 +111,7 @@ def __init__(self, cluster_size, logger, ensure_max_retries=3):
self._ensure_max_retries = ensure_max_retries
def errback(self, exc, interval):
- self._logger.error('Rabbitmq connection error: %s', exc.message)
+ self._logger.error("Rabbitmq connection error: %s", exc.message)
def run(self, connection, wrapped_callback):
"""
@@ -141,8 +143,10 @@ def run(self, connection, wrapped_callback):
raise
# -1, 0 and 1+ are handled properly by eventlet.sleep
- self._logger.debug('Received RabbitMQ server error, sleeping for %s seconds '
- 'before retrying: %s' % (wait, six.text_type(e)))
+ self._logger.debug(
+ "Received RabbitMQ server error, sleeping for %s seconds "
+ "before retrying: %s" % (wait, six.text_type(e))
+ )
concurrency.sleep(wait)
connection.close()
@@ -154,22 +158,28 @@ def run(self, connection, wrapped_callback):
def log_error_on_conn_failure(exc, interval):
self._logger.debug(
- 'Failed to re-establish connection to RabbitMQ server, '
- 'retrying in %s seconds: %s' % (interval, six.text_type(exc))
+ "Failed to re-establish connection to RabbitMQ server, "
+ "retrying in %s seconds: %s" % (interval, six.text_type(exc))
)
try:
# NOTE: This function blocks and tries to restablish a connection for
# indefinetly if "max_retries" argument is not specified
- connection.ensure_connection(max_retries=self._ensure_max_retries,
- errback=log_error_on_conn_failure)
+ connection.ensure_connection(
+ max_retries=self._ensure_max_retries,
+ errback=log_error_on_conn_failure,
+ )
except Exception:
- self._logger.exception('Connections to RabbitMQ cannot be re-established: %s',
- six.text_type(e))
+ self._logger.exception(
+ "Connections to RabbitMQ cannot be re-established: %s",
+ six.text_type(e),
+ )
raise
except Exception as e:
- self._logger.exception('Connections to RabbitMQ cannot be re-established: %s',
- six.text_type(e))
+ self._logger.exception(
+ "Connections to RabbitMQ cannot be re-established: %s",
+ six.text_type(e),
+ )
# Not being able to publish a message could be a significant issue for an app.
raise
finally:
@@ -177,7 +187,7 @@ def log_error_on_conn_failure(exc, interval):
try:
channel.close()
except Exception:
- self._logger.warning('Error closing channel.', exc_info=True)
+ self._logger.warning("Error closing channel.", exc_info=True)
def ensured(self, connection, obj, to_ensure_func, **kwargs):
"""
@@ -191,7 +201,6 @@ def ensured(self, connection, obj, to_ensure_func, **kwargs):
:type obj: Must support mixin kombu.abstract.MaybeChannelBound
"""
ensuring_func = connection.ensure(
- obj, to_ensure_func,
- errback=self.errback,
- max_retries=3)
+ obj, to_ensure_func, errback=self.errback, max_retries=3
+ )
ensuring_func(**kwargs)
diff --git a/st2common/st2common/transport/consumers.py b/st2common/st2common/transport/consumers.py
index 7f626f72a4..dd2f47cb55 100644
--- a/st2common/st2common/transport/consumers.py
+++ b/st2common/st2common/transport/consumers.py
@@ -25,12 +25,11 @@
from st2common.util import concurrency
__all__ = [
- 'QueueConsumer',
- 'StagedQueueConsumer',
- 'ActionsQueueConsumer',
-
- 'MessageHandler',
- 'StagedMessageHandler'
+ "QueueConsumer",
+ "StagedQueueConsumer",
+ "ActionsQueueConsumer",
+ "MessageHandler",
+ "StagedMessageHandler",
]
LOG = logging.getLogger(__name__)
@@ -47,7 +46,9 @@ def shutdown(self):
self._dispatcher.shutdown()
def get_consumers(self, Consumer, channel):
- consumer = Consumer(queues=self._queues, accept=['pickle'], callbacks=[self.process])
+ consumer = Consumer(
+ queues=self._queues, accept=["pickle"], callbacks=[self.process]
+ )
# use prefetch_count=1 for fair dispatch. This way workers that finish an item get the next
# task and the work does not get queued behind any single large item.
@@ -58,11 +59,15 @@ def get_consumers(self, Consumer, channel):
def process(self, body, message):
try:
if not isinstance(body, self._handler.message_type):
- raise TypeError('Received an unexpected type "%s" for payload.' % type(body))
+ raise TypeError(
+ 'Received an unexpected type "%s" for payload.' % type(body)
+ )
self._dispatcher.dispatch(self._process_message, body)
except:
- LOG.exception('%s failed to process message: %s', self.__class__.__name__, body)
+ LOG.exception(
+ "%s failed to process message: %s", self.__class__.__name__, body
+ )
finally:
# At this point we will always ack a message.
message.ack()
@@ -71,7 +76,9 @@ def _process_message(self, body):
try:
self._handler.process(body)
except:
- LOG.exception('%s failed to process message: %s', self.__class__.__name__, body)
+ LOG.exception(
+ "%s failed to process message: %s", self.__class__.__name__, body
+ )
class StagedQueueConsumer(QueueConsumer):
@@ -82,11 +89,15 @@ class StagedQueueConsumer(QueueConsumer):
def process(self, body, message):
try:
if not isinstance(body, self._handler.message_type):
- raise TypeError('Received an unexpected type "%s" for payload.' % type(body))
+ raise TypeError(
+ 'Received an unexpected type "%s" for payload.' % type(body)
+ )
response = self._handler.pre_ack_process(body)
self._dispatcher.dispatch(self._process_message, response)
except:
- LOG.exception('%s failed to process message: %s', self.__class__.__name__, body)
+ LOG.exception(
+ "%s failed to process message: %s", self.__class__.__name__, body
+ )
finally:
# At this point we will always ack a message.
message.ack()
@@ -110,17 +121,21 @@ def __init__(self, connection, queues, handler):
workflows_pool_size = cfg.CONF.actionrunner.workflows_pool_size
actions_pool_size = cfg.CONF.actionrunner.actions_pool_size
- self._workflows_dispatcher = BufferedDispatcher(dispatch_pool_size=workflows_pool_size,
- name='workflows-dispatcher')
- self._actions_dispatcher = BufferedDispatcher(dispatch_pool_size=actions_pool_size,
- name='actions-dispatcher')
+ self._workflows_dispatcher = BufferedDispatcher(
+ dispatch_pool_size=workflows_pool_size, name="workflows-dispatcher"
+ )
+ self._actions_dispatcher = BufferedDispatcher(
+ dispatch_pool_size=actions_pool_size, name="actions-dispatcher"
+ )
def process(self, body, message):
try:
if not isinstance(body, self._handler.message_type):
- raise TypeError('Received an unexpected type "%s" for payload.' % type(body))
+ raise TypeError(
+ 'Received an unexpected type "%s" for payload.' % type(body)
+ )
- action_is_workflow = getattr(body, 'action_is_workflow', False)
+ action_is_workflow = getattr(body, "action_is_workflow", False)
if action_is_workflow:
# Use workflow dispatcher queue
dispatcher = self._workflows_dispatcher
@@ -131,7 +146,9 @@ def process(self, body, message):
LOG.debug('Using BufferedDispatcher pool: "%s"', str(dispatcher))
dispatcher.dispatch(self._process_message, body)
except:
- LOG.exception('%s failed to process message: %s', self.__class__.__name__, body)
+ LOG.exception(
+ "%s failed to process message: %s", self.__class__.__name__, body
+ )
finally:
# At this point we will always ack a message.
message.ack()
@@ -149,11 +166,15 @@ class VariableMessageQueueConsumer(QueueConsumer):
def process(self, body, message):
try:
if not self._handler.message_types.get(type(body)):
- raise TypeError('Received an unexpected type "%s" for payload.' % type(body))
+ raise TypeError(
+ 'Received an unexpected type "%s" for payload.' % type(body)
+ )
self._dispatcher.dispatch(self._process_message, body)
except:
- LOG.exception('%s failed to process message: %s', self.__class__.__name__, body)
+ LOG.exception(
+ "%s failed to process message: %s", self.__class__.__name__, body
+ )
finally:
# At this point we will always ack a message.
message.ack()
@@ -164,12 +185,13 @@ class MessageHandler(object):
message_type = None
def __init__(self, connection, queues):
- self._queue_consumer = self.get_queue_consumer(connection=connection,
- queues=queues)
+ self._queue_consumer = self.get_queue_consumer(
+ connection=connection, queues=queues
+ )
self._consumer_thread = None
def start(self, wait=False):
- LOG.info('Starting %s...', self.__class__.__name__)
+ LOG.info("Starting %s...", self.__class__.__name__)
self._consumer_thread = concurrency.spawn(self._queue_consumer.run)
if wait:
@@ -179,7 +201,7 @@ def wait(self):
self._consumer_thread.wait()
def shutdown(self):
- LOG.info('Shutting down %s...', self.__class__.__name__)
+ LOG.info("Shutting down %s...", self.__class__.__name__)
self._queue_consumer.shutdown()
@abc.abstractmethod
@@ -224,4 +246,6 @@ class VariableMessageHandler(MessageHandler):
"""
def get_queue_consumer(self, connection, queues):
- return VariableMessageQueueConsumer(connection=connection, queues=queues, handler=self)
+ return VariableMessageQueueConsumer(
+ connection=connection, queues=queues, handler=self
+ )
diff --git a/st2common/st2common/transport/execution.py b/st2common/st2common/transport/execution.py
index e35279ac71..5d2880fd6f 100644
--- a/st2common/st2common/transport/execution.py
+++ b/st2common/st2common/transport/execution.py
@@ -20,15 +20,14 @@
from st2common.transport import publishers
__all__ = [
- 'ActionExecutionPublisher',
- 'ActionExecutionOutputPublisher',
-
- 'get_queue',
- 'get_output_queue'
+ "ActionExecutionPublisher",
+ "ActionExecutionOutputPublisher",
+ "get_queue",
+ "get_output_queue",
]
-EXECUTION_XCHG = Exchange('st2.execution', type='topic')
-EXECUTION_OUTPUT_XCHG = Exchange('st2.execution.output', type='topic')
+EXECUTION_XCHG = Exchange("st2.execution", type="topic")
+EXECUTION_OUTPUT_XCHG = Exchange("st2.execution.output", type="topic")
class ActionExecutionPublisher(publishers.CUDPublisher):
@@ -38,14 +37,26 @@ def __init__(self):
class ActionExecutionOutputPublisher(publishers.CUDPublisher):
def __init__(self):
- super(ActionExecutionOutputPublisher, self).__init__(exchange=EXECUTION_OUTPUT_XCHG)
+ super(ActionExecutionOutputPublisher, self).__init__(
+ exchange=EXECUTION_OUTPUT_XCHG
+ )
def get_queue(name=None, routing_key=None, exclusive=False, auto_delete=False):
- return Queue(name, EXECUTION_XCHG, routing_key=routing_key, exclusive=exclusive,
- auto_delete=auto_delete)
+ return Queue(
+ name,
+ EXECUTION_XCHG,
+ routing_key=routing_key,
+ exclusive=exclusive,
+ auto_delete=auto_delete,
+ )
def get_output_queue(name=None, routing_key=None, exclusive=False, auto_delete=False):
- return Queue(name, EXECUTION_OUTPUT_XCHG, routing_key=routing_key, exclusive=exclusive,
- auto_delete=auto_delete)
+ return Queue(
+ name,
+ EXECUTION_OUTPUT_XCHG,
+ routing_key=routing_key,
+ exclusive=exclusive,
+ auto_delete=auto_delete,
+ )
diff --git a/st2common/st2common/transport/liveaction.py b/st2common/st2common/transport/liveaction.py
index 97dd08400b..670c5ebb2e 100644
--- a/st2common/st2common/transport/liveaction.py
+++ b/st2common/st2common/transport/liveaction.py
@@ -21,23 +21,19 @@
from st2common.transport import publishers
-__all__ = [
- 'LiveActionPublisher',
+__all__ = ["LiveActionPublisher", "get_queue", "get_status_management_queue"]
- 'get_queue',
- 'get_status_management_queue'
-]
-
-LIVEACTION_XCHG = Exchange('st2.liveaction', type='topic')
-LIVEACTION_STATUS_MGMT_XCHG = Exchange('st2.liveaction.status', type='topic')
+LIVEACTION_XCHG = Exchange("st2.liveaction", type="topic")
+LIVEACTION_STATUS_MGMT_XCHG = Exchange("st2.liveaction.status", type="topic")
class LiveActionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin):
-
def __init__(self):
publishers.CUDPublisher.__init__(self, exchange=LIVEACTION_XCHG)
- publishers.StatePublisherMixin.__init__(self, exchange=LIVEACTION_STATUS_MGMT_XCHG)
+ publishers.StatePublisherMixin.__init__(
+ self, exchange=LIVEACTION_STATUS_MGMT_XCHG
+ )
def get_queue(name, routing_key):
diff --git a/st2common/st2common/transport/publishers.py b/st2common/st2common/transport/publishers.py
index 7942fdfffe..202220acb1 100644
--- a/st2common/st2common/transport/publishers.py
+++ b/st2common/st2common/transport/publishers.py
@@ -25,16 +25,16 @@
from st2common.transport.connection_retry_wrapper import ConnectionRetryWrapper
__all__ = [
- 'PoolPublisher',
- 'SharedPoolPublishers',
- 'CUDPublisher',
- 'StatePublisherMixin'
+ "PoolPublisher",
+ "SharedPoolPublishers",
+ "CUDPublisher",
+ "StatePublisherMixin",
]
-ANY_RK = '*'
-CREATE_RK = 'create'
-UPDATE_RK = 'update'
-DELETE_RK = 'delete'
+ANY_RK = "*"
+CREATE_RK = "create"
+UPDATE_RK = "update"
+DELETE_RK = "delete"
LOG = logging.getLogger(__name__)
@@ -47,19 +47,21 @@ def __init__(self, urls=None):
:type urls: ``list``
"""
urls = urls or transport_utils.get_messaging_urls()
- connection = transport_utils.get_connection(urls=urls,
- connection_kwargs={'failover_strategy':
- 'round-robin'})
+ connection = transport_utils.get_connection(
+ urls=urls, connection_kwargs={"failover_strategy": "round-robin"}
+ )
self.pool = connection.Pool(limit=10)
self.cluster_size = len(urls)
def errback(self, exc, interval):
- LOG.error('Rabbitmq connection error: %s', exc.message, exc_info=False)
+ LOG.error("Rabbitmq connection error: %s", exc.message, exc_info=False)
- def publish(self, payload, exchange, routing_key=''):
- with Timer(key='amqp.pool_publisher.publish_with_retries.' + exchange.name):
+ def publish(self, payload, exchange, routing_key=""):
+ with Timer(key="amqp.pool_publisher.publish_with_retries." + exchange.name):
with self.pool.acquire(block=True) as connection:
- retry_wrapper = ConnectionRetryWrapper(cluster_size=self.cluster_size, logger=LOG)
+ retry_wrapper = ConnectionRetryWrapper(
+ cluster_size=self.cluster_size, logger=LOG
+ )
def do_publish(connection, channel):
# ProducerPool ends up creating it own ConnectionPool which ends up
@@ -68,18 +70,18 @@ def do_publish(connection, channel):
# Producer for each publish.
producer = Producer(channel)
kwargs = {
- 'body': payload,
- 'exchange': exchange,
- 'routing_key': routing_key,
- 'serializer': 'pickle',
- 'content_encoding': 'utf-8'
+ "body": payload,
+ "exchange": exchange,
+ "routing_key": routing_key,
+ "serializer": "pickle",
+ "content_encoding": "utf-8",
}
retry_wrapper.ensured(
connection=connection,
obj=producer,
to_ensure_func=producer.publish,
- **kwargs
+ **kwargs,
)
retry_wrapper.run(connection=connection, wrapped_callback=do_publish)
@@ -91,6 +93,7 @@ class SharedPoolPublishers(object):
server is usually the same. This sharing allows from the same PoolPublisher to be reused
for publishing purposes. Sharing publishers leads to shared connections.
"""
+
shared_publishers = {}
def get_publisher(self, urls):
@@ -99,7 +102,7 @@ def get_publisher(self, urls):
# ordering in supplied list.
urls_copy = copy.copy(urls)
urls_copy.sort()
- publisher_key = ''.join(urls_copy)
+ publisher_key = "".join(urls_copy)
publisher = self.shared_publishers.get(publisher_key, None)
if not publisher:
# Use original urls here to preserve order.
@@ -115,15 +118,15 @@ def __init__(self, exchange):
self._exchange = exchange
def publish_create(self, payload):
- with Timer(key='amqp.publish.create'):
+ with Timer(key="amqp.publish.create"):
self._publisher.publish(payload, self._exchange, CREATE_RK)
def publish_update(self, payload):
- with Timer(key='amqp.publish.update'):
+ with Timer(key="amqp.publish.update"):
self._publisher.publish(payload, self._exchange, UPDATE_RK)
def publish_delete(self, payload):
- with Timer(key='amqp.publish.delete'):
+ with Timer(key="amqp.publish.delete"):
self._publisher.publish(payload, self._exchange, DELETE_RK)
@@ -135,6 +138,6 @@ def __init__(self, exchange):
def publish_state(self, payload, state):
if not state:
- raise Exception('Unable to publish unassigned state.')
- with Timer(key='amqp.publish.state'):
+ raise Exception("Unable to publish unassigned state.")
+ with Timer(key="amqp.publish.state"):
self._state_publisher.publish(payload, self._state_exchange, state)
diff --git a/st2common/st2common/transport/queues.py b/st2common/st2common/transport/queues.py
index faf6d27fbf..f6f9bcb4ef 100644
--- a/st2common/st2common/transport/queues.py
+++ b/st2common/st2common/transport/queues.py
@@ -34,120 +34,109 @@
from st2common.transport import workflow
__all__ = [
- 'ACTIONSCHEDULER_REQUEST_QUEUE',
-
- 'ACTIONRUNNER_WORK_QUEUE',
- 'ACTIONRUNNER_CANCEL_QUEUE',
- 'ACTIONRUNNER_PAUSE_QUEUE',
- 'ACTIONRUNNER_RESUME_QUEUE',
-
- 'EXPORTER_WORK_QUEUE',
-
- 'NOTIFIER_ACTIONUPDATE_WORK_QUEUE',
-
- 'RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE',
-
- 'RULESENGINE_WORK_QUEUE',
-
- 'STREAM_ANNOUNCEMENT_WORK_QUEUE',
- 'STREAM_EXECUTION_ALL_WORK_QUEUE',
- 'STREAM_EXECUTION_UPDATE_WORK_QUEUE',
- 'STREAM_LIVEACTION_WORK_QUEUE',
-
- 'WORKFLOW_EXECUTION_WORK_QUEUE',
- 'WORKFLOW_EXECUTION_RESUME_QUEUE'
+ "ACTIONSCHEDULER_REQUEST_QUEUE",
+ "ACTIONRUNNER_WORK_QUEUE",
+ "ACTIONRUNNER_CANCEL_QUEUE",
+ "ACTIONRUNNER_PAUSE_QUEUE",
+ "ACTIONRUNNER_RESUME_QUEUE",
+ "EXPORTER_WORK_QUEUE",
+ "NOTIFIER_ACTIONUPDATE_WORK_QUEUE",
+ "RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE",
+ "RULESENGINE_WORK_QUEUE",
+ "STREAM_ANNOUNCEMENT_WORK_QUEUE",
+ "STREAM_EXECUTION_ALL_WORK_QUEUE",
+ "STREAM_EXECUTION_UPDATE_WORK_QUEUE",
+ "STREAM_LIVEACTION_WORK_QUEUE",
+ "WORKFLOW_EXECUTION_WORK_QUEUE",
+ "WORKFLOW_EXECUTION_RESUME_QUEUE",
]
# Used by the action scheduler service
ACTIONSCHEDULER_REQUEST_QUEUE = liveaction.get_status_management_queue(
- 'st2.actionrunner.req',
- routing_key=action_constants.LIVEACTION_STATUS_REQUESTED)
+ "st2.actionrunner.req", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED
+)
# Used by the action runner service
ACTIONRUNNER_WORK_QUEUE = liveaction.get_status_management_queue(
- 'st2.actionrunner.work',
- routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED)
+ "st2.actionrunner.work", routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED
+)
ACTIONRUNNER_CANCEL_QUEUE = liveaction.get_status_management_queue(
- 'st2.actionrunner.cancel',
- routing_key=action_constants.LIVEACTION_STATUS_CANCELING)
+ "st2.actionrunner.cancel", routing_key=action_constants.LIVEACTION_STATUS_CANCELING
+)
ACTIONRUNNER_PAUSE_QUEUE = liveaction.get_status_management_queue(
- 'st2.actionrunner.pause',
- routing_key=action_constants.LIVEACTION_STATUS_PAUSING)
+ "st2.actionrunner.pause", routing_key=action_constants.LIVEACTION_STATUS_PAUSING
+)
ACTIONRUNNER_RESUME_QUEUE = liveaction.get_status_management_queue(
- 'st2.actionrunner.resume',
- routing_key=action_constants.LIVEACTION_STATUS_RESUMING)
+ "st2.actionrunner.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING
+)
# Used by the exporter service
EXPORTER_WORK_QUEUE = execution.get_queue(
- 'st2.exporter.work',
- routing_key=publishers.UPDATE_RK)
+ "st2.exporter.work", routing_key=publishers.UPDATE_RK
+)
# Used by the notifier service
NOTIFIER_ACTIONUPDATE_WORK_QUEUE = execution.get_queue(
- 'st2.notifiers.execution.work',
- routing_key=publishers.UPDATE_RK)
+ "st2.notifiers.execution.work", routing_key=publishers.UPDATE_RK
+)
# Used by the results tracker service
RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE = actionexecutionstate.get_queue(
- 'st2.resultstracker.work',
- routing_key=publishers.CREATE_RK)
+ "st2.resultstracker.work", routing_key=publishers.CREATE_RK
+)
# Used by the rules engine service
RULESENGINE_WORK_QUEUE = reactor.get_trigger_instances_queue(
- name='st2.trigger_instances_dispatch.rules_engine',
- routing_key='#')
+ name="st2.trigger_instances_dispatch.rules_engine", routing_key="#"
+)
# Used by the stream service
STREAM_ANNOUNCEMENT_WORK_QUEUE = announcement.get_queue(
- routing_key=publishers.ANY_RK,
- exclusive=True,
- auto_delete=True)
+ routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True
+)
STREAM_EXECUTION_ALL_WORK_QUEUE = execution.get_queue(
- routing_key=publishers.ANY_RK,
- exclusive=True,
- auto_delete=True)
+ routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True
+)
STREAM_EXECUTION_UPDATE_WORK_QUEUE = execution.get_queue(
- routing_key=publishers.UPDATE_RK,
- exclusive=True,
- auto_delete=True)
+ routing_key=publishers.UPDATE_RK, exclusive=True, auto_delete=True
+)
STREAM_LIVEACTION_WORK_QUEUE = Queue(
None,
liveaction.LIVEACTION_XCHG,
routing_key=publishers.ANY_RK,
exclusive=True,
- auto_delete=True)
+ auto_delete=True,
+)
# TODO: Perhaps we should use pack.action name as routing key
# so we can do more efficient filtering later, if needed
STREAM_EXECUTION_OUTPUT_QUEUE = execution.get_output_queue(
- name=None,
- routing_key=publishers.CREATE_RK,
- exclusive=True,
- auto_delete=True)
+ name=None, routing_key=publishers.CREATE_RK, exclusive=True, auto_delete=True
+)
# Used by the workflow engine service
WORKFLOW_EXECUTION_WORK_QUEUE = workflow.get_status_management_queue(
- name='st2.workflow.work',
- routing_key=action_constants.LIVEACTION_STATUS_REQUESTED)
+ name="st2.workflow.work", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED
+)
WORKFLOW_EXECUTION_RESUME_QUEUE = workflow.get_status_management_queue(
- name='st2.workflow.resume',
- routing_key=action_constants.LIVEACTION_STATUS_RESUMING)
+ name="st2.workflow.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING
+)
WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE = execution.get_queue(
- 'st2.workflow.action.update',
- routing_key=publishers.UPDATE_RK)
+ "st2.workflow.action.update", routing_key=publishers.UPDATE_RK
+)
diff --git a/st2common/st2common/transport/reactor.py b/st2common/st2common/transport/reactor.py
index 613a1d08ed..c9dc84725c 100644
--- a/st2common/st2common/transport/reactor.py
+++ b/st2common/st2common/transport/reactor.py
@@ -22,26 +22,24 @@
from st2common.transport import publishers
__all__ = [
- 'TriggerCUDPublisher',
- 'TriggerInstancePublisher',
-
- 'TriggerDispatcher',
-
- 'get_sensor_cud_queue',
- 'get_trigger_cud_queue',
- 'get_trigger_instances_queue'
+ "TriggerCUDPublisher",
+ "TriggerInstancePublisher",
+ "TriggerDispatcher",
+ "get_sensor_cud_queue",
+ "get_trigger_cud_queue",
+ "get_trigger_instances_queue",
]
LOG = logging.getLogger(__name__)
# Exchange for Trigger CUD events
-TRIGGER_CUD_XCHG = Exchange('st2.trigger', type='topic')
+TRIGGER_CUD_XCHG = Exchange("st2.trigger", type="topic")
# Exchange for TriggerInstance events
-TRIGGER_INSTANCE_XCHG = Exchange('st2.trigger_instances_dispatch', type='topic')
+TRIGGER_INSTANCE_XCHG = Exchange("st2.trigger_instances_dispatch", type="topic")
# Exchane for Sensor CUD events
-SENSOR_CUD_XCHG = Exchange('st2.sensor', type='topic')
+SENSOR_CUD_XCHG = Exchange("st2.sensor", type="topic")
class SensorCUDPublisher(publishers.CUDPublisher):
@@ -96,14 +94,12 @@ def dispatch(self, trigger, payload=None, trace_context=None):
assert isinstance(payload, (type(None), dict))
assert isinstance(trace_context, (type(None), TraceContext))
- payload = {
- 'trigger': trigger,
- 'payload': payload,
- TRACE_CONTEXT: trace_context
- }
- routing_key = 'trigger_instance'
+ payload = {"trigger": trigger, "payload": payload, TRACE_CONTEXT: trace_context}
+ routing_key = "trigger_instance"
- self._logger.debug('Dispatching trigger (trigger=%s,payload=%s)', trigger, payload)
+ self._logger.debug(
+ "Dispatching trigger (trigger=%s,payload=%s)", trigger, payload
+ )
self._publisher.publish_trigger(payload=payload, routing_key=routing_key)
diff --git a/st2common/st2common/transport/utils.py b/st2common/st2common/transport/utils.py
index bea2df1e57..e479713ddc 100644
--- a/st2common/st2common/transport/utils.py
+++ b/st2common/st2common/transport/utils.py
@@ -22,22 +22,18 @@
from st2common import log as logging
-__all__ = [
- 'get_connection',
-
- 'get_messaging_urls'
-]
+__all__ = ["get_connection", "get_messaging_urls"]
LOG = logging.getLogger(__name__)
def get_messaging_urls():
- '''
+ """
Determines the right messaging urls to supply. In case the `cluster_urls` config is
specified then that is used. Else the single `url` property is used.
:rtype: ``list``
- '''
+ """
if cfg.CONF.messaging.cluster_urls:
return cfg.CONF.messaging.cluster_urls
return [cfg.CONF.messaging.url]
@@ -57,33 +53,41 @@ def get_connection(urls=None, connection_kwargs=None):
kwargs = {}
- ssl_kwargs = _get_ssl_kwargs(ssl=cfg.CONF.messaging.ssl,
- ssl_keyfile=cfg.CONF.messaging.ssl_keyfile,
- ssl_certfile=cfg.CONF.messaging.ssl_certfile,
- ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs,
- ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs,
- login_method=cfg.CONF.messaging.login_method)
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl=cfg.CONF.messaging.ssl,
+ ssl_keyfile=cfg.CONF.messaging.ssl_keyfile,
+ ssl_certfile=cfg.CONF.messaging.ssl_certfile,
+ ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs,
+ ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs,
+ login_method=cfg.CONF.messaging.login_method,
+ )
# NOTE: "connection_kwargs" argument passed to this function has precedence over config values
- if len(ssl_kwargs) == 1 and ssl_kwargs['ssl'] is True:
- kwargs.update({'ssl': True})
+ if len(ssl_kwargs) == 1 and ssl_kwargs["ssl"] is True:
+ kwargs.update({"ssl": True})
elif len(ssl_kwargs) >= 2:
- ssl_kwargs.pop('ssl')
- kwargs.update({'ssl': ssl_kwargs})
+ ssl_kwargs.pop("ssl")
+ kwargs.update({"ssl": ssl_kwargs})
- kwargs['login_method'] = cfg.CONF.messaging.login_method
+ kwargs["login_method"] = cfg.CONF.messaging.login_method
kwargs.update(connection_kwargs)
# NOTE: This line contains no secret values so it's OK to log it
- LOG.debug('Using SSL context for RabbitMQ connection: %s' % (ssl_kwargs))
+ LOG.debug("Using SSL context for RabbitMQ connection: %s" % (ssl_kwargs))
connection = Connection(urls, **kwargs)
return connection
-def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None,
- ssl_ca_certs=None, login_method=None):
+def _get_ssl_kwargs(
+ ssl=False,
+ ssl_keyfile=None,
+ ssl_certfile=None,
+ ssl_cert_reqs=None,
+ ssl_ca_certs=None,
+ login_method=None,
+):
"""
Return SSL keyword arguments to be used with the kombu.Connection class.
"""
@@ -93,27 +97,27 @@ def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_req
# because user could still specify to use SSL by including "?ssl=true" query param at the
# end of the connection URL string
if ssl is True:
- ssl_kwargs['ssl'] = True
+ ssl_kwargs["ssl"] = True
if ssl_keyfile:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['keyfile'] = ssl_keyfile
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["keyfile"] = ssl_keyfile
if ssl_certfile:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['certfile'] = ssl_certfile
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["certfile"] = ssl_certfile
if ssl_cert_reqs:
- if ssl_cert_reqs == 'none':
+ if ssl_cert_reqs == "none":
ssl_cert_reqs = ssl_lib.CERT_NONE
- elif ssl_cert_reqs == 'optional':
+ elif ssl_cert_reqs == "optional":
ssl_cert_reqs = ssl_lib.CERT_OPTIONAL
- elif ssl_cert_reqs == 'required':
+ elif ssl_cert_reqs == "required":
ssl_cert_reqs = ssl_lib.CERT_REQUIRED
- ssl_kwargs['cert_reqs'] = ssl_cert_reqs
+ ssl_kwargs["cert_reqs"] = ssl_cert_reqs
if ssl_ca_certs:
- ssl_kwargs['ssl'] = True
- ssl_kwargs['ca_certs'] = ssl_ca_certs
+ ssl_kwargs["ssl"] = True
+ ssl_kwargs["ca_certs"] = ssl_ca_certs
return ssl_kwargs
diff --git a/st2common/st2common/transport/workflow.py b/st2common/st2common/transport/workflow.py
index 2b9815fcb7..0302611a36 100644
--- a/st2common/st2common/transport/workflow.py
+++ b/st2common/st2common/transport/workflow.py
@@ -21,22 +21,22 @@
from st2common.transport import publishers
-__all__ = [
- 'WorkflowExecutionPublisher',
+__all__ = ["WorkflowExecutionPublisher", "get_queue", "get_status_management_queue"]
- 'get_queue',
- 'get_status_management_queue'
-]
+WORKFLOW_EXECUTION_XCHG = kombu.Exchange("st2.workflow", type="topic")
+WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange(
+ "st2.workflow.status", type="topic"
+)
-WORKFLOW_EXECUTION_XCHG = kombu.Exchange('st2.workflow', type='topic')
-WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange('st2.workflow.status', type='topic')
-
-
-class WorkflowExecutionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin):
+class WorkflowExecutionPublisher(
+ publishers.CUDPublisher, publishers.StatePublisherMixin
+):
def __init__(self):
publishers.CUDPublisher.__init__(self, exchange=WORKFLOW_EXECUTION_XCHG)
- publishers.StatePublisherMixin.__init__(self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG)
+ publishers.StatePublisherMixin.__init__(
+ self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG
+ )
def get_queue(name, routing_key):
@@ -44,4 +44,6 @@ def get_queue(name, routing_key):
def get_status_management_queue(name, routing_key):
- return kombu.Queue(name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key)
+ return kombu.Queue(
+ name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key
+ )
diff --git a/st2common/st2common/triggers.py b/st2common/st2common/triggers.py
index a18dadedb9..ec0dba378e 100644
--- a/st2common/st2common/triggers.py
+++ b/st2common/st2common/triggers.py
@@ -22,52 +22,63 @@
from oslo_config import cfg
from st2common import log as logging
-from st2common.constants.triggers import (INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER)
+from st2common.constants.triggers import INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER
from st2common.exceptions.db import StackStormDBObjectConflictError
from st2common.services.triggers import create_trigger_type_db
from st2common.services.triggers import create_shadow_trigger
from st2common.services.triggers import get_trigger_type_db
from st2common.models.system.common import ResourceReference
-__all__ = [
- 'register_internal_trigger_types'
-]
+__all__ = ["register_internal_trigger_types"]
LOG = logging.getLogger(__name__)
def _register_internal_trigger_type(trigger_definition):
try:
- trigger_type_db = create_trigger_type_db(trigger_type=trigger_definition,
- log_not_unique_error_as_debug=True)
+ trigger_type_db = create_trigger_type_db(
+ trigger_type=trigger_definition, log_not_unique_error_as_debug=True
+ )
except (NotUniqueError, StackStormDBObjectConflictError):
# We ignore conflict error since this operation is idempotent and race is not an issue
- LOG.debug('Internal trigger type "%s" already exists, ignoring error...' %
- (trigger_definition['name']))
-
- ref = ResourceReference.to_string_reference(name=trigger_definition['name'],
- pack=trigger_definition['pack'])
+ LOG.debug(
+ 'Internal trigger type "%s" already exists, ignoring error...'
+ % (trigger_definition["name"])
+ )
+
+ ref = ResourceReference.to_string_reference(
+ name=trigger_definition["name"], pack=trigger_definition["pack"]
+ )
trigger_type_db = get_trigger_type_db(ref)
if trigger_type_db:
- LOG.debug('Registered internal trigger: %s.', trigger_definition['name'])
+ LOG.debug("Registered internal trigger: %s.", trigger_definition["name"])
# trigger types with parameters do no require a shadow trigger.
if trigger_type_db and not trigger_type_db.parameters_schema:
try:
- trigger_db = create_shadow_trigger(trigger_type_db,
- log_not_unique_error_as_debug=True)
-
- extra = {'trigger_db': trigger_db}
- LOG.audit('Trigger created for parameter-less internal TriggerType. Trigger.id=%s' %
- (trigger_db.id), extra=extra)
+ trigger_db = create_shadow_trigger(
+ trigger_type_db, log_not_unique_error_as_debug=True
+ )
+
+ extra = {"trigger_db": trigger_db}
+ LOG.audit(
+ "Trigger created for parameter-less internal TriggerType. Trigger.id=%s"
+ % (trigger_db.id),
+ extra=extra,
+ )
except (NotUniqueError, StackStormDBObjectConflictError):
- LOG.debug('Shadow trigger "%s" already exists. Ignoring.',
- trigger_type_db.get_reference().ref, exc_info=True)
+ LOG.debug(
+ 'Shadow trigger "%s" already exists. Ignoring.',
+ trigger_type_db.get_reference().ref,
+ exc_info=True,
+ )
except (ValidationError, ValueError):
- LOG.exception('Validation failed in shadow trigger. TriggerType=%s.',
- trigger_type_db.get_reference().ref)
+ LOG.exception(
+ "Validation failed in shadow trigger. TriggerType=%s.",
+ trigger_type_db.get_reference().ref,
+ )
raise
return trigger_type_db
@@ -89,16 +100,21 @@ def register_internal_trigger_types():
for _, trigger_definitions in six.iteritems(INTERNAL_TRIGGER_TYPES):
for trigger_definition in trigger_definitions:
- LOG.debug('Registering internal trigger: %s', trigger_definition['name'])
+ LOG.debug("Registering internal trigger: %s", trigger_definition["name"])
- is_action_trigger = trigger_definition['name'] == ACTION_SENSOR_TRIGGER['name']
+ is_action_trigger = (
+ trigger_definition["name"] == ACTION_SENSOR_TRIGGER["name"]
+ )
if is_action_trigger and not action_sensor_enabled:
continue
try:
trigger_type_db = _register_internal_trigger_type(
- trigger_definition=trigger_definition)
+ trigger_definition=trigger_definition
+ )
except Exception:
- LOG.exception('Failed registering internal trigger: %s.', trigger_definition)
+ LOG.exception(
+ "Failed registering internal trigger: %s.", trigger_definition
+ )
raise
else:
registered_trigger_types_db.append(trigger_type_db)
diff --git a/st2common/st2common/util/action_db.py b/st2common/st2common/util/action_db.py
index 610b698c18..4880693348 100644
--- a/st2common/st2common/util/action_db.py
+++ b/st2common/st2common/util/action_db.py
@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except ImportError:
@@ -42,15 +43,15 @@
__all__ = [
- 'get_action_parameters_specs',
- 'get_runnertype_by_id',
- 'get_runnertype_by_name',
- 'get_action_by_id',
- 'get_action_by_ref',
- 'get_liveaction_by_id',
- 'update_liveaction_status',
- 'serialize_positional_argument',
- 'get_args'
+ "get_action_parameters_specs",
+ "get_runnertype_by_id",
+ "get_runnertype_by_name",
+ "get_action_by_id",
+ "get_action_by_ref",
+ "get_liveaction_by_id",
+ "update_liveaction_status",
+ "serialize_positional_argument",
+ "get_args",
]
@@ -71,11 +72,11 @@ def get_action_parameters_specs(action_ref):
if not action_db:
return parameters
- runner_type_name = action_db.runner_type['name']
+ runner_type_name = action_db.runner_type["name"]
runner_type_db = get_runnertype_by_name(runnertype_name=runner_type_name)
# Runner type parameters should be added first before the action parameters.
- parameters.update(runner_type_db['runner_parameters'])
+ parameters.update(runner_type_db["runner_parameters"])
parameters.update(action_db.parameters)
return parameters
@@ -83,60 +84,76 @@ def get_action_parameters_specs(action_ref):
def get_runnertype_by_id(runnertype_id):
"""
- Get RunnerType by id.
+ Get RunnerType by id.
- On error, raise StackStormDBObjectNotFoundError
+ On error, raise StackStormDBObjectNotFoundError
"""
try:
runnertype = RunnerType.get_by_id(runnertype_id)
except (ValueError, ValidationError) as e:
- LOG.warning('Database lookup for runnertype with id="%s" resulted in '
- 'exception: %s', runnertype_id, e)
- raise StackStormDBObjectNotFoundError('Unable to find runnertype with '
- 'id="%s"' % runnertype_id)
+ LOG.warning(
+ 'Database lookup for runnertype with id="%s" resulted in ' "exception: %s",
+ runnertype_id,
+ e,
+ )
+ raise StackStormDBObjectNotFoundError(
+ "Unable to find runnertype with " 'id="%s"' % runnertype_id
+ )
return runnertype
def get_runnertype_by_name(runnertype_name):
"""
- Get an runnertype by name.
- On error, raise ST2ObjectNotFoundError.
+ Get an runnertype by name.
+ On error, raise ST2ObjectNotFoundError.
"""
try:
runnertypes = RunnerType.query(name=runnertype_name)
except (ValueError, ValidationError) as e:
- LOG.error('Database lookup for name="%s" resulted in exception: %s',
- runnertype_name, e)
- raise StackStormDBObjectNotFoundError('Unable to find runnertype with name="%s"'
- % runnertype_name)
+ LOG.error(
+ 'Database lookup for name="%s" resulted in exception: %s',
+ runnertype_name,
+ e,
+ )
+ raise StackStormDBObjectNotFoundError(
+ 'Unable to find runnertype with name="%s"' % runnertype_name
+ )
if not runnertypes:
- raise StackStormDBObjectNotFoundError('Unable to find RunnerType with name="%s"'
- % runnertype_name)
+ raise StackStormDBObjectNotFoundError(
+ 'Unable to find RunnerType with name="%s"' % runnertype_name
+ )
if len(runnertypes) > 1:
- LOG.warning('More than one RunnerType returned from DB lookup by name. '
- 'Result list is: %s', runnertypes)
+ LOG.warning(
+ "More than one RunnerType returned from DB lookup by name. "
+ "Result list is: %s",
+ runnertypes,
+ )
return runnertypes[0]
def get_action_by_id(action_id):
"""
- Get Action by id.
+ Get Action by id.
- On error, raise StackStormDBObjectNotFoundError
+ On error, raise StackStormDBObjectNotFoundError
"""
action = None
try:
action = Action.get_by_id(action_id)
except (ValueError, ValidationError) as e:
- LOG.warning('Database lookup for action with id="%s" resulted in '
- 'exception: %s', action_id, e)
- raise StackStormDBObjectNotFoundError('Unable to find action with '
- 'id="%s"' % action_id)
+ LOG.warning(
+ 'Database lookup for action with id="%s" resulted in ' "exception: %s",
+ action_id,
+ e,
+ )
+ raise StackStormDBObjectNotFoundError(
+ "Unable to find action with " 'id="%s"' % action_id
+ )
return action
@@ -153,56 +170,78 @@ def get_action_by_ref(ref):
try:
return Action.get_by_ref(ref)
except ValueError as e:
- LOG.debug('Database lookup for ref="%s" resulted ' +
- 'in exception : %s.', ref, e, exc_info=True)
+ LOG.debug(
+ 'Database lookup for ref="%s" resulted ' + "in exception : %s.",
+ ref,
+ e,
+ exc_info=True,
+ )
return None
def get_liveaction_by_id(liveaction_id):
"""
- Get LiveAction by id.
+ Get LiveAction by id.
- On error, raise ST2DBObjectNotFoundError.
+ On error, raise ST2DBObjectNotFoundError.
"""
liveaction = None
try:
liveaction = LiveAction.get_by_id(liveaction_id)
except (ValidationError, ValueError) as e:
- LOG.error('Database lookup for LiveAction with id="%s" resulted in '
- 'exception: %s', liveaction_id, e)
- raise StackStormDBObjectNotFoundError('Unable to find LiveAction with '
- 'id="%s"' % liveaction_id)
+ LOG.error(
+ 'Database lookup for LiveAction with id="%s" resulted in ' "exception: %s",
+ liveaction_id,
+ e,
+ )
+ raise StackStormDBObjectNotFoundError(
+ "Unable to find LiveAction with " 'id="%s"' % liveaction_id
+ )
return liveaction
-def update_liveaction_status(status=None, result=None, context=None, end_timestamp=None,
- liveaction_id=None, runner_info=None, liveaction_db=None,
- publish=True):
+def update_liveaction_status(
+ status=None,
+ result=None,
+ context=None,
+ end_timestamp=None,
+ liveaction_id=None,
+ runner_info=None,
+ liveaction_db=None,
+ publish=True,
+):
"""
- Update the status of the specified LiveAction to the value provided in
- new_status.
+ Update the status of the specified LiveAction to the value provided in
+ new_status.
- The LiveAction may be specified using either liveaction_id, or as an
- liveaction_db instance.
+ The LiveAction may be specified using either liveaction_id, or as an
+ liveaction_db instance.
"""
if (liveaction_id is None) and (liveaction_db is None):
- raise ValueError('Must specify an liveaction_id or an liveaction_db when '
- 'calling update_LiveAction_status')
+ raise ValueError(
+ "Must specify an liveaction_id or an liveaction_db when "
+ "calling update_LiveAction_status"
+ )
if liveaction_db is None:
liveaction_db = get_liveaction_by_id(liveaction_id)
if status not in LIVEACTION_STATUSES:
- raise ValueError('Attempting to set status for LiveAction "%s" '
- 'to unknown status string. Unknown status is "%s"'
- % (liveaction_db, status))
+ raise ValueError(
+ 'Attempting to set status for LiveAction "%s" '
+ 'to unknown status string. Unknown status is "%s"' % (liveaction_db, status)
+ )
- if result and cfg.CONF.system.validate_output_schema and status == LIVEACTION_STATUS_SUCCEEDED:
+ if (
+ result
+ and cfg.CONF.system.validate_output_schema
+ and status == LIVEACTION_STATUS_SUCCEEDED
+ ):
action_db = get_action_by_ref(liveaction_db.action)
- runner_db = get_runnertype_by_name(action_db.runner_type['name'])
+ runner_db = get_runnertype_by_name(action_db.runner_type["name"])
result, status = output_schema.validate_output(
runner_db.output_schema,
action_db.output_schema,
@@ -214,21 +253,33 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta
# If liveaction_db status is set then we need to decrement the counter
# because it is transitioning to a new state
if liveaction_db.status:
- get_driver().dec_counter('action.executions.%s' % (liveaction_db.status))
+ get_driver().dec_counter("action.executions.%s" % (liveaction_db.status))
# If status is provided then we need to increment the timer because the action
# is transitioning into this new state
if status:
- get_driver().inc_counter('action.executions.%s' % (status))
+ get_driver().inc_counter("action.executions.%s" % (status))
- extra = {'liveaction_db': liveaction_db}
- LOG.debug('Updating ActionExection: "%s" with status="%s"', liveaction_db.id, status,
- extra=extra)
+ extra = {"liveaction_db": liveaction_db}
+ LOG.debug(
+ 'Updating ActionExection: "%s" with status="%s"',
+ liveaction_db.id,
+ status,
+ extra=extra,
+ )
# If liveaction is already canceled, then do not allow status to be updated.
- if liveaction_db.status == LIVEACTION_STATUS_CANCELED and status != LIVEACTION_STATUS_CANCELED:
- LOG.info('Unable to update ActionExecution "%s" with status="%s". '
- 'ActionExecution is already canceled.', liveaction_db.id, status, extra=extra)
+ if (
+ liveaction_db.status == LIVEACTION_STATUS_CANCELED
+ and status != LIVEACTION_STATUS_CANCELED
+ ):
+ LOG.info(
+ 'Unable to update ActionExecution "%s" with status="%s". '
+ "ActionExecution is already canceled.",
+ liveaction_db.id,
+ status,
+ extra=extra,
+ )
return liveaction_db
old_status = liveaction_db.status
@@ -250,11 +301,11 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta
# manipulated fields
liveaction_db = LiveAction.add_or_update(liveaction_db)
- LOG.debug('Updated status for LiveAction object.', extra=extra)
+ LOG.debug("Updated status for LiveAction object.", extra=extra)
if publish and status != old_status:
LiveAction.publish_status(liveaction_db)
- LOG.debug('Published status for LiveAction object.', extra=extra)
+ LOG.debug("Published status for LiveAction object.", extra=extra)
return liveaction_db
@@ -267,9 +318,9 @@ def serialize_positional_argument(argument_type, argument_value):
sense for shell script actions (only the outter / top level value is
serialized).
"""
- if argument_type in ['string', 'number', 'float']:
+ if argument_type in ["string", "number", "float"]:
if argument_value is None:
- argument_value = six.text_type('')
+ argument_value = six.text_type("")
return argument_value
if isinstance(argument_value, (int, float)):
@@ -277,25 +328,25 @@ def serialize_positional_argument(argument_type, argument_value):
if not isinstance(argument_value, six.text_type):
# cast string non-unicode values to unicode
- argument_value = argument_value.decode('utf-8')
- elif argument_type == 'boolean':
+ argument_value = argument_value.decode("utf-8")
+ elif argument_type == "boolean":
# Booleans are serialized as string "1" and "0"
if argument_value is not None:
- argument_value = '1' if bool(argument_value) else '0'
+ argument_value = "1" if bool(argument_value) else "0"
else:
- argument_value = ''
- elif argument_type in ['array', 'list']:
+ argument_value = ""
+ elif argument_type in ["array", "list"]:
# Lists are serialized a comma delimited string (foo,bar,baz)
- argument_value = ','.join(map(str, argument_value)) if argument_value else ''
- elif argument_type == 'object':
+ argument_value = ",".join(map(str, argument_value)) if argument_value else ""
+ elif argument_type == "object":
# Objects are serialized as JSON
- argument_value = json.dumps(argument_value) if argument_value else ''
- elif argument_type == 'null':
+ argument_value = json.dumps(argument_value) if argument_value else ""
+ elif argument_type == "null":
# None / null is serialized as en empty string
- argument_value = ''
+ argument_value = ""
else:
# Other values are simply cast to unicode string
- argument_value = six.text_type(argument_value) if argument_value else ''
+ argument_value = six.text_type(argument_value) if argument_value else ""
return argument_value
@@ -315,12 +366,13 @@ def get_args(action_parameters, action_db):
positional_args = []
positional_args_keys = set()
for _, arg in six.iteritems(position_args_dict):
- arg_type = action_db_parameters.get(arg, {}).get('type', None)
+ arg_type = action_db_parameters.get(arg, {}).get("type", None)
# Perform serialization for positional arguments
arg_value = action_parameters.get(arg, None)
- arg_value = serialize_positional_argument(argument_type=arg_type,
- argument_value=arg_value)
+ arg_value = serialize_positional_argument(
+ argument_type=arg_type, argument_value=arg_value
+ )
positional_args.append(arg_value)
positional_args_keys.add(arg)
@@ -340,7 +392,7 @@ def _get_position_arg_dict(action_parameters, action_db):
for param in action_db_params:
param_meta = action_db_params.get(param, None)
if param_meta is not None:
- pos = param_meta.get('position')
+ pos = param_meta.get("position")
if pos is not None:
args_dict[pos] = param
args_dict = OrderedDict(sorted(args_dict.items()))
diff --git a/st2common/st2common/util/actionalias_helpstring.py b/st2common/st2common/util/actionalias_helpstring.py
index ddee088c8c..109328f926 100644
--- a/st2common/st2common/util/actionalias_helpstring.py
+++ b/st2common/st2common/util/actionalias_helpstring.py
@@ -18,9 +18,7 @@
from st2common.util.actionalias_matching import normalise_alias_format_string
-__all__ = [
- 'generate_helpstring_result'
-]
+__all__ = ["generate_helpstring_result"]
def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=0):
@@ -44,7 +42,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=
matches = []
count = 0
if not (isinstance(limit, int) and isinstance(offset, int)):
- raise TypeError('limit or offset argument is not an integer')
+ raise TypeError("limit or offset argument is not an integer")
for alias in aliases:
# Skip disable aliases.
if not alias.enabled:
@@ -56,7 +54,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=
display, _, _ = normalise_alias_format_string(format_)
if display:
# Skip help strings not containing keyword.
- if not re.search(filter or '', display, flags=re.IGNORECASE):
+ if not re.search(filter or "", display, flags=re.IGNORECASE):
continue
# Skip over help strings not within the requested offset/limit range.
if (offset == 0 and limit > 0) and count >= limit:
@@ -65,14 +63,18 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=
elif (offset > 0 and limit == 0) and count < offset:
count += 1
continue
- elif (offset > 0 and limit > 0) and (count < offset or count >= offset + limit):
+ elif (offset > 0 and limit > 0) and (
+ count < offset or count >= offset + limit
+ ):
count += 1
continue
- matches.append({
- "pack": alias.pack,
- "display": display,
- "description": alias.description
- })
+ matches.append(
+ {
+ "pack": alias.pack,
+ "display": display,
+ "description": alias.description,
+ }
+ )
count += 1
return {"available": count, "helpstrings": matches}
diff --git a/st2common/st2common/util/actionalias_matching.py b/st2common/st2common/util/actionalias_matching.py
index 3827b12d93..1b20fad414 100644
--- a/st2common/st2common/util/actionalias_matching.py
+++ b/st2common/st2common/util/actionalias_matching.py
@@ -24,15 +24,15 @@
from st2common.models.utils.action_alias_utils import extract_parameters
__all__ = [
- 'list_format_strings_from_aliases',
- 'normalise_alias_format_string',
- 'match_command_to_alias',
- 'get_matching_alias',
+ "list_format_strings_from_aliases",
+ "normalise_alias_format_string",
+ "match_command_to_alias",
+ "get_matching_alias",
]
def list_format_strings_from_aliases(aliases, match_multiple=False):
- '''
+ """
List patterns from a collection of alias objects
:param aliases: The list of aliases
@@ -40,34 +40,40 @@ def list_format_strings_from_aliases(aliases, match_multiple=False):
:return: A description of potential execution patterns in a list of aliases.
:rtype: ``list`` of ``list``
- '''
+ """
patterns = []
for alias in aliases:
for format_ in alias.formats:
- display, representations, _match_multiple = normalise_alias_format_string(format_)
+ display, representations, _match_multiple = normalise_alias_format_string(
+ format_
+ )
if display and len(representations) == 0:
- patterns.append({
- 'alias': alias,
- 'format': format_,
- 'display': display,
- 'representation': '',
- })
- else:
- patterns.extend([
+ patterns.append(
{
- 'alias': alias,
- 'format': format_,
- 'display': display,
- 'representation': representation,
- 'match_multiple': _match_multiple,
+ "alias": alias,
+ "format": format_,
+ "display": display,
+ "representation": "",
}
- for representation in representations
- ])
+ )
+ else:
+ patterns.extend(
+ [
+ {
+ "alias": alias,
+ "format": format_,
+ "display": display,
+ "representation": representation,
+ "match_multiple": _match_multiple,
+ }
+ for representation in representations
+ ]
+ )
return patterns
def normalise_alias_format_string(alias_format):
- '''
+ """
StackStorm action aliases come in two forms;
1. A string holding the format, which is also used as the help string.
2. A dictionary containing "display" and/or "representation" keys.
@@ -80,7 +86,7 @@ def normalise_alias_format_string(alias_format):
:return: The representation of the alias
:rtype: ``tuple`` of (``str``, ``str``)
- '''
+ """
display = None
representation = []
match_multiple = False
@@ -89,14 +95,16 @@ def normalise_alias_format_string(alias_format):
display = alias_format
representation.append(alias_format)
elif isinstance(alias_format, dict):
- display = alias_format.get('display')
- representation = alias_format.get('representation') or []
+ display = alias_format.get("display")
+ representation = alias_format.get("representation") or []
if isinstance(representation, six.string_types):
representation = [representation]
- match_multiple = alias_format.get('match_multiple', match_multiple)
+ match_multiple = alias_format.get("match_multiple", match_multiple)
else:
- raise TypeError("alias_format '%s' is neither a dictionary or string type."
- % repr(alias_format))
+ raise TypeError(
+ "alias_format '%s' is neither a dictionary or string type."
+ % repr(alias_format)
+ )
return (display, representation, match_multiple)
@@ -110,8 +118,9 @@ def match_command_to_alias(command, aliases, match_multiple=False):
formats = list_format_strings_from_aliases([alias], match_multiple)
for format_ in formats:
try:
- extract_parameters(format_str=format_['representation'],
- param_stream=command)
+ extract_parameters(
+ format_str=format_["representation"], param_stream=command
+ )
except ParseException:
continue
@@ -125,35 +134,41 @@ def get_matching_alias(command):
"""
# 1. Get aliases
action_alias_dbs = ActionAlias.query(
- Q(formats__match_multiple=None) | Q(formats__match_multiple=False),
- enabled=True)
+ Q(formats__match_multiple=None) | Q(formats__match_multiple=False), enabled=True
+ )
# 2. Match alias(es) to command
matches = match_command_to_alias(command=command, aliases=action_alias_dbs)
if len(matches) > 1:
- raise ActionAliasAmbiguityException("Command '%s' matched more than 1 pattern" %
- command,
- matches=matches,
- command=command)
+ raise ActionAliasAmbiguityException(
+ "Command '%s' matched more than 1 pattern" % command,
+ matches=matches,
+ command=command,
+ )
elif len(matches) == 0:
match_multiple_action_alias_dbs = ActionAlias.query(
- formats__match_multiple=True,
- enabled=True)
+ formats__match_multiple=True, enabled=True
+ )
- matches = match_command_to_alias(command=command, aliases=match_multiple_action_alias_dbs,
- match_multiple=True)
+ matches = match_command_to_alias(
+ command=command,
+ aliases=match_multiple_action_alias_dbs,
+ match_multiple=True,
+ )
if len(matches) > 1:
- raise ActionAliasAmbiguityException("Command '%s' matched more than 1 (multi) pattern" %
- command,
- matches=matches,
- command=command)
+ raise ActionAliasAmbiguityException(
+ "Command '%s' matched more than 1 (multi) pattern" % command,
+ matches=matches,
+ command=command,
+ )
if len(matches) == 0:
- raise ActionAliasAmbiguityException("Command '%s' matched no patterns" %
- command,
- matches=[],
- command=command)
+ raise ActionAliasAmbiguityException(
+ "Command '%s' matched no patterns" % command,
+ matches=[],
+ command=command,
+ )
return matches[0]
diff --git a/st2common/st2common/util/api.py b/st2common/st2common/util/api.py
index 4e0e3f4938..2c378ad726 100644
--- a/st2common/st2common/util/api.py
+++ b/st2common/st2common/util/api.py
@@ -21,8 +21,8 @@
from st2common.util.url import get_url_without_trailing_slash
__all__ = [
- 'get_base_public_api_url',
- 'get_full_public_api_url',
+ "get_base_public_api_url",
+ "get_full_public_api_url",
]
LOG = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ def get_base_public_api_url():
api_url = get_url_without_trailing_slash(cfg.CONF.auth.api_url)
else:
LOG.warn('"auth.api_url" configuration option is not configured')
- api_url = 'http://%s:%s' % (cfg.CONF.api.host, cfg.CONF.api.port)
+ api_url = "http://%s:%s" % (cfg.CONF.api.host, cfg.CONF.api.port)
return api_url
@@ -52,5 +52,5 @@ def get_full_public_api_url(api_version=DEFAULT_API_VERSION):
:rtype: ``str``
"""
api_url = get_base_public_api_url()
- api_url = '%s/%s' % (api_url, api_version)
+ api_url = "%s/%s" % (api_url, api_version)
return api_url
diff --git a/st2common/st2common/util/argument_parser.py b/st2common/st2common/util/argument_parser.py
index 28645ad15f..757f171661 100644
--- a/st2common/st2common/util/argument_parser.py
+++ b/st2common/st2common/util/argument_parser.py
@@ -16,9 +16,7 @@
from __future__ import absolute_import
import argparse
-__all__ = [
- 'generate_argument_parser_for_metadata'
-]
+__all__ = ["generate_argument_parser_for_metadata"]
def generate_argument_parser_for_metadata(metadata):
@@ -32,37 +30,37 @@ def generate_argument_parser_for_metadata(metadata):
:return: Generated argument parser instance.
:rtype: :class:`argparse.ArgumentParser`
"""
- parameters = metadata['parameters']
+ parameters = metadata["parameters"]
- parser = argparse.ArgumentParser(description=metadata['description'])
+ parser = argparse.ArgumentParser(description=metadata["description"])
for parameter_name, parameter_options in parameters.items():
- name = parameter_name.replace('_', '-')
- description = parameter_options['description']
- _type = parameter_options['type']
- required = parameter_options.get('required', False)
- default_value = parameter_options.get('default', None)
- immutable = parameter_options.get('immutable', False)
+ name = parameter_name.replace("_", "-")
+ description = parameter_options["description"]
+ _type = parameter_options["type"]
+ required = parameter_options.get("required", False)
+ default_value = parameter_options.get("default", None)
+ immutable = parameter_options.get("immutable", False)
# Immutable arguments can't be controlled by the user
if immutable:
continue
- args = ['--%s' % (name)]
- kwargs = {'help': description, 'required': required}
+ args = ["--%s" % (name)]
+ kwargs = {"help": description, "required": required}
if default_value is not None:
- kwargs['default'] = default_value
+ kwargs["default"] = default_value
- if _type == 'string':
- kwargs['type'] = str
- elif _type == 'integer':
- kwargs['type'] = int
- elif _type == 'boolean':
+ if _type == "string":
+ kwargs["type"] = str
+ elif _type == "integer":
+ kwargs["type"] = int
+ elif _type == "boolean":
if default_value is False:
- kwargs['action'] = 'store_false'
+ kwargs["action"] = "store_false"
else:
- kwargs['action'] = 'store_true'
+ kwargs["action"] = "store_true"
parser.add_argument(*args, **kwargs)
diff --git a/st2common/st2common/util/auth.py b/st2common/st2common/util/auth.py
index 38294c92a7..90e81d938e 100644
--- a/st2common/st2common/util/auth.py
+++ b/st2common/st2common/util/auth.py
@@ -28,11 +28,11 @@
from st2common.util import hash as hash_utils
__all__ = [
- 'validate_token',
- 'validate_token_and_source',
- 'generate_api_key',
- 'validate_api_key',
- 'validate_api_key_and_source'
+ "validate_token",
+ "validate_token_and_source",
+ "generate_api_key",
+ "validate_api_key",
+ "validate_api_key_and_source",
]
LOG = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ def validate_token(token_string):
if token.expiry <= date_utils.get_datetime_utc_now():
# TODO: purge expired tokens
LOG.audit('Token with id "%s" has expired.' % (token.id))
- raise exceptions.TokenExpiredError('Token has expired.')
+ raise exceptions.TokenExpiredError("Token has expired.")
LOG.audit('Token with id "%s" is validated.' % (token.id))
@@ -74,14 +74,14 @@ def validate_token_and_source(token_in_headers, token_in_query_params):
:rtype: :class:`.TokenDB`
"""
if not token_in_headers and not token_in_query_params:
- LOG.audit('Token is not found in header or query parameters.')
- raise exceptions.TokenNotProvidedError('Token is not provided.')
+ LOG.audit("Token is not found in header or query parameters.")
+ raise exceptions.TokenNotProvidedError("Token is not provided.")
if token_in_headers:
- LOG.audit('Token provided in headers')
+ LOG.audit("Token provided in headers")
if token_in_query_params:
- LOG.audit('Token provided in query parameters')
+ LOG.audit("Token provided in query parameters")
return validate_token(token_in_headers or token_in_query_params)
@@ -103,7 +103,8 @@ def generate_api_key():
base64_encoded = base64.b64encode(
six.b(hashed_seed),
- six.b(random.choice(['rA', 'aZ', 'gQ', 'hH', 'hG', 'aR', 'DD']))).rstrip(b'==')
+ six.b(random.choice(["rA", "aZ", "gQ", "hH", "hG", "aR", "DD"])),
+ ).rstrip(b"==")
base64_encoded = base64_encoded.decode()
return base64_encoded
@@ -127,7 +128,7 @@ def validate_api_key(api_key):
api_key_db = ApiKey.get(api_key)
if not api_key_db.enabled:
- raise exceptions.ApiKeyDisabledError('API key is disabled.')
+ raise exceptions.ApiKeyDisabledError("API key is disabled.")
LOG.audit('API key with id "%s" is validated.' % (api_key_db.id))
@@ -148,13 +149,13 @@ def validate_api_key_and_source(api_key_in_headers, api_key_query_params):
:rtype: :class:`.ApiKeyDB`
"""
if not api_key_in_headers and not api_key_query_params:
- LOG.audit('API key is not found in header or query parameters.')
- raise exceptions.ApiKeyNotProvidedError('API key is not provided.')
+ LOG.audit("API key is not found in header or query parameters.")
+ raise exceptions.ApiKeyNotProvidedError("API key is not provided.")
if api_key_in_headers:
- LOG.audit('API key provided in headers')
+ LOG.audit("API key provided in headers")
if api_key_query_params:
- LOG.audit('API key provided in query parameters')
+ LOG.audit("API key provided in query parameters")
return validate_api_key(api_key_in_headers or api_key_query_params)
diff --git a/st2common/st2common/util/casts.py b/st2common/st2common/util/casts.py
index fa94272e47..aadad8a4a1 100644
--- a/st2common/st2common/util/casts.py
+++ b/st2common/st2common/util/casts.py
@@ -89,12 +89,12 @@ def _cast_none(x):
# These types as they appear in json schema.
CASTS = {
- 'array': _cast_object,
- 'boolean': _cast_boolean,
- 'integer': _cast_integer,
- 'number': _cast_number,
- 'object': _cast_object,
- 'string': _cast_string
+ "array": _cast_object,
+ "boolean": _cast_boolean,
+ "integer": _cast_integer,
+ "number": _cast_number,
+ "object": _cast_object,
+ "string": _cast_string,
}
diff --git a/st2common/st2common/util/compat.py b/st2common/st2common/util/compat.py
index 9288f5f3a0..1926f97dba 100644
--- a/st2common/st2common/util/compat.py
+++ b/st2common/st2common/util/compat.py
@@ -24,16 +24,15 @@
__all__ = [
- 'mock_open_name',
-
- 'to_unicode',
- 'to_ascii',
+ "mock_open_name",
+ "to_unicode",
+ "to_ascii",
]
if six.PY3:
- mock_open_name = 'builtins.open'
+ mock_open_name = "builtins.open"
else:
- mock_open_name = '__builtin__.open'
+ mock_open_name = "__builtin__.open"
def to_unicode(value):
@@ -63,4 +62,4 @@ def to_ascii(value):
if six.PY3:
value = value.encode()
- return value.decode('ascii', errors='ignore')
+ return value.decode("ascii", errors="ignore")
diff --git a/st2common/st2common/util/concurrency.py b/st2common/st2common/util/concurrency.py
index 50312fa78f..239407ade0 100644
--- a/st2common/st2common/util/concurrency.py
+++ b/st2common/st2common/util/concurrency.py
@@ -31,34 +31,30 @@
except ImportError:
gevent = None
-CONCURRENCY_LIBRARY = 'eventlet'
+CONCURRENCY_LIBRARY = "eventlet"
__all__ = [
- 'set_concurrency_library',
- 'get_concurrency_library',
-
- 'get_subprocess_module',
- 'subprocess_popen',
-
- 'spawn',
- 'wait',
- 'cancel',
- 'kill',
- 'sleep',
-
- 'get_greenlet_exit_exception_class',
-
- 'get_green_pool_class',
- 'is_green_pool_free',
- 'green_pool_wait_all'
+ "set_concurrency_library",
+ "get_concurrency_library",
+ "get_subprocess_module",
+ "subprocess_popen",
+ "spawn",
+ "wait",
+ "cancel",
+ "kill",
+ "sleep",
+ "get_greenlet_exit_exception_class",
+ "get_green_pool_class",
+ "is_green_pool_free",
+ "green_pool_wait_all",
]
def set_concurrency_library(library):
global CONCURRENCY_LIBRARY
- if library not in ['eventlet', 'gevent']:
- raise ValueError('Unsupported concurrency library: %s' % (library))
+ if library not in ["eventlet", "gevent"]:
+ raise ValueError("Unsupported concurrency library: %s" % (library))
CONCURRENCY_LIBRARY = library
@@ -69,107 +65,111 @@ def get_concurrency_library():
def get_subprocess_module():
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
from eventlet.green import subprocess # pylint: disable=import-error
+
return subprocess
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
from gevent import subprocess # pylint: disable=import-error
+
return subprocess
def subprocess_popen(*args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
from eventlet.green import subprocess # pylint: disable=import-error
+
return subprocess.Popen(*args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
from gevent import subprocess # pylint: disable=import-error
+
return subprocess.Popen(*args, **kwargs)
def spawn(func, *args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return eventlet.spawn(func, *args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return gevent.spawn(func, *args, **kwargs)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def wait(green_thread, *args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return green_thread.wait(*args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return green_thread.join(*args, **kwargs)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def cancel(green_thread, *args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return green_thread.cancel(*args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return green_thread.kill(*args, **kwargs)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def kill(green_thread, *args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return green_thread.kill(*args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return green_thread.kill(*args, **kwargs)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def sleep(*args, **kwargs):
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return eventlet.sleep(*args, **kwargs)
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return gevent.sleep(*args, **kwargs)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def get_greenlet_exit_exception_class():
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return eventlet.support.greenlets.GreenletExit
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return gevent.GreenletExit
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def get_green_pool_class():
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return eventlet.GreenPool
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return gevent.pool.Pool
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def is_green_pool_free(pool):
"""
Return True if the provided green pool is free, False otherwise.
"""
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return pool.free()
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
return not pool.full()
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
def green_pool_wait_all(pool):
"""
Wait for all the green threads in the pool to finish.
"""
- if CONCURRENCY_LIBRARY == 'eventlet':
+ if CONCURRENCY_LIBRARY == "eventlet":
return pool.waitall()
- elif CONCURRENCY_LIBRARY == 'gevent':
+ elif CONCURRENCY_LIBRARY == "gevent":
# NOTE: This mimicks eventlet.waitall() functionallity better than
# pool.join()
return all(gl.ready() for gl in pool.greenlets)
else:
- raise ValueError('Unsupported concurrency library')
+ raise ValueError("Unsupported concurrency library")
diff --git a/st2common/st2common/util/config_loader.py b/st2common/st2common/util/config_loader.py
index 620707e643..30db039bdc 100644
--- a/st2common/st2common/util/config_loader.py
+++ b/st2common/st2common/util/config_loader.py
@@ -30,9 +30,7 @@
from st2common.util.config_parser import ContentPackConfigParser
from st2common.exceptions.db import StackStormDBObjectNotFoundError
-__all__ = [
- 'ContentPackConfigLoader'
-]
+__all__ = ["ContentPackConfigLoader"]
LOG = logging.getLogger(__name__)
@@ -79,15 +77,16 @@ def get_config(self):
# 2. Retrieve values from "global" pack config file (if available) and resolve them if
# necessary
- config = self._get_values_for_config(config_schema_db=config_schema_db,
- config_db=config_db)
+ config = self._get_values_for_config(
+ config_schema_db=config_schema_db, config_db=config_db
+ )
result.update(config)
return result
def _get_values_for_config(self, config_schema_db, config_db):
- schema_values = getattr(config_schema_db, 'attributes', {})
- config_values = getattr(config_db, 'values', {})
+ schema_values = getattr(config_schema_db, "attributes", {})
+ config_values = getattr(config_db, "values", {})
config = copy.deepcopy(config_values or {})
@@ -131,24 +130,34 @@ def _assign_dynamic_config_values(self, schema, config, parent_keys=None):
# Inspect nested object properties
if is_dictionary:
parent_keys += [str(config_item_key)]
- self._assign_dynamic_config_values(schema=schema_item.get('properties', {}),
- config=config[config_item_key],
- parent_keys=parent_keys)
+ self._assign_dynamic_config_values(
+ schema=schema_item.get("properties", {}),
+ config=config[config_item_key],
+ parent_keys=parent_keys,
+ )
# Inspect nested list items
elif is_list:
parent_keys += [str(config_item_key)]
- self._assign_dynamic_config_values(schema=schema_item.get('items', {}),
- config=config[config_item_key],
- parent_keys=parent_keys)
+ self._assign_dynamic_config_values(
+ schema=schema_item.get("items", {}),
+ config=config[config_item_key],
+ parent_keys=parent_keys,
+ )
else:
- is_jinja_expression = jinja_utils.is_jinja_expression(value=config_item_value)
+ is_jinja_expression = jinja_utils.is_jinja_expression(
+ value=config_item_value
+ )
if is_jinja_expression:
# Resolve / render the Jinja template expression
- full_config_item_key = '.'.join(parent_keys + [str(config_item_key)])
- value = self._get_datastore_value_for_expression(key=full_config_item_key,
+ full_config_item_key = ".".join(
+ parent_keys + [str(config_item_key)]
+ )
+ value = self._get_datastore_value_for_expression(
+ key=full_config_item_key,
value=config_item_value,
- config_schema_item=schema_item)
+ config_schema_item=schema_item,
+ )
config[config_item_key] = value
else:
@@ -167,12 +176,12 @@ def _assign_default_values(self, schema, config):
:rtype: ``dict``
"""
for schema_item_key, schema_item in six.iteritems(schema):
- has_default_value = 'default' in schema_item
+ has_default_value = "default" in schema_item
has_config_value = schema_item_key in config
- default_value = schema_item.get('default', None)
- is_object = schema_item.get('type', None) == 'object'
- has_properties = schema_item.get('properties', None)
+ default_value = schema_item.get("default", None)
+ is_object = schema_item.get("type", None) == "object"
+ has_properties = schema_item.get("properties", None)
if has_default_value and not has_config_value:
# Config value is not provided, but default value is, use a default value
@@ -183,8 +192,9 @@ def _assign_default_values(self, schema, config):
if not config.get(schema_item_key, None):
config[schema_item_key] = {}
- self._assign_default_values(schema=schema_item['properties'],
- config=config[schema_item_key])
+ self._assign_default_values(
+ schema=schema_item["properties"], config=config[schema_item_key]
+ )
return config
@@ -198,18 +208,21 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non
from st2common.services.config import deserialize_key_value
config_schema_item = config_schema_item or {}
- secret = config_schema_item.get('secret', False)
+ secret = config_schema_item.get("secret", False)
try:
- value = render_template_with_system_and_user_context(value=value,
- user=self.user)
+ value = render_template_with_system_and_user_context(
+ value=value, user=self.user
+ )
except Exception as e:
# Throw a more user-friendly exception on failed render
exc_class = type(e)
original_msg = six.text_type(e)
- msg = ('Failed to render dynamic configuration value for key "%s" with value '
- '"%s" for pack "%s" config: %s %s ' % (key, value, self.pack_name,
- exc_class, original_msg))
+ msg = (
+ 'Failed to render dynamic configuration value for key "%s" with value '
+ '"%s" for pack "%s" config: %s %s '
+ % (key, value, self.pack_name, exc_class, original_msg)
+ )
raise RuntimeError(msg)
if value:
@@ -222,21 +235,17 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non
def get_config(pack, user):
- """Returns config for given pack and user.
- """
+ """Returns config for given pack and user."""
LOG.debug('Attempting to get config for pack "%s" and user "%s"' % (pack, user))
if pack and user:
- LOG.debug('Pack and user found. Loading config.')
- config_loader = ContentPackConfigLoader(
- pack_name=pack,
- user=user
- )
+ LOG.debug("Pack and user found. Loading config.")
+ config_loader = ContentPackConfigLoader(pack_name=pack, user=user)
config = config_loader.get_config()
else:
config = {}
- LOG.debug('Config: %s', config)
+ LOG.debug("Config: %s", config)
return config
diff --git a/st2common/st2common/util/config_parser.py b/st2common/st2common/util/config_parser.py
index 247dca88fa..40c9e30313 100644
--- a/st2common/st2common/util/config_parser.py
+++ b/st2common/st2common/util/config_parser.py
@@ -21,10 +21,7 @@
from st2common.content import utils
-__all__ = [
- 'ContentPackConfigParser',
- 'ContentPackConfig'
-]
+__all__ = ["ContentPackConfigParser", "ContentPackConfig"]
class ContentPackConfigParser(object):
@@ -32,8 +29,8 @@ class ContentPackConfigParser(object):
Class responsible for obtaining and parsing content pack configs.
"""
- GLOBAL_CONFIG_NAME = 'config.yaml'
- LOCAL_CONFIG_SUFFIX = '_config.yaml'
+ GLOBAL_CONFIG_NAME = "config.yaml"
+ LOCAL_CONFIG_SUFFIX = "_config.yaml"
def __init__(self, pack_name):
self.pack_name = pack_name
@@ -85,8 +82,7 @@ def get_global_config_path(self):
if not self.pack_path:
return None
- global_config_path = os.path.join(self.pack_path,
- self.GLOBAL_CONFIG_NAME)
+ global_config_path = os.path.join(self.pack_path, self.GLOBAL_CONFIG_NAME)
return global_config_path
@classmethod
@@ -95,7 +91,7 @@ def get_and_parse_config(cls, config_path):
return None
if os.path.exists(config_path) and os.path.isfile(config_path):
- with io.open(config_path, 'r', encoding='utf8') as fp:
+ with io.open(config_path, "r", encoding="utf8") as fp:
config = yaml.safe_load(fp.read())
return ContentPackConfig(file_path=config_path, config=config)
diff --git a/st2common/st2common/util/crypto.py b/st2common/st2common/util/crypto.py
index 230c4ada8e..d01e20557b 100644
--- a/st2common/st2common/util/crypto.py
+++ b/st2common/st2common/util/crypto.py
@@ -51,23 +51,18 @@
from cryptography.hazmat.backends import default_backend
__all__ = [
- 'KEYCZAR_HEADER_SIZE',
- 'KEYCZAR_AES_BLOCK_SIZE',
- 'KEYCZAR_HLEN',
-
- 'read_crypto_key',
-
- 'symmetric_encrypt',
- 'symmetric_decrypt',
-
- 'cryptography_symmetric_encrypt',
- 'cryptography_symmetric_decrypt',
-
+ "KEYCZAR_HEADER_SIZE",
+ "KEYCZAR_AES_BLOCK_SIZE",
+ "KEYCZAR_HLEN",
+ "read_crypto_key",
+ "symmetric_encrypt",
+ "symmetric_decrypt",
+ "cryptography_symmetric_encrypt",
+ "cryptography_symmetric_decrypt",
# NOTE: Keyczar functions are here for testing reasons - they are only used by tests
- 'keyczar_symmetric_encrypt',
- 'keyczar_symmetric_decrypt',
-
- 'AESKey'
+ "keyczar_symmetric_encrypt",
+ "keyczar_symmetric_decrypt",
+ "AESKey",
]
# Keyczar related constants
@@ -94,13 +89,19 @@ class AESKey(object):
mode = None
size = None
- def __init__(self, aes_key_string, hmac_key_string, hmac_key_size, mode='CBC',
- size=DEFAULT_AES_KEY_SIZE):
- if mode not in ['CBC']:
- raise ValueError('Unsupported mode: %s' % (mode))
+ def __init__(
+ self,
+ aes_key_string,
+ hmac_key_string,
+ hmac_key_size,
+ mode="CBC",
+ size=DEFAULT_AES_KEY_SIZE,
+ ):
+ if mode not in ["CBC"]:
+ raise ValueError("Unsupported mode: %s" % (mode))
if size < MINIMUM_AES_KEY_SIZE:
- raise ValueError('Unsafe key size: %s' % (size))
+ raise ValueError("Unsafe key size: %s" % (size))
self.aes_key_string = aes_key_string
self.hmac_key_string = hmac_key_string
@@ -121,7 +122,7 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE):
:rtype: :class:`AESKey`
"""
if key_size < MINIMUM_AES_KEY_SIZE:
- raise ValueError('Unsafe key size: %s' % (key_size))
+ raise ValueError("Unsafe key size: %s" % (key_size))
aes_key_bytes = os.urandom(int(key_size / 8))
aes_key_string = Base64WSEncode(aes_key_bytes)
@@ -129,8 +130,13 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE):
hmac_key_bytes = os.urandom(int(key_size / 8))
hmac_key_string = Base64WSEncode(hmac_key_bytes)
- return AESKey(aes_key_string=aes_key_string, hmac_key_string=hmac_key_string,
- hmac_key_size=key_size, mode='CBC', size=key_size)
+ return AESKey(
+ aes_key_string=aes_key_string,
+ hmac_key_string=hmac_key_string,
+ hmac_key_size=key_size,
+ mode="CBC",
+ size=key_size,
+ )
def to_json(self):
"""
@@ -140,19 +146,22 @@ def to_json(self):
:rtype: ``str``
"""
data = {
- 'hmacKey': {
- 'hmacKeyString': self.hmac_key_string,
- 'size': self.hmac_key_size
+ "hmacKey": {
+ "hmacKeyString": self.hmac_key_string,
+ "size": self.hmac_key_size,
},
- 'aesKeyString': self.aes_key_string,
- 'mode': self.mode.upper(),
- 'size': int(self.size)
+ "aesKeyString": self.aes_key_string,
+ "mode": self.mode.upper(),
+ "size": int(self.size),
}
return json.dumps(data)
def __repr__(self):
- return ('' % (self.hmac_key_size, self.mode,
- self.size))
+ return "" % (
+ self.hmac_key_size,
+ self.mode,
+ self.size,
+ )
def read_crypto_key(key_path):
@@ -164,17 +173,19 @@ def read_crypto_key(key_path):
:rtype: :class:`AESKey`
"""
- with open(key_path, 'r') as fp:
+ with open(key_path, "r") as fp:
content = fp.read()
content = json.loads(content)
try:
- aes_key = AESKey(aes_key_string=content['aesKeyString'],
- hmac_key_string=content['hmacKey']['hmacKeyString'],
- hmac_key_size=content['hmacKey']['size'],
- mode=content['mode'].upper(),
- size=content['size'])
+ aes_key = AESKey(
+ aes_key_string=content["aesKeyString"],
+ hmac_key_string=content["hmacKey"]["hmacKeyString"],
+ hmac_key_size=content["hmacKey"]["size"],
+ mode=content["mode"].upper(),
+ size=content["size"],
+ )
except KeyError as e:
msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e))
raise KeyError(msg)
@@ -187,7 +198,9 @@ def symmetric_encrypt(encrypt_key, plaintext):
def symmetric_decrypt(decrypt_key, ciphertext):
- return cryptography_symmetric_decrypt(decrypt_key=decrypt_key, ciphertext=ciphertext)
+ return cryptography_symmetric_decrypt(
+ decrypt_key=decrypt_key, ciphertext=ciphertext
+ )
def cryptography_symmetric_encrypt(encrypt_key, plaintext):
@@ -206,9 +219,12 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext):
NOTE: Header itself is unused, but it's added so the format is compatible with keyczar format.
"""
- assert isinstance(encrypt_key, AESKey), 'encrypt_key needs to be AESKey class instance'
- assert isinstance(plaintext, (six.text_type, six.string_types, six.binary_type)), \
- 'plaintext needs to either be a string/unicode or bytes'
+ assert isinstance(
+ encrypt_key, AESKey
+ ), "encrypt_key needs to be AESKey class instance"
+ assert isinstance(
+ plaintext, (six.text_type, six.string_types, six.binary_type)
+ ), "plaintext needs to either be a string/unicode or bytes"
aes_key_bytes = encrypt_key.aes_key_bytes
hmac_key_bytes = encrypt_key.hmac_key_bytes
@@ -218,7 +234,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext):
if isinstance(plaintext, (six.text_type, six.string_types)):
# Convert data to bytes
- data = plaintext.encode('utf-8')
+ data = plaintext.encode("utf-8")
else:
data = plaintext
@@ -234,7 +250,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext):
# NOTE: We don't care about actual Keyczar header value, we only care about the length (5
# bytes) so we simply add 5 0's
- header_bytes = b'00000'
+ header_bytes = b"00000"
ciphertext_bytes = encryptor.update(data) + encryptor.finalize()
msg_bytes = header_bytes + iv_bytes + ciphertext_bytes
@@ -263,9 +279,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext):
NOTE 2: This function is loosely based on keyczar AESKey.Decrypt() (Apache 2.0 license).
"""
- assert isinstance(decrypt_key, AESKey), 'decrypt_key needs to be AESKey class instance'
- assert isinstance(ciphertext, (six.text_type, six.string_types, six.binary_type)), \
- 'ciphertext needs to either be a string/unicode or bytes'
+ assert isinstance(
+ decrypt_key, AESKey
+ ), "decrypt_key needs to be AESKey class instance"
+ assert isinstance(
+ ciphertext, (six.text_type, six.string_types, six.binary_type)
+ ), "ciphertext needs to either be a string/unicode or bytes"
aes_key_bytes = decrypt_key.aes_key_bytes
hmac_key_bytes = decrypt_key.hmac_key_bytes
@@ -280,10 +299,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext):
# Verify ciphertext contains IV + HMAC signature
if len(data_bytes) < (KEYCZAR_AES_BLOCK_SIZE + KEYCZAR_HLEN):
- raise ValueError('Invalid or malformed ciphertext (too short)')
+ raise ValueError("Invalid or malformed ciphertext (too short)")
iv_bytes = data_bytes[:KEYCZAR_AES_BLOCK_SIZE] # first block is IV
- ciphertext_bytes = data_bytes[KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN] # strip IV and signature
+ ciphertext_bytes = data_bytes[
+ KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN
+ ] # strip IV and signature
signature_bytes = data_bytes[-KEYCZAR_HLEN:] # last 20 bytes are signature
# Verify HMAC signature
@@ -302,6 +323,7 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext):
decrypted = pkcs5_unpad(decrypted)
return decrypted
+
###
# NOTE: Those methods below are deprecated and only used for testing purposes
##
@@ -329,11 +351,12 @@ def keyczar_symmetric_encrypt(encrypt_key, plaintext):
from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error
from keyczar.keyinfo import GetMode # pylint: disable=import-error
- encrypt_key = KeyczarAesKey(encrypt_key.aes_key_string,
- KeyczarHmacKey(encrypt_key.hmac_key_string,
- encrypt_key.hmac_key_size),
- encrypt_key.size,
- GetMode(encrypt_key.mode))
+ encrypt_key = KeyczarAesKey(
+ encrypt_key.aes_key_string,
+ KeyczarHmacKey(encrypt_key.hmac_key_string, encrypt_key.hmac_key_size),
+ encrypt_key.size,
+ GetMode(encrypt_key.mode),
+ )
return binascii.hexlify(encrypt_key.Encrypt(plaintext)).upper()
@@ -356,11 +379,12 @@ def keyczar_symmetric_decrypt(decrypt_key, ciphertext):
from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error
from keyczar.keyinfo import GetMode # pylint: disable=import-error
- decrypt_key = KeyczarAesKey(decrypt_key.aes_key_string,
- KeyczarHmacKey(decrypt_key.hmac_key_string,
- decrypt_key.hmac_key_size),
- decrypt_key.size,
- GetMode(decrypt_key.mode))
+ decrypt_key = KeyczarAesKey(
+ decrypt_key.aes_key_string,
+ KeyczarHmacKey(decrypt_key.hmac_key_string, decrypt_key.hmac_key_size),
+ decrypt_key.size,
+ GetMode(decrypt_key.mode),
+ )
return decrypt_key.Decrypt(binascii.unhexlify(ciphertext))
@@ -370,7 +394,7 @@ def pkcs5_pad(data):
Pad data using PKCS5
"""
pad = KEYCZAR_AES_BLOCK_SIZE - len(data) % KEYCZAR_AES_BLOCK_SIZE
- data = data + pad * chr(pad).encode('utf-8')
+ data = data + pad * chr(pad).encode("utf-8")
return data
@@ -380,7 +404,7 @@ def pkcs5_unpad(data):
"""
if isinstance(data, six.binary_type):
# Make sure we are operating with a string type
- data = data.decode('utf-8')
+ data = data.decode("utf-8")
pad = ord(data[-1])
data = data[:-pad]
@@ -404,9 +428,9 @@ def Base64WSEncode(s):
"""
if isinstance(s, six.text_type):
# Make sure input string is always converted to bytes (if not already)
- s = s.encode('utf-8')
+ s = s.encode("utf-8")
- return base64.urlsafe_b64encode(s).decode('utf-8').replace("=", "")
+ return base64.urlsafe_b64encode(s).decode("utf-8").replace("=", "")
def Base64WSDecode(s):
@@ -427,12 +451,12 @@ def Base64WSDecode(s):
NOTE: Taken from keyczar (Apache 2.0 license)
"""
- s = ''.join(s.splitlines())
+ s = "".join(s.splitlines())
s = str(s.replace(" ", "")) # kill whitespace, make string (not unicode)
d = len(s) % 4
if d == 1:
- raise ValueError('Base64 decoding errors')
+ raise ValueError("Base64 decoding errors")
elif d == 2:
s += "=="
elif d == 3:
@@ -442,4 +466,4 @@ def Base64WSDecode(s):
return base64.urlsafe_b64decode(s)
except TypeError as e:
# Decoding raises TypeError if s contains invalid characters.
- raise ValueError('Base64 decoding error: %s' % (six.text_type(e)))
+ raise ValueError("Base64 decoding error: %s" % (six.text_type(e)))
diff --git a/st2common/st2common/util/date.py b/st2common/st2common/util/date.py
index 979c3e8eb3..8df0e4659f 100644
--- a/st2common/st2common/util/date.py
+++ b/st2common/st2common/util/date.py
@@ -24,12 +24,7 @@
import dateutil.parser
-__all__ = [
- 'get_datetime_utc_now',
- 'add_utc_tz',
- 'convert_to_utc',
- 'parse'
-]
+__all__ = ["get_datetime_utc_now", "add_utc_tz", "convert_to_utc", "parse"]
def get_datetime_utc_now():
@@ -45,14 +40,14 @@ def get_datetime_utc_now():
def append_milliseconds_to_time(date, millis):
"""
- Return time UTC datetime object offset by provided milliseconds.
+ Return time UTC datetime object offset by provided milliseconds.
"""
return convert_to_utc(date + datetime.timedelta(milliseconds=millis))
def add_utc_tz(dt):
if dt.tzinfo and dt.tzinfo.utcoffset(dt) != datetime.timedelta(0):
- raise ValueError('datetime already contains a non UTC timezone')
+ raise ValueError("datetime already contains a non UTC timezone")
return dt.replace(tzinfo=dateutil.tz.tzutc())
diff --git a/st2common/st2common/util/debugging.py b/st2common/st2common/util/debugging.py
index dd5d74d2a2..66abbbe1ad 100644
--- a/st2common/st2common/util/debugging.py
+++ b/st2common/st2common/util/debugging.py
@@ -25,11 +25,7 @@
from st2common.logging.misc import set_log_level_for_all_loggers
-__all__ = [
- 'enable_debugging',
- 'disable_debugging',
- 'is_enabled'
-]
+__all__ = ["enable_debugging", "disable_debugging", "is_enabled"]
ENABLE_DEBUGGING = False
diff --git a/st2common/st2common/util/deprecation.py b/st2common/st2common/util/deprecation.py
index 160423a5e2..a178a9473d 100644
--- a/st2common/st2common/util/deprecation.py
+++ b/st2common/st2common/util/deprecation.py
@@ -23,10 +23,14 @@ def deprecated(func):
as deprecated. It will result in a warning being emitted
when the function is used.
"""
+
def new_func(*args, **kwargs):
- warnings.warn("Call to deprecated function {}.".format(func.__name__),
- category=DeprecationWarning)
+ warnings.warn(
+ "Call to deprecated function {}.".format(func.__name__),
+ category=DeprecationWarning,
+ )
return func(*args, **kwargs)
+
new_func.__name__ = func.__name__
new_func.__doc__ = func.__doc__
new_func.__dict__.update(func.__dict__)
diff --git a/st2common/st2common/util/driver_loader.py b/st2common/st2common/util/driver_loader.py
index 285c22ed79..50f5044c41 100644
--- a/st2common/st2common/util/driver_loader.py
+++ b/st2common/st2common/util/driver_loader.py
@@ -21,15 +21,11 @@
from st2common import log as logging
-__all__ = [
- 'get_available_backends',
- 'get_backend_driver',
- 'get_backend_instance'
-]
+__all__ = ["get_available_backends", "get_backend_driver", "get_backend_instance"]
LOG = logging.getLogger(__name__)
-BACKENDS_NAMESPACE = 'st2common.rbac.backend'
+BACKENDS_NAMESPACE = "st2common.rbac.backend"
def get_available_backends(namespace, invoke_on_load=False):
@@ -62,8 +58,9 @@ def get_backend_driver(namespace, name, invoke_on_load=False):
LOG.debug('Retrieving driver for backend "%s"' % (name))
try:
- manager = DriverManager(namespace=namespace, name=name,
- invoke_on_load=invoke_on_load)
+ manager = DriverManager(
+ namespace=namespace, name=name, invoke_on_load=invoke_on_load
+ )
except RuntimeError:
message = 'Invalid "%s" backend specified: %s' % (namespace, name)
LOG.exception(message)
@@ -79,7 +76,9 @@ def get_backend_instance(namespace, name, invoke_on_load=False):
:param name: Backend name.
:type name: ``str``
"""
- cls = get_backend_driver(namespace=namespace, name=name, invoke_on_load=invoke_on_load)
+ cls = get_backend_driver(
+ namespace=namespace, name=name, invoke_on_load=invoke_on_load
+ )
cls_instance = cls()
return cls_instance
diff --git a/st2common/st2common/util/enum.py b/st2common/st2common/util/enum.py
index ddcc138ea5..84a6e968f5 100644
--- a/st2common/st2common/util/enum.py
+++ b/st2common/st2common/util/enum.py
@@ -16,15 +16,16 @@
from __future__ import absolute_import
import inspect
-__all__ = [
- 'Enum'
-]
+__all__ = ["Enum"]
class Enum(object):
@classmethod
def get_valid_values(cls):
keys = list(cls.__dict__.keys())
- values = [getattr(cls, key) for key in keys if (not key.startswith('_') and
- not inspect.ismethod(getattr(cls, key)))]
+ values = [
+ getattr(cls, key)
+ for key in keys
+ if (not key.startswith("_") and not inspect.ismethod(getattr(cls, key)))
+ ]
return values
diff --git a/st2common/st2common/util/file_system.py b/st2common/st2common/util/file_system.py
index d6d2458aec..e26adaedfd 100644
--- a/st2common/st2common/util/file_system.py
+++ b/st2common/st2common/util/file_system.py
@@ -26,10 +26,7 @@
import six
-__all__ = [
- 'get_file_list',
- 'recursive_chown'
-]
+__all__ = ["get_file_list", "recursive_chown"]
def get_file_list(directory, exclude_patterns=None):
@@ -48,9 +45,9 @@ def get_file_list(directory, exclude_patterns=None):
:rtype: ``list``
"""
result = []
- if not directory.endswith('/'):
+ if not directory.endswith("/"):
# Make sure trailing slash is present
- directory = directory + '/'
+ directory = directory + "/"
def include_file(file_path):
if not exclude_patterns:
@@ -63,7 +60,7 @@ def include_file(file_path):
return True
for (dirpath, dirnames, filenames) in os.walk(directory):
- base_path = dirpath.replace(directory, '')
+ base_path = dirpath.replace(directory, "")
for filename in filenames:
if base_path:
diff --git a/st2common/st2common/util/green/shell.py b/st2common/st2common/util/green/shell.py
index 4fd71ef7cf..4b6d79935b 100644
--- a/st2common/st2common/util/green/shell.py
+++ b/st2common/st2common/util/green/shell.py
@@ -27,20 +27,31 @@
from st2common import log as logging
from st2common.util import concurrency
-__all__ = [
- 'run_command'
-]
+__all__ = ["run_command"]
TIMEOUT_EXIT_CODE = -9
LOG = logging.getLogger(__name__)
-def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False,
- cwd=None, env=None, timeout=60, preexec_func=None, kill_func=None,
- read_stdout_func=None, read_stderr_func=None,
- read_stdout_buffer=None, read_stderr_buffer=None, stdin_value=None,
- bufsize=0):
+def run_command(
+ cmd,
+ stdin=None,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ cwd=None,
+ env=None,
+ timeout=60,
+ preexec_func=None,
+ kill_func=None,
+ read_stdout_func=None,
+ read_stderr_func=None,
+ read_stdout_buffer=None,
+ read_stderr_buffer=None,
+ stdin_value=None,
+ bufsize=0,
+):
"""
Run the provided command in a subprocess and wait until it completes.
@@ -89,59 +100,77 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
:rtype: ``tuple`` (exit_code, stdout, stderr, timed_out)
"""
- LOG.debug('Entering st2common.util.green.run_command.')
+ LOG.debug("Entering st2common.util.green.run_command.")
assert isinstance(cmd, (list, tuple) + six.string_types)
- if (read_stdout_func and not read_stderr_func) or (read_stderr_func and not read_stdout_func):
- raise ValueError('Both read_stdout_func and read_stderr_func arguments need '
- 'to be provided.')
+ if (read_stdout_func and not read_stderr_func) or (
+ read_stderr_func and not read_stdout_func
+ ):
+ raise ValueError(
+ "Both read_stdout_func and read_stderr_func arguments need "
+ "to be provided."
+ )
if read_stdout_func and not (read_stdout_buffer or read_stderr_buffer):
- raise ValueError('read_stdout_buffer and read_stderr_buffer arguments need to be provided '
- 'when read_stdout_func is provided')
+ raise ValueError(
+ "read_stdout_buffer and read_stderr_buffer arguments need to be provided "
+ "when read_stdout_func is provided"
+ )
if not env:
- LOG.debug('env argument not provided. using process env (os.environ).')
+ LOG.debug("env argument not provided. using process env (os.environ).")
env = os.environ.copy()
subprocess = concurrency.get_subprocess_module()
# Note: We are using eventlet / gevent friendly implementation of subprocess which uses
# GreenPipe so it doesn't block
- LOG.debug('Creating subprocess.')
- process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr,
- env=env, cwd=cwd, shell=shell, preexec_fn=preexec_func,
- bufsize=bufsize)
+ LOG.debug("Creating subprocess.")
+ process = concurrency.subprocess_popen(
+ args=cmd,
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ env=env,
+ cwd=cwd,
+ shell=shell,
+ preexec_fn=preexec_func,
+ bufsize=bufsize,
+ )
if read_stdout_func:
- LOG.debug('Spawning read_stdout_func function')
- read_stdout_thread = concurrency.spawn(read_stdout_func, process.stdout, read_stdout_buffer)
+ LOG.debug("Spawning read_stdout_func function")
+ read_stdout_thread = concurrency.spawn(
+ read_stdout_func, process.stdout, read_stdout_buffer
+ )
if read_stderr_func:
- LOG.debug('Spawning read_stderr_func function')
- read_stderr_thread = concurrency.spawn(read_stderr_func, process.stderr, read_stderr_buffer)
+ LOG.debug("Spawning read_stderr_func function")
+ read_stderr_thread = concurrency.spawn(
+ read_stderr_func, process.stderr, read_stderr_buffer
+ )
def on_timeout_expired(timeout):
global timed_out
try:
- LOG.debug('Starting process wait inside timeout handler.')
+ LOG.debug("Starting process wait inside timeout handler.")
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
# Command has timed out, kill the process and propagate the error.
# Note: We explicitly set the returncode to indicate the timeout.
- LOG.debug('Command execution timeout reached.')
+ LOG.debug("Command execution timeout reached.")
# NOTE: It's important we set returncode twice - here and below to avoid race in this
# function because "kill_func()" is async and "process.kill()" is not.
process.returncode = TIMEOUT_EXIT_CODE
if kill_func:
- LOG.debug('Calling kill_func.')
+ LOG.debug("Calling kill_func.")
kill_func(process=process)
else:
- LOG.debug('Killing process.')
+ LOG.debug("Killing process.")
process.kill()
# NOTE: It's imporant to set returncode here as well, since call to process.kill() sets
@@ -149,25 +178,27 @@ def on_timeout_expired(timeout):
process.returncode = TIMEOUT_EXIT_CODE
if read_stdout_func and read_stderr_func:
- LOG.debug('Killing read_stdout_thread and read_stderr_thread')
+ LOG.debug("Killing read_stdout_thread and read_stderr_thread")
concurrency.kill(read_stdout_thread)
concurrency.kill(read_stderr_thread)
- LOG.debug('Spawning timeout handler thread.')
+ LOG.debug("Spawning timeout handler thread.")
timeout_thread = concurrency.spawn(on_timeout_expired, timeout)
- LOG.debug('Attaching to process.')
+ LOG.debug("Attaching to process.")
if stdin_value:
if six.PY3:
- stdin_value = stdin_value.encode('utf-8')
+ stdin_value = stdin_value.encode("utf-8")
process.stdin.write(stdin_value)
if read_stdout_func and read_stderr_func:
- LOG.debug('Using real-time stdout and stderr read mode, calling process.wait()')
+ LOG.debug("Using real-time stdout and stderr read mode, calling process.wait()")
process.wait()
else:
- LOG.debug('Using delayed stdout and stderr read mode, calling process.communicate()')
+ LOG.debug(
+ "Using delayed stdout and stderr read mode, calling process.communicate()"
+ )
stdout, stderr = process.communicate()
concurrency.cancel(timeout_thread)
@@ -182,11 +213,11 @@ def on_timeout_expired(timeout):
stderr = read_stderr_buffer.getvalue()
if exit_code == TIMEOUT_EXIT_CODE:
- LOG.debug('Timeout.')
+ LOG.debug("Timeout.")
timed_out = True
else:
- LOG.debug('No timeout.')
+ LOG.debug("No timeout.")
timed_out = False
- LOG.debug('Returning.')
+ LOG.debug("Returning.")
return (exit_code, stdout, stderr, timed_out)
diff --git a/st2common/st2common/util/greenpooldispatch.py b/st2common/st2common/util/greenpooldispatch.py
index d85ebfbf5d..156d530116 100644
--- a/st2common/st2common/util/greenpooldispatch.py
+++ b/st2common/st2common/util/greenpooldispatch.py
@@ -21,9 +21,7 @@
from st2common import log as logging
-__all__ = [
- 'BufferedDispatcher'
-]
+__all__ = ["BufferedDispatcher"]
# If the thread pool has been occupied with no empty threads for more than this number of seconds
# a message will be logged
@@ -38,14 +36,20 @@
class BufferedDispatcher(object):
-
- def __init__(self, dispatch_pool_size=50, monitor_thread_empty_q_sleep_time=5,
- monitor_thread_no_workers_sleep_time=1, name=None):
+ def __init__(
+ self,
+ dispatch_pool_size=50,
+ monitor_thread_empty_q_sleep_time=5,
+ monitor_thread_no_workers_sleep_time=1,
+ name=None,
+ ):
self._pool_limit = dispatch_pool_size
self._dispatcher_pool = eventlet.GreenPool(dispatch_pool_size)
self._dispatch_monitor_thread = eventlet.greenthread.spawn(self._flush)
self._monitor_thread_empty_q_sleep_time = monitor_thread_empty_q_sleep_time
- self._monitor_thread_no_workers_sleep_time = monitor_thread_no_workers_sleep_time
+ self._monitor_thread_no_workers_sleep_time = (
+ monitor_thread_no_workers_sleep_time
+ )
self._name = name
self._work_buffer = six.moves.queue.Queue()
@@ -77,7 +81,9 @@ def _flush_now(self):
now = time.time()
if (now - self._pool_last_free_ts) >= POOL_BUSY_THRESHOLD_SECONDS:
- LOG.info(POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS))
+ LOG.info(
+ POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS)
+ )
return
@@ -90,8 +96,15 @@ def _flush_now(self):
def __repr__(self):
free_count = self._dispatcher_pool.free()
- values = (self.name, self._pool_limit, free_count, self._monitor_thread_empty_q_sleep_time,
- self._monitor_thread_no_workers_sleep_time)
- return ('' %
- values)
+ values = (
+ self.name,
+ self._pool_limit,
+ free_count,
+ self._monitor_thread_empty_q_sleep_time,
+ self._monitor_thread_no_workers_sleep_time,
+ )
+ return (
+ ""
+ % values
+ )
diff --git a/st2common/st2common/util/gunicorn_workers.py b/st2common/st2common/util/gunicorn_workers.py
index 61eebe84e4..69942ac309 100644
--- a/st2common/st2common/util/gunicorn_workers.py
+++ b/st2common/st2common/util/gunicorn_workers.py
@@ -20,9 +20,7 @@
import six
from gunicorn.workers.sync import SyncWorker
-__all__ = [
- 'EventletSyncWorker'
-]
+__all__ = ["EventletSyncWorker"]
class EventletSyncWorker(SyncWorker):
@@ -44,7 +42,7 @@ def handle_quit(self, sig, frame):
except AssertionError as e:
msg = six.text_type(e)
- if 'do not call blocking functions from the mainloop' in msg:
+ if "do not call blocking functions from the mainloop" in msg:
# Workaround for "do not call blocking functions from the mainloop" issue
sys.exit(0)
diff --git a/st2common/st2common/util/hash.py b/st2common/st2common/util/hash.py
index f0a5596379..3d7c83328c 100644
--- a/st2common/st2common/util/hash.py
+++ b/st2common/st2common/util/hash.py
@@ -19,12 +19,10 @@
import hashlib
-__all__ = [
- 'hash'
-]
+__all__ = ["hash"]
-FIXED_SALT = 'saltnpepper'
+FIXED_SALT = "saltnpepper"
def hash(value, salt=FIXED_SALT):
diff --git a/st2common/st2common/util/http.py b/st2common/st2common/util/http.py
index e11a277be6..26aa6d445d 100644
--- a/st2common/st2common/util/http.py
+++ b/st2common/st2common/util/http.py
@@ -18,17 +18,20 @@
http_client = six.moves.http_client
-__all__ = [
- 'HTTP_SUCCESS',
- 'parse_content_type_header'
+__all__ = ["HTTP_SUCCESS", "parse_content_type_header"]
+
+HTTP_SUCCESS = [
+ http_client.OK,
+ http_client.CREATED,
+ http_client.ACCEPTED,
+ http_client.NON_AUTHORITATIVE_INFORMATION,
+ http_client.NO_CONTENT,
+ http_client.RESET_CONTENT,
+ http_client.PARTIAL_CONTENT,
+ http_client.MULTI_STATUS,
+ http_client.IM_USED,
]
-HTTP_SUCCESS = [http_client.OK, http_client.CREATED, http_client.ACCEPTED,
- http_client.NON_AUTHORITATIVE_INFORMATION, http_client.NO_CONTENT,
- http_client.RESET_CONTENT, http_client.PARTIAL_CONTENT,
- http_client.MULTI_STATUS, http_client.IM_USED,
- ]
-
def parse_content_type_header(content_type):
"""
@@ -37,13 +40,13 @@ def parse_content_type_header(content_type):
:rype: ``tuple``
"""
- if ';' in content_type:
- split = content_type.split(';')
+ if ";" in content_type:
+ split = content_type.split(";")
media = split[0]
options = {}
for pair in split[1:]:
- split_pair = pair.split('=', 1)
+ split_pair = pair.split("=", 1)
if len(split_pair) != 2:
continue
diff --git a/st2common/st2common/util/ip_utils.py b/st2common/st2common/util/ip_utils.py
index 4e2a00357a..53253432d8 100644
--- a/st2common/st2common/util/ip_utils.py
+++ b/st2common/st2common/util/ip_utils.py
@@ -21,11 +21,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'is_ipv4',
- 'is_ipv6',
- 'split_host_port'
-]
+__all__ = ["is_ipv4", "is_ipv6", "split_host_port"]
BRACKET_PATTERN = r"^\[.*\]" # IPv6 bracket pattern to specify port
COMPILED_BRACKET_PATTERN = re.compile(BRACKET_PATTERN)
@@ -91,30 +87,32 @@ def split_host_port(host_str):
# Check if it's square bracket style.
match = COMPILED_BRACKET_PATTERN.match(host_str)
if match:
- LOG.debug('Square bracket style.')
+ LOG.debug("Square bracket style.")
# Check if square bracket style no port.
match = COMPILED_HOST_ONLY_IN_BRACKET_PATTERN.match(host_str)
if match:
- hostname = match.group().strip('[]')
+ hostname = match.group().strip("[]")
return (hostname, port)
- hostname, separator, port = hostname.rpartition(':')
+ hostname, separator, port = hostname.rpartition(":")
try:
- LOG.debug('host_str: %s, hostname: %s port: %s' % (host_str, hostname, port))
+ LOG.debug(
+ "host_str: %s, hostname: %s port: %s" % (host_str, hostname, port)
+ )
port = int(port)
- hostname = hostname.strip('[]')
+ hostname = hostname.strip("[]")
return (hostname, port)
except:
- raise Exception('Invalid port %s specified.' % port)
+ raise Exception("Invalid port %s specified." % port)
else:
- LOG.debug('Non-bracket address. host_str: %s' % host_str)
- if ':' in host_str:
- LOG.debug('Non-bracket with port.')
- hostname, separator, port = hostname.rpartition(':')
+ LOG.debug("Non-bracket address. host_str: %s" % host_str)
+ if ":" in host_str:
+ LOG.debug("Non-bracket with port.")
+ hostname, separator, port = hostname.rpartition(":")
try:
port = int(port)
return (hostname, port)
except:
- raise Exception('Invalid port %s specified.' % port)
+ raise Exception("Invalid port %s specified." % port)
return (hostname, port)
diff --git a/st2common/st2common/util/isotime.py b/st2common/st2common/util/isotime.py
index 0830393bf8..0c6ca1c4d4 100644
--- a/st2common/st2common/util/isotime.py
+++ b/st2common/st2common/util/isotime.py
@@ -25,17 +25,14 @@
from st2common.util import date as date_utils
import six
-__all__ = [
- 'format',
- 'validate',
- 'parse'
-]
+__all__ = ["format", "validate", "parse"]
-ISO8601_FORMAT = '%Y-%m-%dT%H:%M:%S'
-ISO8601_FORMAT_MICROSECOND = '%Y-%m-%dT%H:%M:%S.%f'
-ISO8601_UTC_REGEX = \
- r'^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$'
+ISO8601_FORMAT = "%Y-%m-%dT%H:%M:%S"
+ISO8601_FORMAT_MICROSECOND = "%Y-%m-%dT%H:%M:%S.%f"
+ISO8601_UTC_REGEX = (
+ r"^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$"
+)
def format(dt, usec=True, offset=True):
@@ -53,20 +50,21 @@ def format(dt, usec=True, offset=True):
fmt = ISO8601_FORMAT_MICROSECOND if usec else ISO8601_FORMAT
if offset:
- ost = dt.strftime('%z')
- ost = (ost[:3] + ':' + ost[3:]) if ost else '+00:00'
+ ost = dt.strftime("%z")
+ ost = (ost[:3] + ":" + ost[3:]) if ost else "+00:00"
else:
- tz = dt.tzinfo.tzname(dt) if dt.tzinfo else 'UTC'
- ost = 'Z' if tz == 'UTC' else tz
+ tz = dt.tzinfo.tzname(dt) if dt.tzinfo else "UTC"
+ ost = "Z" if tz == "UTC" else tz
return dt.strftime(fmt) + ost
def validate(value, raise_exception=True):
- if (isinstance(value, datetime.datetime) or
- (type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value))):
+ if isinstance(value, datetime.datetime) or (
+ type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value)
+ ):
return True
if raise_exception:
- raise ValueError('Datetime value does not match expected format.')
+ raise ValueError("Datetime value does not match expected format.")
return False
diff --git a/st2common/st2common/util/jinja.py b/st2common/st2common/util/jinja.py
index 9234986f9f..4472246908 100644
--- a/st2common/st2common/util/jinja.py
+++ b/st2common/st2common/util/jinja.py
@@ -22,21 +22,14 @@
from st2common.util.compat import to_unicode
-__all__ = [
- 'get_jinja_environment',
- 'render_values',
- 'is_jinja_expression'
-]
+__all__ = ["get_jinja_environment", "render_values", "is_jinja_expression"]
-JINJA_EXPRESSIONS_START_MARKERS = [
- '{{',
- '{%'
-]
+JINJA_EXPRESSIONS_START_MARKERS = ["{{", "{%"]
-JINJA_REGEX = '({{(.*)}})'
+JINJA_REGEX = "({{(.*)}})"
JINJA_REGEX_PTRN = re.compile(JINJA_REGEX)
-JINJA_BLOCK_REGEX = '({%(.*)%})'
+JINJA_BLOCK_REGEX = "({%(.*)%})"
JINJA_BLOCK_REGEX_PTRN = re.compile(JINJA_BLOCK_REGEX)
@@ -53,59 +46,52 @@ def get_filters():
from st2common.expressions.functions import path
return {
- 'decrypt_kv': datastore.decrypt_kv,
-
- 'from_json_string': data.from_json_string,
- 'from_yaml_string': data.from_yaml_string,
- 'json_escape': data.json_escape,
- 'jsonpath_query': data.jsonpath_query,
- 'to_complex': data.to_complex,
- 'to_json_string': data.to_json_string,
- 'to_yaml_string': data.to_yaml_string,
-
- 'regex_match': regex.regex_match,
- 'regex_replace': regex.regex_replace,
- 'regex_search': regex.regex_search,
- 'regex_substring': regex.regex_substring,
-
- 'to_human_time_from_seconds': time.to_human_time_from_seconds,
-
- 'version_compare': version.version_compare,
- 'version_more_than': version.version_more_than,
- 'version_less_than': version.version_less_than,
- 'version_equal': version.version_equal,
- 'version_match': version.version_match,
- 'version_bump_major': version.version_bump_major,
- 'version_bump_minor': version.version_bump_minor,
- 'version_bump_patch': version.version_bump_patch,
- 'version_strip_patch': version.version_strip_patch,
- 'use_none': data.use_none,
-
- 'basename': path.basename,
- 'dirname': path.dirname
+ "decrypt_kv": datastore.decrypt_kv,
+ "from_json_string": data.from_json_string,
+ "from_yaml_string": data.from_yaml_string,
+ "json_escape": data.json_escape,
+ "jsonpath_query": data.jsonpath_query,
+ "to_complex": data.to_complex,
+ "to_json_string": data.to_json_string,
+ "to_yaml_string": data.to_yaml_string,
+ "regex_match": regex.regex_match,
+ "regex_replace": regex.regex_replace,
+ "regex_search": regex.regex_search,
+ "regex_substring": regex.regex_substring,
+ "to_human_time_from_seconds": time.to_human_time_from_seconds,
+ "version_compare": version.version_compare,
+ "version_more_than": version.version_more_than,
+ "version_less_than": version.version_less_than,
+ "version_equal": version.version_equal,
+ "version_match": version.version_match,
+ "version_bump_major": version.version_bump_major,
+ "version_bump_minor": version.version_bump_minor,
+ "version_bump_patch": version.version_bump_patch,
+ "version_strip_patch": version.version_strip_patch,
+ "use_none": data.use_none,
+ "basename": path.basename,
+ "dirname": path.dirname,
}
def get_jinja_environment(allow_undefined=False, trim_blocks=True, lstrip_blocks=True):
- '''
+ """
jinja2.Environment object that is setup with right behaviors and custom filters.
:param strict_undefined: If should allow undefined variables in templates
:type strict_undefined: ``bool``
- '''
+ """
# Late import to avoid very expensive in-direct import (~1 second) when this function
# is not called / used
import jinja2
undefined = jinja2.Undefined if allow_undefined else jinja2.StrictUndefined
env = jinja2.Environment( # nosec
- undefined=undefined,
- trim_blocks=trim_blocks,
- lstrip_blocks=lstrip_blocks
+ undefined=undefined, trim_blocks=trim_blocks, lstrip_blocks=lstrip_blocks
)
env.filters.update(get_filters())
- env.tests['in'] = lambda item, list: item in list
+ env.tests["in"] = lambda item, list: item in list
return env
@@ -130,7 +116,7 @@ def render_values(mapping=None, context=None, allow_undefined=False):
# This mean __context is a reserve key word although backwards compat is preserved by making
# sure that real context is updated later and therefore will override the __context value.
super_context = {}
- super_context['__context'] = context
+ super_context["__context"] = context
super_context.update(context)
env = get_jinja_environment(allow_undefined=allow_undefined)
@@ -150,7 +136,7 @@ def render_values(mapping=None, context=None, allow_undefined=False):
v = str(v)
try:
- LOG.info('Rendering string %s. Super context=%s', v, super_context)
+ LOG.info("Rendering string %s. Super context=%s", v, super_context)
rendered_v = env.from_string(v).render(super_context)
except Exception as e:
# Attach key and value which failed the rendering
@@ -166,7 +152,12 @@ def render_values(mapping=None, context=None, allow_undefined=False):
if reverse_json_dumps:
rendered_v = json.loads(rendered_v)
rendered_mapping[k] = rendered_v
- LOG.info('Mapping: %s, rendered_mapping: %s, context: %s', mapping, rendered_mapping, context)
+ LOG.info(
+ "Mapping: %s, rendered_mapping: %s, context: %s",
+ mapping,
+ rendered_mapping,
+ context,
+ )
return rendered_mapping
@@ -194,6 +185,6 @@ def convert_jinja_to_raw_block(value):
if isinstance(value, six.string_types):
if JINJA_REGEX_PTRN.findall(value) or JINJA_BLOCK_REGEX_PTRN.findall(value):
- return '{% raw %}' + value + '{% endraw %}'
+ return "{% raw %}" + value + "{% endraw %}"
return value
diff --git a/st2common/st2common/util/jsonify.py b/st2common/st2common/util/jsonify.py
index 1f47cec1b0..16a95dde99 100644
--- a/st2common/st2common/util/jsonify.py
+++ b/st2common/st2common/util/jsonify.py
@@ -25,18 +25,12 @@
import six
-__all__ = [
- 'json_encode',
- 'json_loads',
- 'try_loads',
-
- 'get_json_type_for_python_value'
-]
+__all__ = ["json_encode", "json_loads", "try_loads", "get_json_type_for_python_value"]
class GenericJSON(JSONEncoder):
def default(self, obj): # pylint: disable=method-hidden
- if hasattr(obj, '__json__') and six.callable(obj.__json__):
+ if hasattr(obj, "__json__") and six.callable(obj.__json__):
return obj.__json__()
else:
return JSONEncoder.default(self, obj)
@@ -47,7 +41,7 @@ def json_encode(obj, indent=4):
def load_file(path):
- with open(path, 'r') as fd:
+ with open(path, "r") as fd:
return json.load(fd)
@@ -92,16 +86,16 @@ def get_json_type_for_python_value(value):
:rtype: ``str``
"""
if isinstance(value, six.text_type):
- return 'string'
+ return "string"
elif isinstance(value, (int, float)):
- return 'number'
+ return "number"
elif isinstance(value, dict):
- return 'object'
+ return "object"
elif isinstance(value, (list, tuple)):
- return 'array'
+ return "array"
elif isinstance(value, bool):
- return 'boolean'
+ return "boolean"
elif value is None:
- return 'null'
+ return "null"
else:
- return 'unknown'
+ return "unknown"
diff --git a/st2common/st2common/util/keyvalue.py b/st2common/st2common/util/keyvalue.py
index 05246d2d32..cad32250a8 100644
--- a/st2common/st2common/util/keyvalue.py
+++ b/st2common/st2common/util/keyvalue.py
@@ -24,22 +24,23 @@
from st2common.rbac.backends import get_rbac_backend
from st2common.persistence.keyvalue import KeyValuePair
from st2common.services.config import deserialize_key_value
-from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE,
- ALLOWED_SCOPES)
+from st2common.constants.keyvalue import (
+ FULL_SYSTEM_SCOPE,
+ FULL_USER_SCOPE,
+ USER_SCOPE,
+ ALLOWED_SCOPES,
+)
from st2common.models.db.auth import UserDB
from st2common.exceptions.rbac import AccessDeniedError
-__all__ = [
- 'get_datastore_full_scope',
- 'get_key'
-]
+__all__ = ["get_datastore_full_scope", "get_key"]
LOG = logging.getLogger(__name__)
def _validate_scope(scope):
if scope not in ALLOWED_SCOPES:
- msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES)
+ msg = "Scope %s is not in allowed scopes list: %s." % (scope, ALLOWED_SCOPES)
raise ValueError(msg)
@@ -48,9 +49,9 @@ def _validate_decrypt_query_parameter(decrypt, scope, is_admin, user_db):
Validate that the provider user is either admin or requesting to decrypt value for
themselves.
"""
- is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE)
+ is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE
if decrypt and (not is_user_scope and not is_admin):
- msg = 'Decrypt option requires administrator access'
+ msg = "Decrypt option requires administrator access"
raise AccessDeniedError(message=msg, user_db=user_db)
@@ -61,7 +62,7 @@ def get_datastore_full_scope(scope):
if DATASTORE_PARENT_SCOPE in scope:
return scope
- return '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope)
+ return "%s%s%s" % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope)
def _derive_scope_and_key(key, user, scope=None):
@@ -75,10 +76,10 @@ def _derive_scope_and_key(key, user, scope=None):
if scope is not None:
return scope, key
- if key.startswith('system.'):
- return FULL_SYSTEM_SCOPE, key[key.index('.') + 1:]
+ if key.startswith("system."):
+ return FULL_SYSTEM_SCOPE, key[key.index(".") + 1 :]
- return FULL_USER_SCOPE, '%s:%s' % (user, key)
+ return FULL_USER_SCOPE, "%s:%s" % (user, key)
def get_key(key=None, user_db=None, scope=None, decrypt=False):
@@ -86,10 +87,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False):
Retrieve key from KVP store
"""
if not isinstance(key, six.string_types):
- raise TypeError('Given key is not typeof string.')
+ raise TypeError("Given key is not typeof string.")
if not isinstance(decrypt, bool):
- raise TypeError('Decrypt parameter is not typeof bool.')
+ raise TypeError("Decrypt parameter is not typeof bool.")
if not user_db:
# Use system user
@@ -98,9 +99,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False):
scope, key_id = _derive_scope_and_key(key=key, user=user_db.name, scope=scope)
scope = get_datastore_full_scope(scope)
- LOG.debug('get_key key_id: %s, scope: %s, user: %s, decrypt: %s' % (key_id, scope,
- str(user_db.name),
- decrypt))
+ LOG.debug(
+ "get_key key_id: %s, scope: %s, user: %s, decrypt: %s"
+ % (key_id, scope, str(user_db.name), decrypt)
+ )
_validate_scope(scope=scope)
@@ -108,8 +110,9 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False):
is_admin = rbac_utils.user_is_admin(user_db=user_db)
# User needs to be either admin or requesting item for itself
- _validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, is_admin=is_admin,
- user_db=user_db)
+ _validate_decrypt_query_parameter(
+ decrypt=decrypt, scope=scope, is_admin=is_admin, user_db=user_db
+ )
# Get the key value pair by scope and name.
kvp = KeyValuePair.get_by_scope_and_name(scope, key_id)
diff --git a/st2common/st2common/util/loader.py b/st2common/st2common/util/loader.py
index 0e27a0da32..1c5a5a4b54 100644
--- a/st2common/st2common/util/loader.py
+++ b/st2common/st2common/util/loader.py
@@ -28,19 +28,14 @@
from st2common.exceptions.plugins import IncompatiblePluginException
from st2common import log as logging
-__all__ = [
- 'register_plugin',
- 'register_plugin_class',
-
- 'load_meta_file'
-]
+__all__ = ["register_plugin", "register_plugin_class", "load_meta_file"]
LOG = logging.getLogger(__name__)
-PYTHON_EXTENSION = '.py'
-ALLOWED_EXTS = ['.json', '.yaml', '.yml']
-PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load}
+PYTHON_EXTENSION = ".py"
+ALLOWED_EXTS = [".json", ".yaml", ".yml"]
+PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load}
# Cache for dynamically loaded runner modules
RUNNER_MODULES_CACHE = defaultdict(dict)
@@ -48,7 +43,9 @@
def _register_plugin_path(plugin_dir_abs_path):
if not os.path.isdir(plugin_dir_abs_path):
- raise Exception('Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path))
+ raise Exception(
+ 'Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path)
+ )
for x in sys.path:
if plugin_dir_abs_path in (x, x + os.sep):
@@ -59,15 +56,21 @@ def _register_plugin_path(plugin_dir_abs_path):
def _get_plugin_module(plugin_file_path):
plugin_module = os.path.basename(plugin_file_path)
if plugin_module.endswith(PYTHON_EXTENSION):
- plugin_module = plugin_module[:plugin_module.rfind('.py')]
+ plugin_module = plugin_module[: plugin_module.rfind(".py")]
else:
plugin_module = None
return plugin_module
def _get_classes_in_module(module):
- return [kls for name, kls in inspect.getmembers(module,
- lambda member: inspect.isclass(member) and member.__module__ == module.__name__)]
+ return [
+ kls
+ for name, kls in inspect.getmembers(
+ module,
+ lambda member: inspect.isclass(member)
+ and member.__module__ == module.__name__,
+ )
+ ]
def _get_plugin_classes(module_name):
@@ -92,7 +95,7 @@ def _get_plugin_methods(plugin_klass):
method_names = []
for name, method in methods:
method_properties = method.__dict__
- is_abstract = method_properties.get('__isabstractmethod__', False)
+ is_abstract = method_properties.get("__isabstractmethod__", False)
if is_abstract:
continue
@@ -102,16 +105,18 @@ def _get_plugin_methods(plugin_klass):
def _validate_methods(plugin_base_class, plugin_klass):
- '''
+ """
XXX: This is hacky but we'd like to validate the methods
in plugin_impl at least has all the *abstract* methods in
plugin_base_class.
- '''
+ """
expected_methods = plugin_base_class.__abstractmethods__
plugin_methods = _get_plugin_methods(plugin_klass)
for method in expected_methods:
if method not in plugin_methods:
- message = 'Class "%s" doesn\'t implement required "%s" method from the base class'
+ message = (
+ 'Class "%s" doesn\'t implement required "%s" method from the base class'
+ )
raise IncompatiblePluginException(message % (plugin_klass.__name__, method))
@@ -147,8 +152,10 @@ def register_plugin_class(base_class, file_path, class_name):
klass = getattr(module, class_name, None)
if not klass:
- raise Exception('Plugin file "%s" doesn\'t expose class named "%s"' %
- (file_path, class_name))
+ raise Exception(
+ 'Plugin file "%s" doesn\'t expose class named "%s"'
+ % (file_path, class_name)
+ )
_register_plugin(base_class, klass)
return klass
@@ -173,12 +180,14 @@ def register_plugin(plugin_base_class, plugin_abs_file_path):
registered_plugins.append(klass)
except Exception as e:
LOG.exception(e)
- LOG.debug('Skipping class %s as it doesn\'t match specs.', klass)
+ LOG.debug("Skipping class %s as it doesn't match specs.", klass)
continue
if len(registered_plugins) == 0:
- raise Exception('Found no classes in plugin file "%s" matching requirements.' %
- (plugin_abs_file_path))
+ raise Exception(
+ 'Found no classes in plugin file "%s" matching requirements.'
+ % (plugin_abs_file_path)
+ )
return registered_plugins
@@ -189,16 +198,17 @@ def load_meta_file(file_path):
file_name, file_ext = os.path.splitext(file_path)
if file_ext not in ALLOWED_EXTS:
- raise Exception('Unsupported meta type %s, file %s. Allowed: %s' %
- (file_ext, file_path, ALLOWED_EXTS))
+ raise Exception(
+ "Unsupported meta type %s, file %s. Allowed: %s"
+ % (file_ext, file_path, ALLOWED_EXTS)
+ )
- with open(file_path, 'r') as f:
+ with open(file_path, "r") as f:
return PARSER_FUNCS[file_ext](f)
def get_available_plugins(namespace):
- """Return names of the available / installed plugins for a given namespace.
- """
+ """Return names of the available / installed plugins for a given namespace."""
from stevedore.extension import ExtensionManager
manager = ExtensionManager(namespace=namespace, invoke_on_load=False)
@@ -206,9 +216,10 @@ def get_available_plugins(namespace):
def get_plugin_instance(namespace, name, invoke_on_load=True):
- """Return class instance for the provided plugin name and namespace.
- """
+ """Return class instance for the provided plugin name and namespace."""
from stevedore.driver import DriverManager
- manager = DriverManager(namespace=namespace, name=name, invoke_on_load=invoke_on_load)
+ manager = DriverManager(
+ namespace=namespace, name=name, invoke_on_load=invoke_on_load
+ )
return manager.driver
diff --git a/st2common/st2common/util/misc.py b/st2common/st2common/util/misc.py
index 28773abedb..6a1027e9fe 100644
--- a/st2common/st2common/util/misc.py
+++ b/st2common/st2common/util/misc.py
@@ -26,18 +26,17 @@
import six
__all__ = [
- 'prefix_dict_keys',
- 'compare_path_file_name',
- 'get_field_name_from_mongoengine_error',
-
- 'sanitize_output',
- 'strip_shell_chars',
- 'rstrip_last_char',
- 'lowercase_value'
+ "prefix_dict_keys",
+ "compare_path_file_name",
+ "get_field_name_from_mongoengine_error",
+ "sanitize_output",
+ "strip_shell_chars",
+ "rstrip_last_char",
+ "lowercase_value",
]
-def prefix_dict_keys(dictionary, prefix='_'):
+def prefix_dict_keys(dictionary, prefix="_"):
"""
Prefix dictionary keys with a provided prefix.
@@ -52,7 +51,7 @@ def prefix_dict_keys(dictionary, prefix='_'):
result = {}
for key, value in six.iteritems(dictionary):
- result['%s%s' % (prefix, key)] = value
+ result["%s%s" % (prefix, key)] = value
return result
@@ -89,7 +88,7 @@ def sanitize_output(input_str, uses_pty=False):
output = strip_shell_chars(input_str)
if uses_pty:
- output = output.replace('\r\n', '\n')
+ output = output.replace("\r\n", "\n")
return output
@@ -105,8 +104,8 @@ def strip_shell_chars(input_str):
:rtype: ``str``
"""
- stripped_str = rstrip_last_char(input_str, '\n')
- stripped_str = rstrip_last_char(stripped_str, '\r')
+ stripped_str = rstrip_last_char(input_str, "\n")
+ stripped_str = rstrip_last_char(stripped_str, "\r")
return stripped_str
@@ -127,7 +126,7 @@ def rstrip_last_char(input_str, char_to_strip):
return input_str
if input_str.endswith(char_to_strip):
- return input_str[:-len(char_to_strip)]
+ return input_str[: -len(char_to_strip)]
return input_str
@@ -153,10 +152,10 @@ def get_normalized_file_path(file_path):
:rtype: ``str``
"""
- if hasattr(sys, 'frozen'): # support for py2exe
- file_path = 'logging%s__init__%s' % (os.sep, file_path[-4:])
- elif file_path[-4:].lower() in ['.pyc', '.pyo']:
- file_path = file_path[:-4] + '.py'
+ if hasattr(sys, "frozen"): # support for py2exe
+ file_path = "logging%s__init__%s" % (os.sep, file_path[-4:])
+ elif file_path[-4:].lower() in [".pyc", ".pyo"]:
+ file_path = file_path[:-4] + ".py"
else:
file_path = file_path
@@ -193,7 +192,7 @@ def get_field_name_from_mongoengine_error(exc):
"""
msg = str(exc)
- match = re.match("Cannot resolve field \"(.+?)\"", msg)
+ match = re.match('Cannot resolve field "(.+?)"', msg)
if match:
return match.groups()[0]
@@ -201,7 +200,9 @@ def get_field_name_from_mongoengine_error(exc):
return msg
-def ignore_and_log_exception(exc_classes=(Exception,), logger=None, level=logging.WARNING):
+def ignore_and_log_exception(
+ exc_classes=(Exception,), logger=None, level=logging.WARNING
+):
"""
Decorator which catches the provided exception classes and logs them instead of letting them
bubble all the way up.
@@ -214,13 +215,14 @@ def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exc_classes as e:
- if len(args) >= 1 and getattr(args[0], '__class__', None):
- func_name = '%s.%s' % (args[0].__class__.__name__, func.__name__)
+ if len(args) >= 1 and getattr(args[0], "__class__", None):
+ func_name = "%s.%s" % (args[0].__class__.__name__, func.__name__)
else:
func_name = func.__name__
- message = ('Exception in fuction "%s": %s' % (func_name, str(e)))
+ message = 'Exception in fuction "%s": %s' % (func_name, str(e))
logger.log(level, message)
return wrapper
+
return decorator
diff --git a/st2common/st2common/util/mongoescape.py b/st2common/st2common/util/mongoescape.py
index 6d42b4972c..d75d9502f4 100644
--- a/st2common/st2common/util/mongoescape.py
+++ b/st2common/st2common/util/mongoescape.py
@@ -21,17 +21,22 @@
from st2common.util.ujson import fast_deepcopy
# Note: Because of old rule escaping code, two different characters can be translated back to dot
-RULE_CRITERIA_UNESCAPED = ['.']
-RULE_CRITERIA_ESCAPED = [u'\u2024']
-RULE_CRITERIA_ESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED)))
-RULE_CRITERIA_UNESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)))
+RULE_CRITERIA_UNESCAPED = ["."]
+RULE_CRITERIA_ESCAPED = ["\u2024"]
+RULE_CRITERIA_ESCAPE_TRANSLATION = dict(
+ list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED))
+)
+RULE_CRITERIA_UNESCAPE_TRANSLATION = dict(
+ list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED))
+)
# http://docs.mongodb.org/manual/faq/developers/#faq-dollar-sign-escaping
-UNESCAPED = ['.', '$']
-ESCAPED = [u'\uFF0E', u'\uFF04']
+UNESCAPED = [".", "$"]
+ESCAPED = ["\uFF0E", "\uFF04"]
ESCAPE_TRANSLATION = dict(list(zip(UNESCAPED, ESCAPED)))
UNESCAPE_TRANSLATION = dict(
- list(zip(ESCAPED, UNESCAPED)) + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED))
+ list(zip(ESCAPED, UNESCAPED))
+ + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED))
)
diff --git a/st2common/st2common/util/monkey_patch.py b/st2common/st2common/util/monkey_patch.py
index 5a042fd656..76b4a191de 100644
--- a/st2common/st2common/util/monkey_patch.py
+++ b/st2common/st2common/util/monkey_patch.py
@@ -22,13 +22,13 @@
import sys
__all__ = [
- 'monkey_patch',
- 'use_select_poll_workaround',
- 'is_use_debugger_flag_provided'
+ "monkey_patch",
+ "use_select_poll_workaround",
+ "is_use_debugger_flag_provided",
]
-USE_DEBUGGER_FLAG = '--use-debugger'
-PARENT_ARGS_FLAG = '--parent-args='
+USE_DEBUGGER_FLAG = "--use-debugger"
+PARENT_ARGS_FLAG = "--parent-args="
def monkey_patch(patch_thread=None):
@@ -48,7 +48,9 @@ def monkey_patch(patch_thread=None):
if patch_thread is None:
patch_thread = not is_use_debugger_flag_provided()
- eventlet.monkey_patch(os=True, select=True, socket=True, thread=patch_thread, time=True)
+ eventlet.monkey_patch(
+ os=True, select=True, socket=True, thread=patch_thread, time=True
+ )
def use_select_poll_workaround(nose_only=True):
@@ -80,20 +82,20 @@ def use_select_poll_workaround(nose_only=True):
import eventlet
# Work around to get tests to pass with eventlet >= 0.20.0
- if not nose_only or (nose_only and 'nose' in sys.modules.keys()):
+ if not nose_only or (nose_only and "nose" in sys.modules.keys()):
# Add back blocking poll() to eventlet monkeypatched select
- original_poll = eventlet.patcher.original('select').poll
+ original_poll = eventlet.patcher.original("select").poll
select.poll = original_poll
- sys.modules['select'] = select
+ sys.modules["select"] = select
subprocess.select = select
if sys.version_info >= (3, 6, 5):
# If we also don't patch selectors.select, it will fail with Python >= 3.6.5
import selectors # pylint: disable=import-error
- sys.modules['selectors'] = selectors
- selectors.select = sys.modules['select']
+ sys.modules["selectors"] = selectors
+ selectors.select = sys.modules["select"]
def is_use_debugger_flag_provided():
diff --git a/st2common/st2common/util/output_schema.py b/st2common/st2common/util/output_schema.py
index 607f1af0bb..2bde19c3c0 100644
--- a/st2common/st2common/util/output_schema.py
+++ b/st2common/st2common/util/output_schema.py
@@ -26,37 +26,36 @@
def _validate_runner(runner_schema, result):
- LOG.debug('Validating runner output: %s', runner_schema)
+ LOG.debug("Validating runner output: %s", runner_schema)
runner_schema = {
"type": "object",
"properties": runner_schema,
- "additionalProperties": False
+ "additionalProperties": False,
}
- schema.validate(result, runner_schema, cls=schema.get_validator('custom'))
+ schema.validate(result, runner_schema, cls=schema.get_validator("custom"))
def _validate_action(action_schema, result, output_key):
- LOG.debug('Validating action output: %s', action_schema)
+ LOG.debug("Validating action output: %s", action_schema)
final_result = result[output_key]
action_schema = {
"type": "object",
"properties": action_schema,
- "additionalProperties": False
+ "additionalProperties": False,
}
- schema.validate(final_result, action_schema, cls=schema.get_validator('custom'))
+ schema.validate(final_result, action_schema, cls=schema.get_validator("custom"))
def validate_output(runner_schema, action_schema, result, status, output_key):
- """ Validate output of action with runner and action schema.
- """
+ """Validate output of action with runner and action schema."""
try:
- LOG.debug('Validating action output: %s', result)
- LOG.debug('Output Key: %s', output_key)
+ LOG.debug("Validating action output: %s", result)
+ LOG.debug("Output Key: %s", output_key)
if runner_schema:
_validate_runner(runner_schema, result)
@@ -64,26 +63,26 @@ def validate_output(runner_schema, action_schema, result, status, output_key):
_validate_action(action_schema, result, output_key)
except jsonschema.ValidationError:
- LOG.exception('Failed to validate output.')
+ LOG.exception("Failed to validate output.")
_, ex, _ = sys.exc_info()
# mark execution as failed.
status = action_constants.LIVEACTION_STATUS_FAILED
# include the error message and traceback to try and provide some hints.
result = {
- 'error': str(ex),
- 'message': 'Error validating output. See error output for more details.',
+ "error": str(ex),
+ "message": "Error validating output. See error output for more details.",
}
return (result, status)
except:
- LOG.exception('Failed to validate output.')
+ LOG.exception("Failed to validate output.")
_, ex, tb = sys.exc_info()
# mark execution as failed.
status = action_constants.LIVEACTION_STATUS_FAILED
# include the error message and traceback to try and provide some hints.
result = {
- 'traceback': ''.join(traceback.format_tb(tb, 20)),
- 'error': str(ex),
- 'message': 'Error validating output. See error output for more details.',
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ "error": str(ex),
+ "message": "Error validating output. See error output for more details.",
}
return (result, status)
diff --git a/st2common/st2common/util/pack.py b/st2common/st2common/util/pack.py
index 6ac4e4fc48..43dde60051 100644
--- a/st2common/st2common/util/pack.py
+++ b/st2common/st2common/util/pack.py
@@ -30,27 +30,28 @@
from st2common.util import jinja as jinja_utils
__all__ = [
- 'get_pack_ref_from_metadata',
- 'get_pack_metadata',
- 'get_pack_warnings',
-
- 'get_pack_common_libs_path_for_pack_ref',
- 'get_pack_common_libs_path_for_pack_db',
-
- 'validate_config_against_schema',
-
- 'normalize_pack_version'
+ "get_pack_ref_from_metadata",
+ "get_pack_metadata",
+ "get_pack_warnings",
+ "get_pack_common_libs_path_for_pack_ref",
+ "get_pack_common_libs_path_for_pack_db",
+ "validate_config_against_schema",
+ "normalize_pack_version",
]
# Common format for python 2.7 warning
if six.PY2:
- PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \
- "Python 2 support will be dropped in future releases. " \
- "Please consider updating your packs to work with Python 3.x"
+ PACK_PYTHON2_WARNING = (
+ "DEPRECATION WARNING: Pack %s only supports Python 2.x. "
+ "Python 2 support will be dropped in future releases. "
+ "Please consider updating your packs to work with Python 3.x"
+ )
else:
- PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \
- "Python 2 support has been removed since st2 v3.4.0. " \
- "Please update your packs to work with Python 3.x"
+ PACK_PYTHON2_WARNING = (
+ "DEPRECATION WARNING: Pack %s only supports Python 2.x. "
+ "Python 2 support has been removed since st2 v3.4.0. "
+ "Please update your packs to work with Python 3.x"
+ )
def get_pack_ref_from_metadata(metadata, pack_directory_name=None):
@@ -69,19 +70,23 @@ def get_pack_ref_from_metadata(metadata, pack_directory_name=None):
# which are in sub-directories)
# 2. If attribute is not available, but pack name is and pack name meets the valid name
# criteria, we use that
- if metadata.get('ref', None):
- pack_ref = metadata['ref']
- elif pack_directory_name and re.match(PACK_REF_WHITELIST_REGEX, pack_directory_name):
+ if metadata.get("ref", None):
+ pack_ref = metadata["ref"]
+ elif pack_directory_name and re.match(
+ PACK_REF_WHITELIST_REGEX, pack_directory_name
+ ):
pack_ref = pack_directory_name
else:
- if re.match(PACK_REF_WHITELIST_REGEX, metadata['name']):
- pack_ref = metadata['name']
+ if re.match(PACK_REF_WHITELIST_REGEX, metadata["name"]):
+ pack_ref = metadata["name"]
else:
- msg = ('Pack name "%s" contains invalid characters and "ref" attribute is not '
- 'available. You either need to add "ref" attribute which contains only word '
- 'characters to the pack metadata file or update name attribute to contain only'
- 'word characters.')
- raise ValueError(msg % (metadata['name']))
+ msg = (
+ 'Pack name "%s" contains invalid characters and "ref" attribute is not '
+ 'available. You either need to add "ref" attribute which contains only word '
+ "characters to the pack metadata file or update name attribute to contain only"
+ "word characters."
+ )
+ raise ValueError(msg % (metadata["name"]))
return pack_ref
@@ -95,7 +100,9 @@ def get_pack_metadata(pack_dir):
manifest_path = os.path.join(pack_dir, MANIFEST_FILE_NAME)
if not os.path.isfile(manifest_path):
- raise ValueError('Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME))
+ raise ValueError(
+ 'Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME)
+ )
meta_loader = MetaLoader()
content = meta_loader.load(manifest_path)
@@ -112,15 +119,16 @@ def get_pack_warnings(pack_metadata):
:rtype: ``str``
"""
warning = None
- versions = pack_metadata.get('python_versions', None)
- pack_name = pack_metadata.get('name', None)
- if versions and set(versions) == set(['2']):
+ versions = pack_metadata.get("python_versions", None)
+ pack_name = pack_metadata.get("name", None)
+ if versions and set(versions) == set(["2"]):
warning = PACK_PYTHON2_WARNING % pack_name
return warning
-def validate_config_against_schema(config_schema, config_object, config_path,
- pack_name=None):
+def validate_config_against_schema(
+ config_schema, config_object, config_path, pack_name=None
+):
"""
Validate provided config dictionary against the provided config schema
dictionary.
@@ -128,35 +136,49 @@ def validate_config_against_schema(config_schema, config_object, config_path,
# NOTE: Lazy improt to avoid performance overhead of importing this module when it's not used
import jsonschema
- pack_name = pack_name or 'unknown'
+ pack_name = pack_name or "unknown"
- schema = util_schema.get_schema_for_resource_parameters(parameters_schema=config_schema,
- allow_additional_properties=True)
+ schema = util_schema.get_schema_for_resource_parameters(
+ parameters_schema=config_schema, allow_additional_properties=True
+ )
instance = config_object
try:
- cleaned = util_schema.validate(instance=instance, schema=schema,
- cls=util_schema.CustomValidator, use_default=True,
- allow_default_none=True)
+ cleaned = util_schema.validate(
+ instance=instance,
+ schema=schema,
+ cls=util_schema.CustomValidator,
+ use_default=True,
+ allow_default_none=True,
+ )
for key in cleaned:
- if (jinja_utils.is_jinja_expression(value=cleaned.get(key)) and
- "decrypt_kv" in cleaned.get(key) and config_schema.get(key).get('secret')):
- raise ValueValidationException('Values specified as "secret: True" in config '
- 'schema are automatically decrypted by default. Use '
- 'of "decrypt_kv" jinja filter is not allowed for '
- 'such values. Please check the specified values in '
- 'the config or the default values in the schema.')
+ if (
+ jinja_utils.is_jinja_expression(value=cleaned.get(key))
+ and "decrypt_kv" in cleaned.get(key)
+ and config_schema.get(key).get("secret")
+ ):
+ raise ValueValidationException(
+ 'Values specified as "secret: True" in config '
+ "schema are automatically decrypted by default. Use "
+ 'of "decrypt_kv" jinja filter is not allowed for '
+ "such values. Please check the specified values in "
+ "the config or the default values in the schema."
+ )
except jsonschema.ValidationError as e:
- attribute = getattr(e, 'path', [])
+ attribute = getattr(e, "path", [])
if isinstance(attribute, (tuple, list, collections.Iterable)):
attribute = [str(item) for item in attribute]
- attribute = '.'.join(attribute)
+ attribute = ".".join(attribute)
else:
attribute = str(attribute)
- msg = ('Failed validating attribute "%s" in config for pack "%s" (%s): %s' %
- (attribute, pack_name, config_path, six.text_type(e)))
+ msg = 'Failed validating attribute "%s" in config for pack "%s" (%s): %s' % (
+ attribute,
+ pack_name,
+ config_path,
+ six.text_type(e),
+ )
raise jsonschema.ValidationError(msg)
return cleaned
@@ -183,12 +205,12 @@ def get_pack_common_libs_path_for_pack_db(pack_db):
:rtype: ``str``
"""
- pack_dir = getattr(pack_db, 'path', None)
+ pack_dir = getattr(pack_db, "path", None)
if not pack_dir:
return None
- libs_path = os.path.join(pack_dir, 'lib')
+ libs_path = os.path.join(pack_dir, "lib")
return libs_path
@@ -202,8 +224,8 @@ def normalize_pack_version(version):
"""
version = str(version)
- version_seperator_count = version.count('.')
+ version_seperator_count = version.count(".")
if version_seperator_count == 1:
- version = version + '.0'
+ version = version + ".0"
return version
diff --git a/st2common/st2common/util/pack_management.py b/st2common/st2common/util/pack_management.py
index 48b9457203..0fde5b1d86 100644
--- a/st2common/st2common/util/pack_management.py
+++ b/st2common/st2common/util/pack_management.py
@@ -48,29 +48,33 @@
from st2common.util.versioning import get_python_version
__all__ = [
- 'download_pack',
-
- 'get_repo_url',
- 'eval_repo_url',
-
- 'apply_pack_owner_group',
- 'apply_pack_permissions',
-
- 'get_and_set_proxy_config'
+ "download_pack",
+ "get_repo_url",
+ "eval_repo_url",
+ "apply_pack_owner_group",
+ "apply_pack_permissions",
+ "get_and_set_proxy_config",
]
LOG = logging.getLogger(__name__)
-CONFIG_FILE = 'config.yaml'
+CONFIG_FILE = "config.yaml"
CURRENT_STACKSTORM_VERSION = get_stackstorm_version()
CURRENT_PYTHON_VERSION = get_python_version()
-SUDO_BINARY = find_executable('sudo')
+SUDO_BINARY = find_executable("sudo")
-def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, force=False,
- proxy_config=None, force_owner_group=True, force_permissions=True,
- logger=LOG):
+def download_pack(
+ pack,
+ abs_repo_base="/opt/stackstorm/packs",
+ verify_ssl=True,
+ force=False,
+ proxy_config=None,
+ force_owner_group=True,
+ force_permissions=True,
+ logger=LOG,
+):
"""
Download the pack and move it to /opt/stackstorm/packs.
@@ -105,11 +109,11 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True,
result = [pack_url, None, None]
temp_dir_name = hashlib.md5(pack_url.encode()).hexdigest()
- lock_file = LockFile('/tmp/%s' % (temp_dir_name))
+ lock_file = LockFile("/tmp/%s" % (temp_dir_name))
lock_file_path = lock_file.lock_file
if force:
- logger.debug('Force mode is enabled, deleting lock file...')
+ logger.debug("Force mode is enabled, deleting lock file...")
try:
os.unlink(lock_file_path)
@@ -119,31 +123,42 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True,
with lock_file:
try:
- user_home = os.path.expanduser('~')
+ user_home = os.path.expanduser("~")
abs_local_path = os.path.join(user_home, temp_dir_name)
- if pack_url.startswith('file://'):
+ if pack_url.startswith("file://"):
# Local pack
- local_pack_directory = os.path.abspath(os.path.join(pack_url.split('file://')[1]))
+ local_pack_directory = os.path.abspath(
+ os.path.join(pack_url.split("file://")[1])
+ )
else:
local_pack_directory = None
# If it's a local pack which is not a git repository, just copy the directory content
# over
if local_pack_directory and not os.path.isdir(
- os.path.join(local_pack_directory, '.git')):
+ os.path.join(local_pack_directory, ".git")
+ ):
if not os.path.isdir(local_pack_directory):
- raise ValueError('Local pack directory "%s" doesn\'t exist' %
- (local_pack_directory))
+ raise ValueError(
+ 'Local pack directory "%s" doesn\'t exist'
+ % (local_pack_directory)
+ )
- logger.debug('Detected local pack directory which is not a git repository, just '
- 'copying files over...')
+ logger.debug(
+ "Detected local pack directory which is not a git repository, just "
+ "copying files over..."
+ )
shutil.copytree(local_pack_directory, abs_local_path)
else:
# 1. Clone / download the repo
- clone_repo(temp_dir=abs_local_path, repo_url=pack_url, verify_ssl=verify_ssl,
- ref=pack_version)
+ clone_repo(
+ temp_dir=abs_local_path,
+ repo_url=pack_url,
+ verify_ssl=verify_ssl,
+ ref=pack_version,
+ )
pack_metadata = get_pack_metadata(pack_dir=abs_local_path)
pack_ref = get_pack_ref(pack_dir=abs_local_path)
@@ -154,12 +169,15 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True,
verify_pack_version(pack_metadata=pack_metadata)
# 3. Move pack to the final location
- move_result = move_pack(abs_repo_base=abs_repo_base, pack_name=pack_ref,
- abs_local_path=abs_local_path,
- pack_metadata=pack_metadata,
- force_owner_group=force_owner_group,
- force_permissions=force_permissions,
- logger=logger)
+ move_result = move_pack(
+ abs_repo_base=abs_repo_base,
+ pack_name=pack_ref,
+ abs_local_path=abs_local_path,
+ pack_metadata=pack_metadata,
+ force_owner_group=force_owner_group,
+ force_permissions=force_permissions,
+ logger=logger,
+ )
result[2] = move_result
finally:
cleanup_repo(abs_local_path=abs_local_path)
@@ -167,21 +185,21 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True,
return tuple(result)
-def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'):
+def clone_repo(temp_dir, repo_url, verify_ssl=True, ref="master"):
# Switch to non-interactive mode
- os.environ['GIT_TERMINAL_PROMPT'] = '0'
- os.environ['GIT_ASKPASS'] = '/bin/echo'
+ os.environ["GIT_TERMINAL_PROMPT"] = "0"
+ os.environ["GIT_ASKPASS"] = "/bin/echo"
# Disable SSL cert checking if explictly asked
if not verify_ssl:
- os.environ['GIT_SSL_NO_VERIFY'] = 'true'
+ os.environ["GIT_SSL_NO_VERIFY"] = "true"
# Clone the repo from git; we don't use shallow copying
# because we want the user to work with the repo in the
# future.
repo = Repo.clone_from(repo_url, temp_dir)
- is_local_repo = repo_url.startswith('file://')
+ is_local_repo = repo_url.startswith("file://")
try:
active_branch = repo.active_branch
@@ -194,18 +212,20 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'):
# Special case for local git repos - we allow users to install from repos which are checked out
# at a specific commit (aka detached HEAD)
if is_local_repo and not active_branch and not ref:
- LOG.debug('Installing pack from git repo on disk, skipping branch checkout')
+ LOG.debug("Installing pack from git repo on disk, skipping branch checkout")
return temp_dir
use_branch = False
# Special case when a default repo branch is not "master"
# No ref provided so we just use a default active branch
- if (not ref or ref == active_branch.name) and repo.active_branch.object == repo.head.commit:
+ if (
+ not ref or ref == active_branch.name
+ ) and repo.active_branch.object == repo.head.commit:
gitref = repo.active_branch.object
else:
# Try to match the reference to a branch name (i.e. "master")
- gitref = get_gitref(repo, 'origin/%s' % ref)
+ gitref = get_gitref(repo, "origin/%s" % ref)
if gitref:
use_branch = True
@@ -215,7 +235,7 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'):
# Try to match the reference to a "vX.Y.Z" tag
if not gitref and re.match(PACK_VERSION_REGEX, ref):
- gitref = get_gitref(repo, 'v%s' % ref)
+ gitref = get_gitref(repo, "v%s" % ref)
# Giving up ¯\_(ツ)_/¯
if not gitref:
@@ -224,43 +244,52 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'):
valid_versions = get_valid_versions_for_repo(repo=repo)
if len(valid_versions) >= 1:
- valid_versions_string = ', '.join(valid_versions)
+ valid_versions_string = ", ".join(valid_versions)
- msg += ' Available versions are: %s.'
+ msg += " Available versions are: %s."
format_values.append(valid_versions_string)
raise ValueError(msg % tuple(format_values))
# We're trying to figure out which branch the ref is actually on,
# since there's no direct way to check for this in git-python.
- branches = repo.git.branch('-a', '--contains', gitref.hexsha) # pylint: disable=no-member
+ branches = repo.git.branch(
+ "-a", "--contains", gitref.hexsha
+ ) # pylint: disable=no-member
# Git tags aren't necessarily on a branch.
# If this is the case, gitref will be the tag name, but branches will be
# empty.
# We also need to checkout tags slightly differently than branches.
if branches:
- branches = branches.replace('*', '').split()
+ branches = branches.replace("*", "").split()
if active_branch.name not in branches or use_branch:
- branch = 'origin/%s' % ref if use_branch else branches[0]
- short_branch = ref if use_branch else branches[0].split('/')[-1]
- repo.git.checkout('-b', short_branch, branch)
+ branch = "origin/%s" % ref if use_branch else branches[0]
+ short_branch = ref if use_branch else branches[0].split("/")[-1]
+ repo.git.checkout("-b", short_branch, branch)
branch = repo.head.reference
else:
branch = repo.active_branch.name
repo.git.checkout(gitref.hexsha) # pylint: disable=no-member
- repo.git.branch('-f', branch, gitref.hexsha) # pylint: disable=no-member
+ repo.git.branch("-f", branch, gitref.hexsha) # pylint: disable=no-member
repo.git.checkout(branch)
else:
- repo.git.checkout('v%s' % ref) # pylint: disable=no-member
+ repo.git.checkout("v%s" % ref) # pylint: disable=no-member
return temp_dir
-def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_owner_group=True,
- force_permissions=True, logger=LOG):
+def move_pack(
+ abs_repo_base,
+ pack_name,
+ abs_local_path,
+ pack_metadata,
+ force_owner_group=True,
+ force_permissions=True,
+ logger=LOG,
+):
"""
Move pack directory into the final location.
"""
@@ -270,8 +299,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own
to = abs_repo_base
dest_pack_path = os.path.join(abs_repo_base, pack_name)
if os.path.exists(dest_pack_path):
- logger.debug('Removing existing pack %s in %s to replace.', pack_name,
- dest_pack_path)
+ logger.debug(
+ "Removing existing pack %s in %s to replace.", pack_name, dest_pack_path
+ )
# Ensure to preserve any existing configuration
old_config_file = os.path.join(dest_pack_path, CONFIG_FILE)
@@ -282,7 +312,7 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own
shutil.rmtree(dest_pack_path)
- logger.debug('Moving pack from %s to %s.', abs_local_path, to)
+ logger.debug("Moving pack from %s to %s.", abs_local_path, to)
shutil.move(abs_local_path, dest_pack_path)
# post move fix all permissions
@@ -299,9 +329,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own
if warning:
logger.warning(warning)
- message = 'Success.'
+ message = "Success."
elif message:
- message = 'Failure : %s' % message
+ message = "Failure : %s" % message
return (desired, message)
@@ -316,20 +346,25 @@ def apply_pack_owner_group(pack_path):
pack_group = utils.get_pack_group()
if pack_group:
- LOG.debug('Changing owner group of "{}" directory to {}'.format(pack_path, pack_group))
+ LOG.debug(
+ 'Changing owner group of "{}" directory to {}'.format(pack_path, pack_group)
+ )
if SUDO_BINARY:
- args = ['sudo', 'chgrp', '-R', pack_group, pack_path]
+ args = ["sudo", "chgrp", "-R", pack_group, pack_path]
else:
# Environments where sudo is not available (e.g. docker)
- args = ['chgrp', '-R', pack_group, pack_path]
+ args = ["chgrp", "-R", pack_group, pack_path]
exit_code, _, stderr, _ = shell.run_command(args)
if exit_code != 0:
# Non fatal, but we still log it
- LOG.debug('Failed to change owner group on directory "{}" to "{}": {}'
- .format(pack_path, pack_group, stderr))
+ LOG.debug(
+ 'Failed to change owner group on directory "{}" to "{}": {}'.format(
+ pack_path, pack_group, stderr
+ )
+ )
return True
@@ -370,13 +405,13 @@ def get_repo_url(pack, proxy_config=None):
name_or_url = pack_and_version[0]
version = pack_and_version[1] if len(pack_and_version) > 1 else None
- if len(name_or_url.split('/')) == 1:
+ if len(name_or_url.split("/")) == 1:
pack = get_pack_from_index(name_or_url, proxy_config=proxy_config)
if not pack:
raise Exception('No record of the "%s" pack in the index.' % (name_or_url))
- return (pack['repo_url'], version or pack['version'])
+ return (pack["repo_url"], version or pack["version"])
else:
return (eval_repo_url(name_or_url), version)
@@ -386,12 +421,12 @@ def eval_repo_url(repo_url):
Allow passing short GitHub or GitLab SSH style URLs.
"""
if not repo_url:
- raise Exception('No valid repo_url provided or could be inferred.')
+ raise Exception("No valid repo_url provided or could be inferred.")
if repo_url.startswith("gitlab@") or repo_url.startswith("file://"):
return repo_url
else:
- if len(repo_url.split('/')) == 2 and 'git@' not in repo_url:
- url = 'https://github.com/{}'.format(repo_url)
+ if len(repo_url.split("/")) == 2 and "git@" not in repo_url:
+ url = "https://github.com/{}".format(repo_url)
else:
url = repo_url
return url
@@ -400,50 +435,65 @@ def eval_repo_url(repo_url):
def is_desired_pack(abs_pack_path, pack_name):
# path has to exist.
if not os.path.exists(abs_pack_path):
- return (False, 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' %
- (pack_name))
+ return (
+ False,
+ 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' % (pack_name),
+ )
# should not include reserved characters
for character in PACK_RESERVED_CHARACTERS:
if character in pack_name:
- return (False, 'Pack name "%s" contains reserved character "%s"' %
- (pack_name, character))
+ return (
+ False,
+ 'Pack name "%s" contains reserved character "%s"'
+ % (pack_name, character),
+ )
# must contain a manifest file. Empty file is ok for now.
if not os.path.isfile(os.path.join(abs_pack_path, MANIFEST_FILE_NAME)):
- return (False, 'Pack is missing a manifest file (%s).' % (MANIFEST_FILE_NAME))
+ return (False, "Pack is missing a manifest file (%s)." % (MANIFEST_FILE_NAME))
- return (True, '')
+ return (True, "")
def verify_pack_version(pack_metadata):
"""
Verify that the pack works with the currently running StackStorm version.
"""
- pack_name = pack_metadata.get('name', None)
- required_stackstorm_version = pack_metadata.get('stackstorm_version', None)
- supported_python_versions = pack_metadata.get('python_versions', None)
+ pack_name = pack_metadata.get("name", None)
+ required_stackstorm_version = pack_metadata.get("stackstorm_version", None)
+ supported_python_versions = pack_metadata.get("python_versions", None)
# If stackstorm_version attribute is specified, verify that the pack works with currently
# running version of StackStorm
if required_stackstorm_version:
- if not complex_semver_match(CURRENT_STACKSTORM_VERSION, required_stackstorm_version):
- msg = ('Pack "%s" requires StackStorm "%s", but current version is "%s". '
- 'You can override this restriction by providing the "force" flag, but '
- 'the pack is not guaranteed to work.' %
- (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION))
+ if not complex_semver_match(
+ CURRENT_STACKSTORM_VERSION, required_stackstorm_version
+ ):
+ msg = (
+ 'Pack "%s" requires StackStorm "%s", but current version is "%s". '
+ 'You can override this restriction by providing the "force" flag, but '
+ "the pack is not guaranteed to work."
+ % (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION)
+ )
raise ValueError(msg)
if supported_python_versions:
- if set(supported_python_versions) == set(['2']) and (not six.PY2):
- msg = ('Pack "%s" requires Python 2.x, but current Python version is "%s". '
- 'You can override this restriction by providing the "force" flag, but '
- 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION))
+ if set(supported_python_versions) == set(["2"]) and (not six.PY2):
+ msg = (
+ 'Pack "%s" requires Python 2.x, but current Python version is "%s". '
+ 'You can override this restriction by providing the "force" flag, but '
+ "the pack is not guaranteed to work."
+ % (pack_name, CURRENT_PYTHON_VERSION)
+ )
raise ValueError(msg)
- elif set(supported_python_versions) == set(['3']) and (not six.PY3):
- msg = ('Pack "%s" requires Python 3.x, but current Python version is "%s". '
- 'You can override this restriction by providing the "force" flag, but '
- 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION))
+ elif set(supported_python_versions) == set(["3"]) and (not six.PY3):
+ msg = (
+ 'Pack "%s" requires Python 3.x, but current Python version is "%s". '
+ 'You can override this restriction by providing the "force" flag, but '
+ "the pack is not guaranteed to work."
+ % (pack_name, CURRENT_PYTHON_VERSION)
+ )
raise ValueError(msg)
else:
# Pack support Python 2.x and 3.x so no check is needed, or
@@ -474,7 +524,7 @@ def get_valid_versions_for_repo(repo):
valid_versions = []
for tag in repo.tags:
- if tag.name.startswith('v') and re.match(PACK_VERSION_REGEX, tag.name[1:]):
+ if tag.name.startswith("v") and re.match(PACK_VERSION_REGEX, tag.name[1:]):
# Note: We strip leading "v" from the version number
valid_versions.append(tag.name[1:])
@@ -486,39 +536,38 @@ def get_pack_ref(pack_dir):
Read pack reference from the metadata file and sanitize it.
"""
metadata = get_pack_metadata(pack_dir=pack_dir)
- pack_ref = get_pack_ref_from_metadata(metadata=metadata,
- pack_directory_name=None)
+ pack_ref = get_pack_ref_from_metadata(metadata=metadata, pack_directory_name=None)
return pack_ref
def get_and_set_proxy_config():
- https_proxy = os.environ.get('https_proxy', None)
- http_proxy = os.environ.get('http_proxy', None)
- proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None)
- no_proxy = os.environ.get('no_proxy', None)
+ https_proxy = os.environ.get("https_proxy", None)
+ http_proxy = os.environ.get("http_proxy", None)
+ proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None)
+ no_proxy = os.environ.get("no_proxy", None)
proxy_config = {}
if http_proxy or https_proxy:
- LOG.debug('Using proxy %s', http_proxy if http_proxy else https_proxy)
+ LOG.debug("Using proxy %s", http_proxy if http_proxy else https_proxy)
proxy_config = {
- 'https_proxy': https_proxy,
- 'http_proxy': http_proxy,
- 'proxy_ca_bundle_path': proxy_ca_bundle_path,
- 'no_proxy': no_proxy
+ "https_proxy": https_proxy,
+ "http_proxy": http_proxy,
+ "proxy_ca_bundle_path": proxy_ca_bundle_path,
+ "no_proxy": no_proxy,
}
- if https_proxy and not os.environ.get('https_proxy', None):
- os.environ['https_proxy'] = https_proxy
+ if https_proxy and not os.environ.get("https_proxy", None):
+ os.environ["https_proxy"] = https_proxy
- if http_proxy and not os.environ.get('http_proxy', None):
- os.environ['http_proxy'] = http_proxy
+ if http_proxy and not os.environ.get("http_proxy", None):
+ os.environ["http_proxy"] = http_proxy
- if no_proxy and not os.environ.get('no_proxy', None):
- os.environ['no_proxy'] = no_proxy
+ if no_proxy and not os.environ.get("no_proxy", None):
+ os.environ["no_proxy"] = no_proxy
- if proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None):
- os.environ['no_proxy'] = no_proxy
+ if proxy_ca_bundle_path and not os.environ.get("proxy_ca_bundle_path", None):
+ os.environ["no_proxy"] = no_proxy
return proxy_config
diff --git a/st2common/st2common/util/param.py b/st2common/st2common/util/param.py
index 93507fcd87..270c90e424 100644
--- a/st2common/st2common/util/param.py
+++ b/st2common/st2common/util/param.py
@@ -26,7 +26,11 @@
from st2common.util.jinja import is_jinja_expression
from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX
from st2common.constants.pack import PACK_CONFIG_CONTEXT_KV_PREFIX
-from st2common.constants.keyvalue import DATASTORE_PARENT_SCOPE, SYSTEM_SCOPE, FULL_SYSTEM_SCOPE
+from st2common.constants.keyvalue import (
+ DATASTORE_PARENT_SCOPE,
+ SYSTEM_SCOPE,
+ FULL_SYSTEM_SCOPE,
+)
from st2common.constants.keyvalue import USER_SCOPE, FULL_USER_SCOPE
from st2common.exceptions.param import ParamException
from st2common.services.keyvalues import KeyValueLookup, UserKeyValueLookup
@@ -39,23 +43,27 @@
ENV = jinja_utils.get_jinja_environment()
__all__ = [
- 'render_live_params',
- 'render_final_params',
+ "render_live_params",
+ "render_final_params",
]
def _split_params(runner_parameters, action_parameters, mixed_params):
def pf(params, skips):
- result = {k: v for k, v in six.iteritems(mixed_params)
- if k in params and k not in skips}
+ result = {
+ k: v
+ for k, v in six.iteritems(mixed_params)
+ if k in params and k not in skips
+ }
return result
+
return (pf(runner_parameters, {}), pf(action_parameters, runner_parameters))
def _cast_params(rendered, parameter_schemas):
- '''
+ """
It's just here to make tests happy
- '''
+ """
casted_params = {}
for k, v in six.iteritems(rendered):
casted_params[k] = _cast(v, parameter_schemas[k] or {})
@@ -66,7 +74,7 @@ def _cast(v, parameter_schema):
if v is None or not parameter_schema:
return v
- parameter_type = parameter_schema.get('type', None)
+ parameter_type = parameter_schema.get("type", None)
if not parameter_type:
return v
@@ -78,23 +86,27 @@ def _cast(v, parameter_schema):
def _create_graph(action_context, config):
- '''
+ """
Creates a generic directed graph for depencency tree and fills it with basic context variables
- '''
+ """
G = nx.DiGraph()
system_keyvalue_context = {SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)}
# If both 'user' and 'api_user' are specified, this prioritize 'api_user'
- user = action_context['user'] if 'user' in action_context else None
- user = action_context['api_user'] if 'api_user' in action_context else user
+ user = action_context["user"] if "user" in action_context else None
+ user = action_context["api_user"] if "api_user" in action_context else user
if not user:
# When no user is not specified, this selects system-user's scope by default.
user = cfg.CONF.system_user.user
- LOG.info('Unable to retrieve user / api_user value from action_context. Falling back '
- 'to and using system_user (%s).' % (user))
+ LOG.info(
+ "Unable to retrieve user / api_user value from action_context. Falling back "
+ "to and using system_user (%s)." % (user)
+ )
- system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup(scope=FULL_USER_SCOPE, user=user)
+ system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup(
+ scope=FULL_USER_SCOPE, user=user
+ )
G.add_node(DATASTORE_PARENT_SCOPE, value=system_keyvalue_context)
G.add_node(ACTION_CONTEXT_KV_PREFIX, value=action_context)
G.add_node(PACK_CONFIG_CONTEXT_KV_PREFIX, value=config)
@@ -102,9 +114,9 @@ def _create_graph(action_context, config):
def _process(G, name, value):
- '''
+ """
Determines whether parameter is a template or a value. Adds graph nodes and edges accordingly.
- '''
+ """
# Jinja defaults to ascii parser in python 2.x unless you set utf-8 support on per module level
# Instead we're just assuming every string to be a unicode string
if isinstance(value, str):
@@ -114,23 +126,21 @@ def _process(G, name, value):
if isinstance(value, list) or isinstance(value, dict):
complex_value_str = str(value)
- is_jinja_expr = (
- jinja_utils.is_jinja_expression(value) or jinja_utils.is_jinja_expression(
- complex_value_str
- )
- )
+ is_jinja_expr = jinja_utils.is_jinja_expression(
+ value
+ ) or jinja_utils.is_jinja_expression(complex_value_str)
if is_jinja_expr:
G.add_node(name, template=value)
template_ast = ENV.parse(value)
- LOG.debug('Template ast: %s', template_ast)
+ LOG.debug("Template ast: %s", template_ast)
# Dependencies of the node represent jinja variables used in the template
# We're connecting nodes with an edge for every depencency to traverse them
# in the right order and also make sure that we don't have missing or cyclic
# dependencies upfront.
dependencies = meta.find_undeclared_variables(template_ast)
- LOG.debug('Dependencies: %s', dependencies)
+ LOG.debug("Dependencies: %s", dependencies)
if dependencies:
for dependency in dependencies:
G.add_edge(dependency, name)
@@ -139,24 +149,24 @@ def _process(G, name, value):
def _process_defaults(G, schemas):
- '''
+ """
Process dependencies for parameters default values in the order schemas are defined.
- '''
+ """
for schema in schemas:
for name, value in six.iteritems(schema):
absent = name not in G.node
- is_none = G.node.get(name, {}).get('value') is None
- immutable = value.get('immutable', False)
+ is_none = G.node.get(name, {}).get("value") is None
+ immutable = value.get("immutable", False)
if absent or is_none or immutable:
- _process(G, name, value.get('default'))
+ _process(G, name, value.get("default"))
def _validate(G):
- '''
+ """
Validates dependency graph to ensure it has no missing or cyclic dependencies
- '''
+ """
for name in G.nodes():
- if 'value' not in G.node[name] and 'template' not in G.node[name]:
+ if "value" not in G.node[name] and "template" not in G.node[name]:
msg = 'Dependency unsatisfied in variable "%s"' % name
raise ParamException(msg)
@@ -172,51 +182,52 @@ def _validate(G):
variable_names.append(variable_name)
- variable_names = ', '.join(sorted(variable_names))
- msg = ('Cyclic dependency found in the following variables: %s. Likely the variable is '
- 'referencing itself' % (variable_names))
+ variable_names = ", ".join(sorted(variable_names))
+ msg = (
+ "Cyclic dependency found in the following variables: %s. Likely the variable is "
+ "referencing itself" % (variable_names)
+ )
raise ParamException(msg)
def _render(node, render_context):
- '''
+ """
Render the node depending on its type
- '''
- if 'template' in node:
+ """
+ if "template" in node:
complex_type = False
- if isinstance(node['template'], list) or isinstance(node['template'], dict):
- node['template'] = json.dumps(node['template'])
+ if isinstance(node["template"], list) or isinstance(node["template"], dict):
+ node["template"] = json.dumps(node["template"])
# Finds occurrences of "{{variable}}" and adds `to_complex` filter
# so types are honored. If it doesn't follow that syntax then it's
# rendered as a string.
- node['template'] = re.sub(
- r'"{{([A-z0-9_-]+)}}"', r'{{\1 | to_complex}}',
- node['template']
+ node["template"] = re.sub(
+ r'"{{([A-z0-9_-]+)}}"', r"{{\1 | to_complex}}", node["template"]
)
- LOG.debug('Rendering complex type: %s', node['template'])
+ LOG.debug("Rendering complex type: %s", node["template"])
complex_type = True
- LOG.debug('Rendering node: %s with context: %s', node, render_context)
+ LOG.debug("Rendering node: %s with context: %s", node, render_context)
- result = ENV.from_string(str(node['template'])).render(render_context)
+ result = ENV.from_string(str(node["template"])).render(render_context)
- LOG.debug('Render complete: %s', result)
+ LOG.debug("Render complete: %s", result)
if complex_type:
result = json.loads(result)
- LOG.debug('Complex Type Rendered: %s', result)
+ LOG.debug("Complex Type Rendered: %s", result)
return result
- if 'value' in node:
- return node['value']
+ if "value" in node:
+ return node["value"]
def _resolve_dependencies(G):
- '''
+ """
Traverse the dependency graph starting from resolved nodes
- '''
+ """
context = {}
for name in nx.topological_sort(G):
node = G.node[name]
@@ -224,7 +235,7 @@ def _resolve_dependencies(G):
context[name] = _render(node, context)
except Exception as e:
- LOG.debug('Failed to render %s: %s', name, e, exc_info=True)
+ LOG.debug("Failed to render %s: %s", name, e, exc_info=True)
msg = 'Failed to render parameter "%s": %s' % (name, six.text_type(e))
raise ParamException(msg)
@@ -232,9 +243,9 @@ def _resolve_dependencies(G):
def _cast_params_from(params, context, schemas):
- '''
+ """
Pick a list of parameters from context and cast each of them according to the schemas provided
- '''
+ """
result = {}
# First, cast only explicitly provided live parameters
@@ -258,17 +269,19 @@ def _cast_params_from(params, context, schemas):
for param_name, param_details in schema.items():
# Skip if the parameter have immutable set to true in schema
- if param_details.get('immutable'):
+ if param_details.get("immutable"):
continue
# Skip if the parameter doesn't have a default, or if the
# value in the context is identical to the default
- if 'default' not in param_details or \
- param_details.get('default') == context[param_name]:
+ if (
+ "default" not in param_details
+ or param_details.get("default") == context[param_name]
+ ):
continue
# Skip if the default value isn't a Jinja expression
- if not is_jinja_expression(param_details.get('default')):
+ if not is_jinja_expression(param_details.get("default")):
continue
# Skip if the parameter is being overridden
@@ -280,22 +293,29 @@ def _cast_params_from(params, context, schemas):
return result
-def render_live_params(runner_parameters, action_parameters, params, action_context,
- additional_contexts=None):
- '''
+def render_live_params(
+ runner_parameters,
+ action_parameters,
+ params,
+ action_context,
+ additional_contexts=None,
+):
+ """
Renders list of parameters. Ensures that there's no cyclic or missing dependencies. Returns a
dict of plain rendered parameters.
- '''
+ """
additional_contexts = additional_contexts or {}
- pack = action_context.get('pack')
- user = action_context.get('user')
+ pack = action_context.get("pack")
+ user = action_context.get("user")
try:
config = get_config(pack, user)
except Exception as e:
- LOG.info('Failed to retrieve config for pack %s and user %s: %s' % (pack, user,
- six.text_type(e)))
+ LOG.info(
+ "Failed to retrieve config for pack %s and user %s: %s"
+ % (pack, user, six.text_type(e))
+ )
config = {}
G = _create_graph(action_context, config)
@@ -310,18 +330,20 @@ def render_live_params(runner_parameters, action_parameters, params, action_cont
_validate(G)
context = _resolve_dependencies(G)
- live_params = _cast_params_from(params, context, [action_parameters, runner_parameters])
+ live_params = _cast_params_from(
+ params, context, [action_parameters, runner_parameters]
+ )
return live_params
def render_final_params(runner_parameters, action_parameters, params, action_context):
- '''
+ """
Renders missing parameters required for action to execute. Treats parameters from the dict as
plain values instead of trying to render them again. Returns dicts for action and runner
parameters.
- '''
- config = get_config(action_context.get('pack'), action_context.get('user'))
+ """
+ config = get_config(action_context.get("pack"), action_context.get("user"))
G = _create_graph(action_context, config)
@@ -331,18 +353,29 @@ def render_final_params(runner_parameters, action_parameters, params, action_con
_validate(G)
context = _resolve_dependencies(G)
- context = _cast_params_from(context, context, [action_parameters, runner_parameters])
+ context = _cast_params_from(
+ context, context, [action_parameters, runner_parameters]
+ )
return _split_params(runner_parameters, action_parameters, context)
-def get_finalized_params(runnertype_parameter_info, action_parameter_info, liveaction_parameters,
- action_context):
- '''
+def get_finalized_params(
+ runnertype_parameter_info,
+ action_parameter_info,
+ liveaction_parameters,
+ action_context,
+):
+ """
Left here to keep tests running. Later we would need to split tests so they start testing each
function separately.
- '''
- params = render_live_params(runnertype_parameter_info, action_parameter_info,
- liveaction_parameters, action_context)
- return render_final_params(runnertype_parameter_info, action_parameter_info, params,
- action_context)
+ """
+ params = render_live_params(
+ runnertype_parameter_info,
+ action_parameter_info,
+ liveaction_parameters,
+ action_context,
+ )
+ return render_final_params(
+ runnertype_parameter_info, action_parameter_info, params, action_context
+ )
diff --git a/st2common/st2common/util/payload.py b/st2common/st2common/util/payload.py
index 92b36d55c0..b2dc2a74af 100644
--- a/st2common/st2common/util/payload.py
+++ b/st2common/st2common/util/payload.py
@@ -22,11 +22,8 @@
class PayloadLookup(object):
-
def __init__(self, payload, prefix=TRIGGER_PAYLOAD_PREFIX):
- self.context = {
- prefix: payload
- }
+ self.context = {prefix: payload}
for system_scope in SYSTEM_SCOPES:
self.context[system_scope] = KeyValueLookup(scope=system_scope)
diff --git a/st2common/st2common/util/queues.py b/st2common/st2common/util/queues.py
index 526692155f..9fce3b20a7 100644
--- a/st2common/st2common/util/queues.py
+++ b/st2common/st2common/util/queues.py
@@ -36,7 +36,7 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix
:rtype: ``str``
"""
if not queue_name_base:
- raise ValueError('Queue name base cannot be empty.')
+ raise ValueError("Queue name base cannot be empty.")
if not queue_name_suffix:
return queue_name_base
@@ -46,8 +46,8 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix
# Pick last 10 digits of uuid. Arbitrary but unique enough. Long queue names
# might cause issues in RabbitMQ.
u_hex = uuid.uuid4().hex
- uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10:]
- queue_suffix = '%s-%s' % (queue_name_suffix, uuid_suffix)
+ uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10 :]
+ queue_suffix = "%s-%s" % (queue_name_suffix, uuid_suffix)
- queue_name = '%s.%s' % (queue_name_base, queue_suffix)
+ queue_name = "%s.%s" % (queue_name_base, queue_suffix)
return queue_name
diff --git a/st2common/st2common/util/reference.py b/st2common/st2common/util/reference.py
index 3262eb603f..137a014d73 100644
--- a/st2common/st2common/util/reference.py
+++ b/st2common/st2common/util/reference.py
@@ -20,24 +20,25 @@
def get_ref_from_model(model):
if model is None:
- raise ValueError('Model has None value.')
- model_id = getattr(model, 'id', None)
+ raise ValueError("Model has None value.")
+ model_id = getattr(model, "id", None)
if model_id is None:
- raise db.StackStormDBObjectMalformedError('model %s must contain id.' % str(model))
- reference = {'id': str(model_id),
- 'name': getattr(model, 'name', None)}
+ raise db.StackStormDBObjectMalformedError(
+ "model %s must contain id." % str(model)
+ )
+ reference = {"id": str(model_id), "name": getattr(model, "name", None)}
return reference
def get_model_from_ref(db_api, reference):
if reference is None:
- raise db.StackStormDBObjectNotFoundError('No reference supplied.')
- model_id = reference.get('id', None)
+ raise db.StackStormDBObjectNotFoundError("No reference supplied.")
+ model_id = reference.get("id", None)
if model_id is not None:
return db_api.get_by_id(model_id)
- model_name = reference.get('name', None)
+ model_name = reference.get("name", None)
if model_name is None:
- raise db.StackStormDBObjectNotFoundError('Both name and id are None.')
+ raise db.StackStormDBObjectNotFoundError("Both name and id are None.")
return db_api.get_by_name(model_name)
@@ -71,8 +72,10 @@ def get_resource_ref_from_model(model):
name = model.name
pack = model.pack
except AttributeError:
- raise Exception('Cannot build ResourceReference for model: %s. Name or pack missing.'
- % model)
+ raise Exception(
+ "Cannot build ResourceReference for model: %s. Name or pack missing."
+ % model
+ )
return ResourceReference(name=name, pack=pack)
diff --git a/st2common/st2common/util/sandboxing.py b/st2common/st2common/util/sandboxing.py
index 02871f7472..9801f7d112 100644
--- a/st2common/st2common/util/sandboxing.py
+++ b/st2common/st2common/util/sandboxing.py
@@ -31,11 +31,11 @@
from st2common.content.utils import get_pack_base_path
__all__ = [
- 'get_sandbox_python_binary_path',
- 'get_sandbox_python_path',
- 'get_sandbox_python_path_for_python_action',
- 'get_sandbox_path',
- 'get_sandbox_virtualenv_path',
+ "get_sandbox_python_binary_path",
+ "get_sandbox_python_path",
+ "get_sandbox_python_path_for_python_action",
+ "get_sandbox_path",
+ "get_sandbox_virtualenv_path",
]
@@ -47,13 +47,13 @@ def get_sandbox_python_binary_path(pack=None):
:type pack: ``str``
"""
system_base_path = cfg.CONF.system.base_path
- virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack)
+ virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack)
if pack in SYSTEM_PACK_NAMES:
# Use system python for "packs" and "core" actions
python_path = sys.executable
else:
- python_path = os.path.join(virtualenv_path, 'bin/python')
+ python_path = os.path.join(virtualenv_path, "bin/python")
return python_path
@@ -70,19 +70,19 @@ def get_sandbox_path(virtualenv_path):
"""
sandbox_path = []
- parent_path = os.environ.get('PATH', '')
+ parent_path = os.environ.get("PATH", "")
if not virtualenv_path:
return parent_path
- parent_path = parent_path.split(':')
+ parent_path = parent_path.split(":")
parent_path = [path for path in parent_path if path]
# Add virtualenv bin directory
- virtualenv_bin_path = os.path.join(virtualenv_path, 'bin/')
+ virtualenv_bin_path = os.path.join(virtualenv_path, "bin/")
sandbox_path.append(virtualenv_bin_path)
sandbox_path.extend(parent_path)
- sandbox_path = ':'.join(sandbox_path)
+ sandbox_path = ":".join(sandbox_path)
return sandbox_path
@@ -104,9 +104,9 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv=
:type inherit_parent_virtualenv: ``str``
"""
sandbox_python_path = []
- parent_python_path = os.environ.get('PYTHONPATH', '')
+ parent_python_path = os.environ.get("PYTHONPATH", "")
- parent_python_path = parent_python_path.split(':')
+ parent_python_path = parent_python_path.split(":")
parent_python_path = [path for path in parent_python_path if path]
if inherit_from_parent:
@@ -121,13 +121,14 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv=
sandbox_python_path.append(site_packages_dir)
- sandbox_python_path = ':'.join(sandbox_python_path)
- sandbox_python_path = ':' + sandbox_python_path
+ sandbox_python_path = ":".join(sandbox_python_path)
+ sandbox_python_path = ":" + sandbox_python_path
return sandbox_python_path
-def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True,
- inherit_parent_virtualenv=True):
+def get_sandbox_python_path_for_python_action(
+ pack, inherit_from_parent=True, inherit_parent_virtualenv=True
+):
"""
Return sandbox PYTHONPATH for a particular Python runner action.
@@ -136,30 +137,36 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True,
"""
sandbox_python_path = get_sandbox_python_path(
inherit_from_parent=inherit_from_parent,
- inherit_parent_virtualenv=inherit_parent_virtualenv)
+ inherit_parent_virtualenv=inherit_parent_virtualenv,
+ )
pack_base_path = get_pack_base_path(pack_name=pack)
virtualenv_path = get_sandbox_virtualenv_path(pack=pack)
if virtualenv_path and os.path.isdir(virtualenv_path):
- pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib')
+ pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib")
virtualenv_directories = os.listdir(pack_virtualenv_lib_path)
- virtualenv_directories = [dir_name for dir_name in virtualenv_directories if
- fnmatch.fnmatch(dir_name, 'python*')]
+ virtualenv_directories = [
+ dir_name
+ for dir_name in virtualenv_directories
+ if fnmatch.fnmatch(dir_name, "python*")
+ ]
# Add the pack's lib directory (lib/python3.x) in front of the PYTHONPATH.
- pack_actions_lib_paths = os.path.join(pack_base_path, 'actions', 'lib')
- pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib')
- pack_venv_lib_directory = os.path.join(pack_virtualenv_lib_path, virtualenv_directories[0])
+ pack_actions_lib_paths = os.path.join(pack_base_path, "actions", "lib")
+ pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib")
+ pack_venv_lib_directory = os.path.join(
+ pack_virtualenv_lib_path, virtualenv_directories[0]
+ )
# Add the pack's site-packages directory (lib/python3.x/site-packages)
# in front of the Python system site-packages This is important because
# we want Python 3 compatible libraries to be used from the pack virtual
# environment and not system ones.
- pack_venv_site_packages_directory = os.path.join(pack_virtualenv_lib_path,
- virtualenv_directories[0],
- 'site-packages')
+ pack_venv_site_packages_directory = os.path.join(
+ pack_virtualenv_lib_path, virtualenv_directories[0], "site-packages"
+ )
full_sandbox_python_path = [
# NOTE: Order here is very important for imports to function correctly
@@ -169,7 +176,7 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True,
sandbox_python_path,
]
- sandbox_python_path = ':'.join(full_sandbox_python_path)
+ sandbox_python_path = ":".join(full_sandbox_python_path)
return sandbox_python_path
@@ -183,7 +190,7 @@ def get_sandbox_virtualenv_path(pack):
virtualenv_path = None
else:
system_base_path = cfg.CONF.system.base_path
- virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack)
+ virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack)
return virtualenv_path
@@ -195,8 +202,9 @@ def is_in_virtualenv():
"""
# sys.real_prefix is for virtualenv
# sys.base_prefix != sys.prefix is for venv
- return (hasattr(sys, 'real_prefix') or
- (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix))
+ return hasattr(sys, "real_prefix") or (
+ hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix
+ )
def get_virtualenv_prefix():
@@ -205,10 +213,10 @@ def get_virtualenv_prefix():
where we retrieved the virtualenv prefix from. The second element is
the virtualenv prefix.
"""
- if hasattr(sys, 'real_prefix'):
- return ('sys.real_prefix', sys.real_prefix)
- elif hasattr(sys, 'base_prefix'):
- return ('sys.base_prefix', sys.base_prefix)
+ if hasattr(sys, "real_prefix"):
+ return ("sys.real_prefix", sys.real_prefix)
+ elif hasattr(sys, "base_prefix"):
+ return ("sys.base_prefix", sys.base_prefix)
return (None, None)
@@ -216,9 +224,9 @@ def set_virtualenv_prefix(prefix_tuple):
"""
:return: Sets the virtualenv prefix given a tuple returned from get_virtualenv_prefix()
"""
- if prefix_tuple[0] == 'sys.real_prefix' and hasattr(sys, 'real_prefix'):
+ if prefix_tuple[0] == "sys.real_prefix" and hasattr(sys, "real_prefix"):
sys.real_prefix = prefix_tuple[1]
- elif prefix_tuple[0] == 'sys.base_prefix' and hasattr(sys, 'base_prefix'):
+ elif prefix_tuple[0] == "sys.base_prefix" and hasattr(sys, "base_prefix"):
sys.base_prefix = prefix_tuple[1]
@@ -226,7 +234,7 @@ def clear_virtualenv_prefix():
"""
:return: Unsets / removes / resets the virtualenv prefix
"""
- if hasattr(sys, 'real_prefix'):
+ if hasattr(sys, "real_prefix"):
del sys.real_prefix
- if hasattr(sys, 'base_prefix'):
+ if hasattr(sys, "base_prefix"):
sys.base_prefix = sys.prefix
diff --git a/st2common/st2common/util/schema/__init__.py b/st2common/st2common/util/schema/__init__.py
index a49f733a5a..8e18509cd1 100644
--- a/st2common/st2common/util/schema/__init__.py
+++ b/st2common/st2common/util/schema/__init__.py
@@ -27,19 +27,19 @@
from st2common.util.misc import deep_update
__all__ = [
- 'get_validator',
- 'get_draft_schema',
- 'get_action_parameters_schema',
- 'get_schema_for_action_parameters',
- 'get_schema_for_resource_parameters',
- 'is_property_type_single',
- 'is_property_type_list',
- 'is_property_type_anyof',
- 'is_property_type_oneof',
- 'is_property_nullable',
- 'is_attribute_type_array',
- 'is_attribute_type_object',
- 'validate'
+ "get_validator",
+ "get_draft_schema",
+ "get_action_parameters_schema",
+ "get_schema_for_action_parameters",
+ "get_schema_for_resource_parameters",
+ "is_property_type_single",
+ "is_property_type_list",
+ "is_property_type_anyof",
+ "is_property_type_oneof",
+ "is_property_nullable",
+ "is_attribute_type_array",
+ "is_attribute_type_object",
+ "validate",
]
# https://github.com/json-schema/json-schema/blob/master/draft-04/schema
@@ -49,12 +49,13 @@
# and draft 3 version of required.
PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)))
SCHEMAS = {
- 'draft4': jsonify.load_file(os.path.join(PATH, 'draft4.json')),
- 'custom': jsonify.load_file(os.path.join(PATH, 'custom.json')),
-
+ "draft4": jsonify.load_file(os.path.join(PATH, "draft4.json")),
+ "custom": jsonify.load_file(os.path.join(PATH, "custom.json")),
# Custom schema for action params which doesn't allow parameter "type" attribute to be array
- 'action_params': jsonify.load_file(os.path.join(PATH, 'action_params.json')),
- 'action_output_schema': jsonify.load_file(os.path.join(PATH, 'action_output_schema.json'))
+ "action_params": jsonify.load_file(os.path.join(PATH, "action_params.json")),
+ "action_output_schema": jsonify.load_file(
+ os.path.join(PATH, "action_output_schema.json")
+ ),
}
SCHEMA_ANY_TYPE = {
@@ -64,23 +65,23 @@
{"type": "integer"},
{"type": "number"},
{"type": "object"},
- {"type": "string"}
+ {"type": "string"},
]
}
RUNNER_PARAM_OVERRIDABLE_ATTRS = [
- 'default',
- 'description',
- 'enum',
- 'immutable',
- 'required'
+ "default",
+ "description",
+ "enum",
+ "immutable",
+ "required",
]
-def get_draft_schema(version='custom', additional_properties=False):
+def get_draft_schema(version="custom", additional_properties=False):
schema = copy.deepcopy(SCHEMAS[version])
- if additional_properties and 'additionalProperties' in schema:
- del schema['additionalProperties']
+ if additional_properties and "additionalProperties" in schema:
+ del schema["additionalProperties"]
return schema
@@ -89,8 +90,7 @@ def get_action_output_schema(additional_properties=True):
Return a generic schema which is used for validating action output.
"""
return get_draft_schema(
- version='action_output_schema',
- additional_properties=additional_properties
+ version="action_output_schema", additional_properties=additional_properties
)
@@ -98,81 +98,100 @@ def get_action_parameters_schema(additional_properties=False):
"""
Return a generic schema which is used for validating action parameters definition.
"""
- return get_draft_schema(version='action_params', additional_properties=additional_properties)
+ return get_draft_schema(
+ version="action_params", additional_properties=additional_properties
+ )
CustomValidator = create(
- meta_schema=get_draft_schema(version='custom', additional_properties=True),
+ meta_schema=get_draft_schema(version="custom", additional_properties=True),
validators={
- u"$ref": _validators.ref,
- u"additionalItems": _validators.additionalItems,
- u"additionalProperties": _validators.additionalProperties,
- u"allOf": _validators.allOf_draft4,
- u"anyOf": _validators.anyOf_draft4,
- u"dependencies": _validators.dependencies,
- u"enum": _validators.enum,
- u"format": _validators.format,
- u"items": _validators.items,
- u"maxItems": _validators.maxItems,
- u"maxLength": _validators.maxLength,
- u"maxProperties": _validators.maxProperties_draft4,
- u"maximum": _validators.maximum,
- u"minItems": _validators.minItems,
- u"minLength": _validators.minLength,
- u"minProperties": _validators.minProperties_draft4,
- u"minimum": _validators.minimum,
- u"multipleOf": _validators.multipleOf,
- u"not": _validators.not_draft4,
- u"oneOf": _validators.oneOf_draft4,
- u"pattern": _validators.pattern,
- u"patternProperties": _validators.patternProperties,
- u"properties": _validators.properties_draft3,
- u"type": _validators.type_draft4,
- u"uniqueItems": _validators.uniqueItems,
+ "$ref": _validators.ref,
+ "additionalItems": _validators.additionalItems,
+ "additionalProperties": _validators.additionalProperties,
+ "allOf": _validators.allOf_draft4,
+ "anyOf": _validators.anyOf_draft4,
+ "dependencies": _validators.dependencies,
+ "enum": _validators.enum,
+ "format": _validators.format,
+ "items": _validators.items,
+ "maxItems": _validators.maxItems,
+ "maxLength": _validators.maxLength,
+ "maxProperties": _validators.maxProperties_draft4,
+ "maximum": _validators.maximum,
+ "minItems": _validators.minItems,
+ "minLength": _validators.minLength,
+ "minProperties": _validators.minProperties_draft4,
+ "minimum": _validators.minimum,
+ "multipleOf": _validators.multipleOf,
+ "not": _validators.not_draft4,
+ "oneOf": _validators.oneOf_draft4,
+ "pattern": _validators.pattern,
+ "patternProperties": _validators.patternProperties,
+ "properties": _validators.properties_draft3,
+ "type": _validators.type_draft4,
+ "uniqueItems": _validators.uniqueItems,
},
version="custom_validator",
)
def is_property_type_single(property_schema):
- return (isinstance(property_schema, dict) and
- 'anyOf' not in list(property_schema.keys()) and
- 'oneOf' not in list(property_schema.keys()) and
- not isinstance(property_schema.get('type', 'string'), list))
+ return (
+ isinstance(property_schema, dict)
+ and "anyOf" not in list(property_schema.keys())
+ and "oneOf" not in list(property_schema.keys())
+ and not isinstance(property_schema.get("type", "string"), list)
+ )
def is_property_type_list(property_schema):
- return (isinstance(property_schema, dict) and
- isinstance(property_schema.get('type', 'string'), list))
+ return isinstance(property_schema, dict) and isinstance(
+ property_schema.get("type", "string"), list
+ )
def is_property_type_anyof(property_schema):
- return isinstance(property_schema, dict) and 'anyOf' in list(property_schema.keys())
+ return isinstance(property_schema, dict) and "anyOf" in list(property_schema.keys())
def is_property_type_oneof(property_schema):
- return isinstance(property_schema, dict) and 'oneOf' in list(property_schema.keys())
+ return isinstance(property_schema, dict) and "oneOf" in list(property_schema.keys())
def is_property_nullable(property_type_schema):
# For anyOf and oneOf, the property_schema is a list of types.
if isinstance(property_type_schema, list):
- return len([t for t in property_type_schema
- if ((isinstance(t, six.string_types) and t == 'null') or
- (isinstance(t, dict) and t.get('type', 'string') == 'null'))]) > 0
-
- return (isinstance(property_type_schema, dict) and
- property_type_schema.get('type', 'string') == 'null')
+ return (
+ len(
+ [
+ t
+ for t in property_type_schema
+ if (
+ (isinstance(t, six.string_types) and t == "null")
+ or (isinstance(t, dict) and t.get("type", "string") == "null")
+ )
+ ]
+ )
+ > 0
+ )
+
+ return (
+ isinstance(property_type_schema, dict)
+ and property_type_schema.get("type", "string") == "null"
+ )
def is_attribute_type_array(attribute_type):
- return (attribute_type == 'array' or
- (isinstance(attribute_type, list) and 'array' in attribute_type))
+ return attribute_type == "array" or (
+ isinstance(attribute_type, list) and "array" in attribute_type
+ )
def is_attribute_type_object(attribute_type):
- return (attribute_type == 'object' or
- (isinstance(attribute_type, list) and 'object' in attribute_type))
+ return attribute_type == "object" or (
+ isinstance(attribute_type, list) and "object" in attribute_type
+ )
def assign_default_values(instance, schema):
@@ -186,11 +205,11 @@ def assign_default_values(instance, schema):
if not instance_is_dict and not instance_is_array:
return instance
- properties = schema.get('properties', {})
+ properties = schema.get("properties", {})
for property_name, property_data in six.iteritems(properties):
- has_default_value = 'default' in property_data
- default_value = property_data.get('default', None)
+ has_default_value = "default" in property_data
+ default_value = property_data.get("default", None)
# Assign default value on the instance so the validation doesn't fail if requires is true
# but the value is not provided
@@ -203,29 +222,36 @@ def assign_default_values(instance, schema):
instance[index][property_name] = default_value
# Support for nested properties (array and object)
- attribute_type = property_data.get('type', None)
- schema_items = property_data.get('items', {})
+ attribute_type = property_data.get("type", None)
+ schema_items = property_data.get("items", {})
# Array
- if (is_attribute_type_array(attribute_type) and
- schema_items and schema_items.get('properties', {})):
+ if (
+ is_attribute_type_array(attribute_type)
+ and schema_items
+ and schema_items.get("properties", {})
+ ):
array_instance = instance.get(property_name, None)
- array_schema = schema['properties'][property_name]['items']
+ array_schema = schema["properties"][property_name]["items"]
if array_instance is not None:
# Note: We don't perform subschema assignment if no value is provided
- instance[property_name] = assign_default_values(instance=array_instance,
- schema=array_schema)
+ instance[property_name] = assign_default_values(
+ instance=array_instance, schema=array_schema
+ )
# Object
- if is_attribute_type_object(attribute_type) and property_data.get('properties', {}):
+ if is_attribute_type_object(attribute_type) and property_data.get(
+ "properties", {}
+ ):
object_instance = instance.get(property_name, None)
- object_schema = schema['properties'][property_name]
+ object_schema = schema["properties"][property_name]
if object_instance is not None:
# Note: We don't perform subschema assignment if no value is provided
- instance[property_name] = assign_default_values(instance=object_instance,
- schema=object_schema)
+ instance[property_name] = assign_default_values(
+ instance=object_instance, schema=object_schema
+ )
return instance
@@ -236,51 +262,70 @@ def modify_schema_allow_default_none(schema):
defines a default value of None.
"""
schema = copy.deepcopy(schema)
- properties = schema.get('properties', {})
+ properties = schema.get("properties", {})
for property_name, property_data in six.iteritems(properties):
- is_optional = not property_data.get('required', False)
- has_default_value = 'default' in property_data
- default_value = property_data.get('default', None)
- property_schema = schema['properties'][property_name]
+ is_optional = not property_data.get("required", False)
+ has_default_value = "default" in property_data
+ default_value = property_data.get("default", None)
+ property_schema = schema["properties"][property_name]
if (has_default_value or is_optional) and default_value is None:
# If property is anyOf and oneOf then it has to be process differently.
- if (is_property_type_anyof(property_schema) and
- not is_property_nullable(property_schema['anyOf'])):
- property_schema['anyOf'].append({'type': 'null'})
- elif (is_property_type_oneof(property_schema) and
- not is_property_nullable(property_schema['oneOf'])):
- property_schema['oneOf'].append({'type': 'null'})
- elif (is_property_type_list(property_schema) and
- not is_property_nullable(property_schema.get('type'))):
- property_schema['type'].append('null')
- elif (is_property_type_single(property_schema) and
- not is_property_nullable(property_schema.get('type'))):
- property_schema['type'] = [property_schema.get('type', 'string'), 'null']
+ if is_property_type_anyof(property_schema) and not is_property_nullable(
+ property_schema["anyOf"]
+ ):
+ property_schema["anyOf"].append({"type": "null"})
+ elif is_property_type_oneof(property_schema) and not is_property_nullable(
+ property_schema["oneOf"]
+ ):
+ property_schema["oneOf"].append({"type": "null"})
+ elif is_property_type_list(property_schema) and not is_property_nullable(
+ property_schema.get("type")
+ ):
+ property_schema["type"].append("null")
+ elif is_property_type_single(property_schema) and not is_property_nullable(
+ property_schema.get("type")
+ ):
+ property_schema["type"] = [
+ property_schema.get("type", "string"),
+ "null",
+ ]
# Support for nested properties (array and object)
- attribute_type = property_data.get('type', None)
- schema_items = property_data.get('items', {})
+ attribute_type = property_data.get("type", None)
+ schema_items = property_data.get("items", {})
# Array
- if (is_attribute_type_array(attribute_type) and
- schema_items and schema_items.get('properties', {})):
+ if (
+ is_attribute_type_array(attribute_type)
+ and schema_items
+ and schema_items.get("properties", {})
+ ):
array_schema = schema_items
array_schema = modify_schema_allow_default_none(schema=array_schema)
- schema['properties'][property_name]['items'] = array_schema
+ schema["properties"][property_name]["items"] = array_schema
# Object
- if is_attribute_type_object(attribute_type) and property_data.get('properties', {}):
+ if is_attribute_type_object(attribute_type) and property_data.get(
+ "properties", {}
+ ):
object_schema = property_data
object_schema = modify_schema_allow_default_none(schema=object_schema)
- schema['properties'][property_name] = object_schema
+ schema["properties"][property_name] = object_schema
return schema
-def validate(instance, schema, cls=None, use_default=True, allow_default_none=False, *args,
- **kwargs):
+def validate(
+ instance,
+ schema,
+ cls=None,
+ use_default=True,
+ allow_default_none=False,
+ *args,
+ **kwargs,
+):
"""
Custom validate function which supports default arguments combined with the "required"
property.
@@ -292,13 +337,13 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa
"""
instance = copy.deepcopy(instance)
- schema_type = schema.get('type', None)
+ schema_type = schema.get("type", None)
instance_is_dict = isinstance(instance, dict)
if use_default and allow_default_none:
schema = modify_schema_allow_default_none(schema=schema)
- if use_default and schema_type == 'object' and instance_is_dict:
+ if use_default and schema_type == "object" and instance_is_dict:
instance = assign_default_values(instance=instance, schema=schema)
# pylint: disable=assignment-from-no-return
@@ -307,28 +352,30 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa
return instance
-VALIDATORS = {
- 'draft4': jsonschema.Draft4Validator,
- 'custom': CustomValidator
-}
+VALIDATORS = {"draft4": jsonschema.Draft4Validator, "custom": CustomValidator}
-def get_validator(version='custom'):
+def get_validator(version="custom"):
validator = VALIDATORS[version]
return validator
-def validate_runner_parameter_attribute_override(action_ref, param_name, attr_name,
- runner_param_attr_value, action_param_attr_value):
+def validate_runner_parameter_attribute_override(
+ action_ref, param_name, attr_name, runner_param_attr_value, action_param_attr_value
+):
"""
Validate that the provided parameter from the action schema can override the
runner parameter.
"""
param_values_are_the_same = action_param_attr_value == runner_param_attr_value
- if (attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS and not param_values_are_the_same):
+ if (
+ attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS
+ and not param_values_are_the_same
+ ):
raise InvalidActionParameterException(
'The attribute "%s" for the runner parameter "%s" in action "%s" '
- 'cannot be overridden.' % (attr_name, param_name, action_ref))
+ "cannot be overridden." % (attr_name, param_name, action_ref)
+ )
return True
@@ -341,7 +388,8 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None):
"""
if not runnertype_db:
from st2common.util.action_db import get_runnertype_by_name
- runnertype_db = get_runnertype_by_name(action_db.runner_type['name'])
+
+ runnertype_db = get_runnertype_by_name(action_db.runner_type["name"])
# Note: We need to perform a deep merge because user can only specify a single parameter
# attribute when overriding it in an action metadata.
@@ -359,26 +407,31 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None):
for attribute, value in six.iteritems(schema):
runner_param_value = runnertype_db.runner_parameters[name].get(attribute)
- validate_runner_parameter_attribute_override(action_ref=action_db.ref,
- param_name=name,
- attr_name=attribute,
- runner_param_attr_value=runner_param_value,
- action_param_attr_value=value)
+ validate_runner_parameter_attribute_override(
+ action_ref=action_db.ref,
+ param_name=name,
+ attr_name=attribute,
+ runner_param_attr_value=runner_param_value,
+ action_param_attr_value=value,
+ )
schema = get_schema_for_resource_parameters(parameters_schema=parameters_schema)
if parameters_schema:
- schema['title'] = action_db.name
+ schema["title"] = action_db.name
if action_db.description:
- schema['description'] = action_db.description
+ schema["description"] = action_db.description
return schema
-def get_schema_for_resource_parameters(parameters_schema, allow_additional_properties=False):
+def get_schema_for_resource_parameters(
+ parameters_schema, allow_additional_properties=False
+):
"""
Dynamically construct JSON schema for the provided resource from the parameters metadata.
"""
+
def normalize(x):
return {k: v if v else SCHEMA_ANY_TYPE for k, v in six.iteritems(x)}
@@ -386,8 +439,8 @@ def normalize(x):
properties = {}
properties.update(normalize(parameters_schema))
if properties:
- schema['type'] = 'object'
- schema['properties'] = properties
- schema['additionalProperties'] = allow_additional_properties
+ schema["type"] = "object"
+ schema["properties"] = properties
+ schema["additionalProperties"] = allow_additional_properties
return schema
diff --git a/st2common/st2common/util/secrets.py b/st2common/st2common/util/secrets.py
index 2945ef0594..b863a93a61 100644
--- a/st2common/st2common/util/secrets.py
+++ b/st2common/st2common/util/secrets.py
@@ -65,7 +65,7 @@ def get_secret_parameters(parameters):
"""
secret_parameters = {}
- parameters_type = parameters.get('type')
+ parameters_type = parameters.get("type")
# If the parameter itself is secret, then skip all processing below it
# and return the type of this parameter.
#
@@ -74,22 +74,22 @@ def get_secret_parameters(parameters):
# **Important** that we do this check first, so in case this parameter
# is an `object` or `array`, and the user wants the full thing
# to be secret, that it is marked as secret.
- if parameters.get('secret', False):
+ if parameters.get("secret", False):
return parameters_type
iterator = None
- if parameters_type == 'object':
+ if parameters_type == "object":
# if this is an object, then iterate over the properties within
# the object
# result = dict
- iterator = six.iteritems(parameters.get('properties', {}))
- elif parameters_type == 'array':
+ iterator = six.iteritems(parameters.get("properties", {}))
+ elif parameters_type == "array":
# if this is an array, then iterate over the items definition as a single
# property
# result = list
- iterator = enumerate([parameters.get('items', {})])
+ iterator = enumerate([parameters.get("items", {})])
secret_parameters = []
- elif parameters_type in ['integer', 'number', 'boolean', 'null', 'string']:
+ elif parameters_type in ["integer", "number", "boolean", "null", "string"]:
# if this a "plain old datatype", then iterate over the properties set
# of the data type
# result = string (property type)
@@ -105,8 +105,8 @@ def get_secret_parameters(parameters):
if not isinstance(options, dict):
continue
- parameter_type = options.get('type')
- if options.get('secret', False):
+ parameter_type = options.get("type")
+ if options.get("secret", False):
# If this parameter is secret, then add it our secret parameters
#
# **This causes the _full_ object / array tree to be secret
@@ -121,7 +121,7 @@ def get_secret_parameters(parameters):
secret_parameters[parameter] = parameter_type
else:
return parameter_type
- elif parameter_type in ['object', 'array']:
+ elif parameter_type in ["object", "array"]:
# otherwise recursively dive into the `object`/`array` and
# find individual parameters marked as secret
sub_params = get_secret_parameters(options)
@@ -176,15 +176,17 @@ def mask_secret_parameters(parameters, secret_parameters, result=None):
for secret_param, secret_sub_params in iterator:
if is_dict:
if secret_param in result:
- result[secret_param] = mask_secret_parameters(parameters[secret_param],
- secret_sub_params,
- result=result[secret_param])
+ result[secret_param] = mask_secret_parameters(
+ parameters[secret_param],
+ secret_sub_params,
+ result=result[secret_param],
+ )
elif is_list:
# we're assuming lists contain the same data type for every element
for idx, value in enumerate(result):
- result[idx] = mask_secret_parameters(parameters[idx],
- secret_sub_params,
- result=result[idx])
+ result[idx] = mask_secret_parameters(
+ parameters[idx], secret_sub_params, result=result[idx]
+ )
else:
result[secret_param] = MASKED_ATTRIBUTE_VALUE
@@ -204,8 +206,8 @@ def mask_inquiry_response(response, schema):
"""
result = fast_deepcopy(response)
- for prop_name, prop_attrs in schema['properties'].items():
- if prop_attrs.get('secret') is True:
+ for prop_name, prop_attrs in schema["properties"].items():
+ if prop_attrs.get("secret") is True:
if prop_name in response:
result[prop_name] = MASKED_ATTRIBUTE_VALUE
diff --git a/st2common/st2common/util/service.py b/st2common/st2common/util/service.py
index 6691e50268..e3c2dcb9f9 100644
--- a/st2common/st2common/util/service.py
+++ b/st2common/st2common/util/service.py
@@ -24,13 +24,13 @@
def retry_on_exceptions(exc):
- LOG.warning('Evaluating retry on exception %s. %s', type(exc), str(exc))
+ LOG.warning("Evaluating retry on exception %s. %s", type(exc), str(exc))
is_mongo_connection_error = isinstance(exc, pymongo.errors.ConnectionFailure)
retrying = is_mongo_connection_error
if retrying:
- LOG.warning('Retrying on exception %s.', type(exc))
+ LOG.warning("Retrying on exception %s.", type(exc))
return retrying
diff --git a/st2common/st2common/util/shell.py b/st2common/st2common/util/shell.py
index 5c4217594a..945ec39a5a 100644
--- a/st2common/st2common/util/shell.py
+++ b/st2common/st2common/util/shell.py
@@ -30,13 +30,7 @@
# subprocess functionality and run_command
subprocess = concurrency.get_subprocess_module()
-__all__ = [
- 'run_command',
- 'kill_process',
-
- 'quote_unix',
- 'quote_windows'
-]
+__all__ = ["run_command", "kill_process", "quote_unix", "quote_windows"]
LOG = logging.getLogger(__name__)
@@ -45,8 +39,15 @@
# pylint: disable=too-many-function-args
-def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False,
- cwd=None, env=None):
+def run_command(
+ cmd,
+ stdin=None,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ cwd=None,
+ env=None,
+):
"""
Run the provided command in a subprocess and wait until it completes.
@@ -79,8 +80,15 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
if not env:
env = os.environ.copy()
- process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr,
- env=env, cwd=cwd, shell=shell)
+ process = concurrency.subprocess_popen(
+ args=cmd,
+ stdin=stdin,
+ stdout=stdout,
+ stderr=stderr,
+ env=env,
+ cwd=cwd,
+ shell=shell,
+ )
stdout, stderr = process.communicate()
exit_code = process.returncode
@@ -100,15 +108,17 @@ def kill_process(process):
:param process: Process object as returned by subprocess.Popen.
:type process: ``object``
"""
- kill_command = shlex.split('sudo pkill -TERM -s %s' % (process.pid))
+ kill_command = shlex.split("sudo pkill -TERM -s %s" % (process.pid))
try:
if six.PY3:
- status = subprocess.call(kill_command, timeout=100) # pylint: disable=not-callable
+ status = subprocess.call(
+ kill_command, timeout=100
+ ) # pylint: disable=not-callable
else:
status = subprocess.call(kill_command) # pylint: disable=not-callable
except Exception:
- LOG.exception('Unable to pkill process.')
+ LOG.exception("Unable to pkill process.")
return status
@@ -151,11 +161,12 @@ def on_parent_exit(signame):
Based on https://gist.github.com/evansd/2346614
"""
+
def noop():
pass
try:
- libc = cdll['libc.so.6']
+ libc = cdll["libc.so.6"]
except OSError:
# libc, can't be found (e.g. running on non-Unix system), we cant ensure signal will be
# triggered
@@ -173,5 +184,6 @@ def set_parent_exit_signal():
# http://linux.die.net/man/2/prctl
result = prctl(PR_SET_PDEATHSIG, signum)
if result != 0:
- raise Exception('prctl failed with error code %s' % result)
+ raise Exception("prctl failed with error code %s" % result)
+
return set_parent_exit_signal
diff --git a/st2common/st2common/util/spec_loader.py b/st2common/st2common/util/spec_loader.py
index 8ab926330f..07889fa2d2 100644
--- a/st2common/st2common/util/spec_loader.py
+++ b/st2common/st2common/util/spec_loader.py
@@ -33,16 +33,13 @@
from st2common.rbac.types import PermissionType
from st2common.util import isotime
-__all__ = [
- 'load_spec',
- 'generate_spec'
-]
+__all__ = ["load_spec", "generate_spec"]
ARGUMENTS = {
- 'DEFAULT_PACK_NAME': st2common.constants.pack.DEFAULT_PACK_NAME,
- 'LIVEACTION_STATUSES': st2common.constants.action.LIVEACTION_STATUSES,
- 'PERMISSION_TYPE': PermissionType,
- 'ISO8601_UTC_REGEX': isotime.ISO8601_UTC_REGEX
+ "DEFAULT_PACK_NAME": st2common.constants.pack.DEFAULT_PACK_NAME,
+ "LIVEACTION_STATUSES": st2common.constants.action.LIVEACTION_STATUSES,
+ "PERMISSION_TYPE": PermissionType,
+ "ISO8601_UTC_REGEX": isotime.ISO8601_UTC_REGEX,
}
@@ -50,23 +47,35 @@ class UniqueKeyLoader(Loader):
"""
YAML loader which throws on a duplicate key.
"""
+
def construct_mapping(self, node, deep=False):
if not isinstance(node, MappingNode):
- raise ConstructorError(None, None,
- "expected a mapping node, but found %s" % node.id,
- node.start_mark)
+ raise ConstructorError(
+ None,
+ None,
+ "expected a mapping node, but found %s" % node.id,
+ node.start_mark,
+ )
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
try:
hash(key)
except TypeError as exc:
- raise ConstructorError("while constructing a mapping", node.start_mark,
- "found unacceptable key (%s)" % exc, key_node.start_mark)
+ raise ConstructorError(
+ "while constructing a mapping",
+ node.start_mark,
+ "found unacceptable key (%s)" % exc,
+ key_node.start_mark,
+ )
# check for duplicate keys
if key in mapping:
- raise ConstructorError("while constructing a mapping", node.start_mark,
- "found duplicate key", key_node.start_mark)
+ raise ConstructorError(
+ "while constructing a mapping",
+ node.start_mark,
+ "found duplicate key",
+ key_node.start_mark,
+ )
value = self.construct_object(value_node, deep=deep)
mapping[key] = value
return mapping
diff --git a/st2common/st2common/util/system_info.py b/st2common/st2common/util/system_info.py
index a83bf5169f..b81d205907 100644
--- a/st2common/st2common/util/system_info.py
+++ b/st2common/st2common/util/system_info.py
@@ -17,22 +17,14 @@
import os
import socket
-__all__ = [
- 'get_host_info',
- 'get_process_info'
-]
+__all__ = ["get_host_info", "get_process_info"]
def get_host_info():
- host_info = {
- 'hostname': socket.gethostname()
- }
+ host_info = {"hostname": socket.gethostname()}
return host_info
def get_process_info():
- process_info = {
- 'hostname': socket.gethostname(),
- 'pid': os.getpid()
- }
+ process_info = {"hostname": socket.gethostname(), "pid": os.getpid()}
return process_info
diff --git a/st2common/st2common/util/templating.py b/st2common/st2common/util/templating.py
index 9dc25d917c..82e8e1c246 100644
--- a/st2common/st2common/util/templating.py
+++ b/st2common/st2common/util/templating.py
@@ -24,9 +24,9 @@
from st2common.services.keyvalues import UserKeyValueLookup
__all__ = [
- 'render_template',
- 'render_template_with_system_context',
- 'render_template_with_system_and_user_context'
+ "render_template",
+ "render_template_with_system_context",
+ "render_template_with_system_and_user_context",
]
@@ -74,7 +74,9 @@ def render_template_with_system_context(value, context=None, prefix=None):
return rendered
-def render_template_with_system_and_user_context(value, user, context=None, prefix=None):
+def render_template_with_system_and_user_context(
+ value, user, context=None, prefix=None
+):
"""
Render provided template with a default system context and user context for the provided user.
@@ -95,7 +97,7 @@ def render_template_with_system_and_user_context(value, user, context=None, pref
context = context or {}
context[DATASTORE_PARENT_SCOPE] = {
SYSTEM_SCOPE: KeyValueLookup(prefix=prefix, scope=FULL_SYSTEM_SCOPE),
- USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE)
+ USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE),
}
rendered = render_template(value=value, context=context)
diff --git a/st2common/st2common/util/types.py b/st2common/st2common/util/types.py
index 5c25990a6e..ad70f078b9 100644
--- a/st2common/st2common/util/types.py
+++ b/st2common/st2common/util/types.py
@@ -20,17 +20,14 @@
from __future__ import absolute_import
import collections
-__all__ = [
- 'OrderedSet'
-]
+__all__ = ["OrderedSet"]
class OrderedSet(collections.MutableSet):
-
def __init__(self, iterable=None):
self.end = end = []
- end += [None, end, end] # sentinel node for doubly linked list
- self.map = {} # key --> [key, prev, next]
+ end += [None, end, end] # sentinel node for doubly linked list
+ self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
@@ -68,15 +65,15 @@ def __reversed__(self):
def pop(self, last=True):
if not self:
- raise KeyError('set is empty')
+ raise KeyError("set is empty")
key = self.end[1][0] if last else self.end[2][0]
self.discard(key)
return key
def __repr__(self):
if not self:
- return '%s()' % (self.__class__.__name__,)
- return '%s(%r)' % (self.__class__.__name__, list(self))
+ return "%s()" % (self.__class__.__name__,)
+ return "%s(%r)" % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
diff --git a/st2common/st2common/util/uid.py b/st2common/st2common/util/uid.py
index 07d04d7511..289184d59e 100644
--- a/st2common/st2common/util/uid.py
+++ b/st2common/st2common/util/uid.py
@@ -20,9 +20,7 @@
from __future__ import absolute_import
from st2common.models.db.stormbase import UIDFieldMixin
-__all__ = [
- 'parse_uid'
-]
+__all__ = ["parse_uid"]
def parse_uid(uid):
@@ -33,12 +31,12 @@ def parse_uid(uid):
:rtype: ``tuple``
"""
if UIDFieldMixin.UID_SEPARATOR not in uid:
- raise ValueError('Invalid uid: %s' % (uid))
+ raise ValueError("Invalid uid: %s" % (uid))
parsed = uid.split(UIDFieldMixin.UID_SEPARATOR)
if len(parsed) < 2:
- raise ValueError('Invalid or malformed uid: %s' % (uid))
+ raise ValueError("Invalid or malformed uid: %s" % (uid))
resource_type = parsed[0]
uid_remainder = parsed[1:]
diff --git a/st2common/st2common/util/ujson.py b/st2common/st2common/util/ujson.py
index cace243448..6c533fb30a 100644
--- a/st2common/st2common/util/ujson.py
+++ b/st2common/st2common/util/ujson.py
@@ -19,9 +19,7 @@
import ujson
-__all__ = [
- 'fast_deepcopy'
-]
+__all__ = ["fast_deepcopy"]
def fast_deepcopy(value, fall_back_to_deepcopy=True):
diff --git a/st2common/st2common/util/url.py b/st2common/st2common/util/url.py
index 9c3196f835..b4dd8fc137 100644
--- a/st2common/st2common/util/url.py
+++ b/st2common/st2common/util/url.py
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__all__ = [
- 'get_url_without_trailing_slash'
-]
+__all__ = ["get_url_without_trailing_slash"]
def get_url_without_trailing_slash(value):
@@ -27,5 +25,5 @@ def get_url_without_trailing_slash(value):
:rtype: ``str``
"""
- result = value[:-1] if value.endswith('/') else value
+ result = value[:-1] if value.endswith("/") else value
return result
diff --git a/st2common/st2common/util/versioning.py b/st2common/st2common/util/versioning.py
index 121a93312a..89da24f174 100644
--- a/st2common/st2common/util/versioning.py
+++ b/st2common/st2common/util/versioning.py
@@ -25,12 +25,7 @@
from st2common import __version__ as stackstorm_version
-__all__ = [
- 'get_stackstorm_version',
- 'get_python_version',
-
- 'complex_semver_match'
-]
+__all__ = ["get_stackstorm_version", "get_python_version", "complex_semver_match"]
def get_stackstorm_version():
@@ -38,8 +33,8 @@ def get_stackstorm_version():
Return a valid semver version string for the currently running StackStorm version.
"""
# Special handling for dev versions which are not valid semver identifiers
- if 'dev' in stackstorm_version and stackstorm_version.count('.') == 1:
- version = stackstorm_version.replace('dev', '.0')
+ if "dev" in stackstorm_version and stackstorm_version.count(".") == 1:
+ version = stackstorm_version.replace("dev", ".0")
return version
return stackstorm_version
@@ -50,7 +45,7 @@ def get_python_version():
Return Python version used by this installation.
"""
version_info = sys.version_info
- return '%s.%s.%s' % (version_info.major, version_info.minor, version_info.micro)
+ return "%s.%s.%s" % (version_info.major, version_info.minor, version_info.micro)
def complex_semver_match(version, version_specifier):
@@ -63,10 +58,10 @@ def complex_semver_match(version, version_specifier):
:rtype: ``bool``
"""
- if version_specifier == 'all':
+ if version_specifier == "all":
return True
- split_version_specifier = version_specifier.split(',')
+ split_version_specifier = version_specifier.split(",")
if len(split_version_specifier) == 1:
# No comma, we can do a simple comparision
diff --git a/st2common/st2common/util/virtualenvs.py b/st2common/st2common/util/virtualenvs.py
index 52e39974fd..62cfc99b52 100644
--- a/st2common/st2common/util/virtualenvs.py
+++ b/st2common/st2common/util/virtualenvs.py
@@ -36,16 +36,22 @@
from st2common.content.utils import get_packs_base_paths
from st2common.content.utils import get_pack_directory
-__all__ = [
- 'setup_pack_virtualenv'
-]
+__all__ = ["setup_pack_virtualenv"]
LOG = logging.getLogger(__name__)
-def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True,
- include_setuptools=True, include_wheel=True, proxy_config=None,
- no_download=True, force_owner_group=True):
+def setup_pack_virtualenv(
+ pack_name,
+ update=False,
+ logger=None,
+ include_pip=True,
+ include_setuptools=True,
+ include_wheel=True,
+ proxy_config=None,
+ no_download=True,
+ force_owner_group=True,
+):
"""
Setup virtual environment for the provided pack.
@@ -68,7 +74,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True
if not re.match(PACK_REF_WHITELIST_REGEX, pack_name):
raise ValueError('Invalid pack name "%s"' % (pack_name))
- base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/')
+ base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, "virtualenvs/")
virtualenv_path = os.path.join(base_virtualenvs_path, quote_unix(pack_name))
# Ensure pack directory exists in one of the search paths
@@ -78,7 +84,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True
if not pack_path:
packs_base_paths = get_packs_base_paths()
- search_paths = ', '.join(packs_base_paths)
+ search_paths = ", ".join(packs_base_paths)
msg = 'Pack "%s" is not installed. Looked in: %s' % (pack_name, search_paths)
raise Exception(msg)
@@ -88,42 +94,64 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True
remove_virtualenv(virtualenv_path=virtualenv_path, logger=logger)
# 1. Create virtual environment
- logger.debug('Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path))
- create_virtualenv(virtualenv_path=virtualenv_path, logger=logger, include_pip=include_pip,
- include_setuptools=include_setuptools, include_wheel=include_wheel,
- no_download=no_download)
+ logger.debug(
+ 'Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path)
+ )
+ create_virtualenv(
+ virtualenv_path=virtualenv_path,
+ logger=logger,
+ include_pip=include_pip,
+ include_setuptools=include_setuptools,
+ include_wheel=include_wheel,
+ no_download=no_download,
+ )
# 2. Install base requirements which are common to all the packs
- logger.debug('Installing base requirements')
+ logger.debug("Installing base requirements")
for requirement in BASE_PACK_REQUIREMENTS:
- install_requirement(virtualenv_path=virtualenv_path, requirement=requirement,
- proxy_config=proxy_config, logger=logger)
+ install_requirement(
+ virtualenv_path=virtualenv_path,
+ requirement=requirement,
+ proxy_config=proxy_config,
+ logger=logger,
+ )
# 3. Install pack-specific requirements
- requirements_file_path = os.path.join(pack_path, 'requirements.txt')
+ requirements_file_path = os.path.join(pack_path, "requirements.txt")
has_requirements = os.path.isfile(requirements_file_path)
if has_requirements:
- logger.debug('Installing pack specific requirements from "%s"' %
- (requirements_file_path))
- install_requirements(virtualenv_path=virtualenv_path,
- requirements_file_path=requirements_file_path,
- proxy_config=proxy_config,
- logger=logger)
+ logger.debug(
+ 'Installing pack specific requirements from "%s"' % (requirements_file_path)
+ )
+ install_requirements(
+ virtualenv_path=virtualenv_path,
+ requirements_file_path=requirements_file_path,
+ proxy_config=proxy_config,
+ logger=logger,
+ )
else:
- logger.debug('No pack specific requirements found')
+ logger.debug("No pack specific requirements found")
# 4. Set the owner group
if force_owner_group:
apply_pack_owner_group(pack_path=virtualenv_path)
- action = 'updated' if update else 'created'
- logger.debug('Virtualenv for pack "%s" successfully %s in "%s"' %
- (pack_name, action, virtualenv_path))
-
-
-def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_setuptools=True,
- include_wheel=True, no_download=True):
+ action = "updated" if update else "created"
+ logger.debug(
+ 'Virtualenv for pack "%s" successfully %s in "%s"'
+ % (pack_name, action, virtualenv_path)
+ )
+
+
+def create_virtualenv(
+ virtualenv_path,
+ logger=None,
+ include_pip=True,
+ include_setuptools=True,
+ include_wheel=True,
+ no_download=True,
+):
"""
:param include_pip: Include pip binary and package in the newely created virtual environment.
:type include_pip: ``bool``
@@ -145,7 +173,7 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se
python_binary = cfg.CONF.actionrunner.python_binary
virtualenv_binary = cfg.CONF.actionrunner.virtualenv_binary
virtualenv_opts = cfg.CONF.actionrunner.virtualenv_opts or []
- virtualenv_opts += ['--verbose']
+ virtualenv_opts += ["--verbose"]
if not os.path.isfile(python_binary):
raise Exception('Python binary "%s" doesn\'t exist' % (python_binary))
@@ -153,39 +181,44 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se
if not os.path.isfile(virtualenv_binary):
raise Exception('Virtualenv binary "%s" doesn\'t exist.' % (virtualenv_binary))
- logger.debug('Creating virtualenv in "%s" using Python binary "%s"' %
- (virtualenv_path, python_binary))
+ logger.debug(
+ 'Creating virtualenv in "%s" using Python binary "%s"'
+ % (virtualenv_path, python_binary)
+ )
cmd = [virtualenv_binary]
- cmd.extend(['-p', python_binary])
+ cmd.extend(["-p", python_binary])
cmd.extend(virtualenv_opts)
if not include_pip:
- cmd.append('--no-pip')
+ cmd.append("--no-pip")
if not include_setuptools:
- cmd.append('--no-setuptools')
+ cmd.append("--no-setuptools")
if not include_wheel:
- cmd.append('--no-wheel')
+ cmd.append("--no-wheel")
if no_download:
- cmd.append('--no-download')
+ cmd.append("--no-download")
cmd.extend([virtualenv_path])
- logger.debug('Running command "%s" to create virtualenv.', ' '.join(cmd))
+ logger.debug('Running command "%s" to create virtualenv.', " ".join(cmd))
try:
exit_code, stdout, stderr = run_command(cmd=cmd)
except OSError as e:
- raise Exception('Error executing command %s. %s.' % (' '.join(cmd),
- six.text_type(e)))
+ raise Exception(
+ "Error executing command %s. %s." % (" ".join(cmd), six.text_type(e))
+ )
if exit_code != 0:
- raise Exception('Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s' %
- (virtualenv_path, stdout, stderr))
+ raise Exception(
+ 'Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s'
+ % (virtualenv_path, stdout, stderr)
+ )
return True
@@ -203,53 +236,62 @@ def remove_virtualenv(virtualenv_path, logger=None):
logger.debug('Removing virtualenv in "%s"' % virtualenv_path)
try:
shutil.rmtree(virtualenv_path)
- logger.debug('Virtualenv successfull removed.')
+ logger.debug("Virtualenv successfully removed.")
except Exception as e:
- logger.error('Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e))
+ logger.error(
+ 'Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e)
+ )
raise e
return True
-def install_requirements(virtualenv_path, requirements_file_path, proxy_config=None, logger=None):
+def install_requirements(
+ virtualenv_path, requirements_file_path, proxy_config=None, logger=None
+):
"""
Install requirements from a file.
"""
logger = logger or LOG
- pip_path = os.path.join(virtualenv_path, 'bin/pip')
+ pip_path = os.path.join(virtualenv_path, "bin/pip")
pip_opts = cfg.CONF.actionrunner.pip_opts or []
cmd = [pip_path]
if proxy_config:
- cert = proxy_config.get('proxy_ca_bundle_path', None)
- https_proxy = proxy_config.get('https_proxy', None)
- http_proxy = proxy_config.get('http_proxy', None)
+ cert = proxy_config.get("proxy_ca_bundle_path", None)
+ https_proxy = proxy_config.get("https_proxy", None)
+ http_proxy = proxy_config.get("http_proxy", None)
if http_proxy:
- cmd.extend(['--proxy', http_proxy])
+ cmd.extend(["--proxy", http_proxy])
if https_proxy:
- cmd.extend(['--proxy', https_proxy])
+ cmd.extend(["--proxy", https_proxy])
if cert:
- cmd.extend(['--cert', cert])
+ cmd.extend(["--cert", cert])
- cmd.append('install')
+ cmd.append("install")
cmd.extend(pip_opts)
- cmd.extend(['-U', '-r', requirements_file_path])
+ cmd.extend(["-U", "-r", requirements_file_path])
env = get_env_for_subprocess_command()
- logger.debug('Installing requirements from file %s with command %s.',
- requirements_file_path, ' '.join(cmd))
+ logger.debug(
+ "Installing requirements from file %s with command %s.",
+ requirements_file_path,
+ " ".join(cmd),
+ )
exit_code, stdout, stderr = run_command(cmd=cmd, env=env)
if exit_code != 0:
stdout = to_ascii(stdout)
stderr = to_ascii(stderr)
- raise Exception('Failed to install requirements from "%s": %s (stderr: %s)' %
- (requirements_file_path, stdout, stderr))
+ raise Exception(
+ 'Failed to install requirements from "%s": %s (stderr: %s)'
+ % (requirements_file_path, stdout, stderr)
+ )
return True
@@ -261,35 +303,37 @@ def install_requirement(virtualenv_path, requirement, proxy_config=None, logger=
:param requirement: Requirement specifier.
"""
logger = logger or LOG
- pip_path = os.path.join(virtualenv_path, 'bin/pip')
+ pip_path = os.path.join(virtualenv_path, "bin/pip")
pip_opts = cfg.CONF.actionrunner.pip_opts or []
cmd = [pip_path]
if proxy_config:
- cert = proxy_config.get('proxy_ca_bundle_path', None)
- https_proxy = proxy_config.get('https_proxy', None)
- http_proxy = proxy_config.get('http_proxy', None)
+ cert = proxy_config.get("proxy_ca_bundle_path", None)
+ https_proxy = proxy_config.get("https_proxy", None)
+ http_proxy = proxy_config.get("http_proxy", None)
if http_proxy:
- cmd.extend(['--proxy', http_proxy])
+ cmd.extend(["--proxy", http_proxy])
if https_proxy:
- cmd.extend(['--proxy', https_proxy])
+ cmd.extend(["--proxy", https_proxy])
if cert:
- cmd.extend(['--cert', cert])
+ cmd.extend(["--cert", cert])
- cmd.append('install')
+ cmd.append("install")
cmd.extend(pip_opts)
cmd.extend([requirement])
env = get_env_for_subprocess_command()
- logger.debug('Installing requirement %s with command %s.',
- requirement, ' '.join(cmd))
+ logger.debug(
+ "Installing requirement %s with command %s.", requirement, " ".join(cmd)
+ )
exit_code, stdout, stderr = run_command(cmd=cmd, env=env)
if exit_code != 0:
- raise Exception('Failed to install requirement "%s": %s' %
- (requirement, stdout))
+ raise Exception(
+ 'Failed to install requirement "%s": %s' % (requirement, stdout)
+ )
return True
@@ -303,7 +347,7 @@ def get_env_for_subprocess_command():
"""
env = os.environ.copy()
- if 'PYTHONPATH' in env:
- del env['PYTHONPATH']
+ if "PYTHONPATH" in env:
+ del env["PYTHONPATH"]
return env
diff --git a/st2common/st2common/util/wsgi.py b/st2common/st2common/util/wsgi.py
index a3441e4bda..63ec6c6253 100644
--- a/st2common/st2common/util/wsgi.py
+++ b/st2common/st2common/util/wsgi.py
@@ -24,9 +24,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'shutdown_server_kill_pending_requests'
-]
+__all__ = ["shutdown_server_kill_pending_requests"]
def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2):
@@ -46,7 +44,7 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2):
sock.close()
active_requests = worker_pool.running()
- LOG.info('Shutting down. Requests left: %s', active_requests)
+ LOG.info("Shutting down. Requests left: %s", active_requests)
# Give active requests some time to finish
if active_requests > 0:
@@ -57,5 +55,5 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2):
for coro in running_corutines:
eventlet.greenthread.kill(coro)
- LOG.info('Exiting...')
+ LOG.info("Exiting...")
raise SystemExit()
diff --git a/st2common/st2common/validators/api/action.py b/st2common/st2common/validators/api/action.py
index 1eb5dbfeb9..973e999fa6 100644
--- a/st2common/st2common/validators/api/action.py
+++ b/st2common/st2common/validators/api/action.py
@@ -26,10 +26,7 @@
from st2common.models.system.common import ResourceReference
from six.moves import range
-__all__ = [
- 'validate_action',
- 'get_runner_model'
-]
+__all__ = ["validate_action", "get_runner_model"]
LOG = logging.getLogger(__name__)
@@ -49,14 +46,17 @@ def validate_action(action_api, runner_type_db=None):
# Check if pack is valid.
if not _is_valid_pack(action_api.pack):
packs_base_paths = get_packs_base_paths()
- packs_base_paths = ','.join(packs_base_paths)
- msg = ('Content pack "%s" is not found or doesn\'t contain actions directory. '
- 'Searched in: %s' %
- (action_api.pack, packs_base_paths))
+ packs_base_paths = ",".join(packs_base_paths)
+ msg = (
+ 'Content pack "%s" is not found or doesn\'t contain actions directory. '
+ "Searched in: %s" % (action_api.pack, packs_base_paths)
+ )
raise ValueValidationException(msg)
# Check if parameters defined are valid.
- action_ref = ResourceReference.to_string_reference(pack=action_api.pack, name=action_api.name)
+ action_ref = ResourceReference.to_string_reference(
+ pack=action_api.pack, name=action_api.name
+ )
_validate_parameters(action_ref, action_api.parameters, runner_db.runner_parameters)
@@ -66,15 +66,18 @@ def get_runner_model(action_api):
try:
runner_db = get_runnertype_by_name(action_api.runner_type)
except StackStormDBObjectNotFoundError:
- msg = ('RunnerType %s is not found. If you are using old and deprecated runner name, you '
- 'need to switch to a new one. For more information, please see '
- 'https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9' % (action_api.runner_type))
+ msg = (
+ "RunnerType %s is not found. If you are using old and deprecated runner name, you "
+ "need to switch to a new one. For more information, please see "
+ "https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9"
+ % (action_api.runner_type)
+ )
raise ValueValidationException(msg)
return runner_db
def _is_valid_pack(pack):
- return check_pack_content_directory_exists(pack=pack, content_type='actions')
+ return check_pack_content_directory_exists(pack=pack, content_type="actions")
def _validate_parameters(action_ref, action_params=None, runner_params=None):
@@ -84,32 +87,44 @@ def _validate_parameters(action_ref, action_params=None, runner_params=None):
if action_param in runner_params:
for action_param_attr, value in six.iteritems(action_param_meta):
util_schema.validate_runner_parameter_attribute_override(
- action_ref, action_param, action_param_attr,
- value, runner_params[action_param].get(action_param_attr))
-
- if 'position' in action_param_meta:
- pos = action_param_meta['position']
+ action_ref,
+ action_param,
+ action_param_attr,
+ value,
+ runner_params[action_param].get(action_param_attr),
+ )
+
+ if "position" in action_param_meta:
+ pos = action_param_meta["position"]
param = position_params.get(pos, None)
if param:
- msg = ('Parameters %s and %s have same position %d.' % (action_param, param, pos) +
- ' Position values have to be unique.')
+ msg = (
+ "Parameters %s and %s have same position %d."
+ % (action_param, param, pos)
+ + " Position values have to be unique."
+ )
raise ValueValidationException(msg)
else:
position_params[pos] = action_param
- if 'immutable' in action_param_meta:
+ if "immutable" in action_param_meta:
if action_param in runner_params:
runner_param_meta = runner_params[action_param]
- if 'immutable' in runner_param_meta:
- msg = 'Param %s is declared immutable in runner. ' % action_param + \
- 'Cannot override in action.'
+ if "immutable" in runner_param_meta:
+ msg = (
+ "Param %s is declared immutable in runner. " % action_param
+ + "Cannot override in action."
+ )
raise ValueValidationException(msg)
- if 'default' not in action_param_meta and 'default' not in runner_param_meta:
- msg = 'Immutable param %s requires a default value.' % action_param
+ if (
+ "default" not in action_param_meta
+ and "default" not in runner_param_meta
+ ):
+ msg = "Immutable param %s requires a default value." % action_param
raise ValueValidationException(msg)
else:
- if 'default' not in action_param_meta:
- msg = 'Immutable param %s requires a default value.' % action_param
+ if "default" not in action_param_meta:
+ msg = "Immutable param %s requires a default value." % action_param
raise ValueValidationException(msg)
return _validate_position_values_contiguous(position_params)
@@ -120,10 +135,10 @@ def _validate_position_values_contiguous(position_params):
return True
positions = sorted(position_params.keys())
- contiguous = (positions == list(range(min(positions), max(positions) + 1)))
+ contiguous = positions == list(range(min(positions), max(positions) + 1))
if not contiguous:
- msg = 'Positions supplied %s for parameters are not contiguous.' % positions
+ msg = "Positions supplied %s for parameters are not contiguous." % positions
raise ValueValidationException(msg)
return True
diff --git a/st2common/st2common/validators/api/misc.py b/st2common/st2common/validators/api/misc.py
index b18ff05d21..215afc5501 100644
--- a/st2common/st2common/validators/api/misc.py
+++ b/st2common/st2common/validators/api/misc.py
@@ -17,9 +17,7 @@
from st2common.constants.pack import SYSTEM_PACK_NAME
from st2common.exceptions.apivalidation import ValueValidationException
-__all__ = [
- 'validate_not_part_of_system_pack'
-]
+__all__ = ["validate_not_part_of_system_pack"]
def validate_not_part_of_system_pack(resource_db):
@@ -32,10 +30,10 @@ def validate_not_part_of_system_pack(resource_db):
:param resource_db: Resource database object to check.
:type resource_db: ``object``
"""
- pack = getattr(resource_db, 'pack', None)
+ pack = getattr(resource_db, "pack", None)
if pack == SYSTEM_PACK_NAME:
- msg = 'Resources belonging to system level packs can\'t be manipulated'
+ msg = "Resources belonging to system level packs can't be manipulated"
raise ValueValidationException(msg)
return resource_db
diff --git a/st2common/st2common/validators/api/reactor.py b/st2common/st2common/validators/api/reactor.py
index eb2cf1c814..0d84a66a99 100644
--- a/st2common/st2common/validators/api/reactor.py
+++ b/st2common/st2common/validators/api/reactor.py
@@ -29,10 +29,9 @@
from st2common.services import triggers
__all__ = [
- 'validate_criteria',
-
- 'validate_trigger_parameters',
- 'validate_trigger_payload'
+ "validate_criteria",
+ "validate_trigger_parameters",
+ "validate_trigger_payload",
]
@@ -43,20 +42,30 @@
def validate_criteria(criteria):
if not isinstance(criteria, dict):
- raise ValueValidationException('Criteria should be a dict.')
+ raise ValueValidationException("Criteria should be a dict.")
for key, value in six.iteritems(criteria):
- operator = value.get('type', None)
+ operator = value.get("type", None)
if operator is None:
- raise ValueValidationException('Operator not specified for field: ' + key)
+ raise ValueValidationException("Operator not specified for field: " + key)
if operator not in allowed_operators:
- raise ValueValidationException('For field: ' + key + ', operator ' + operator +
- ' not in list of allowed operators: ' +
- str(list(allowed_operators.keys())))
- pattern = value.get('pattern', None)
+ raise ValueValidationException(
+ "For field: "
+ + key
+ + ", operator "
+ + operator
+ + " not in list of allowed operators: "
+ + str(list(allowed_operators.keys()))
+ )
+ pattern = value.get("pattern", None)
if pattern is None:
- raise ValueValidationException('For field: ' + key + ', no pattern specified ' +
- 'for operator ' + operator)
+ raise ValueValidationException(
+ "For field: "
+ + key
+ + ", no pattern specified "
+ + "for operator "
+ + operator
+ )
def validate_trigger_parameters(trigger_type_ref, parameters):
@@ -77,27 +86,33 @@ def validate_trigger_parameters(trigger_type_ref, parameters):
is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES
if is_system_trigger:
# System trigger
- parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['parameters_schema']
+ parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["parameters_schema"]
else:
trigger_type_db = triggers.get_trigger_type_db(trigger_type_ref)
if not trigger_type_db:
# Trigger doesn't exist in the database
return None
- parameters_schema = getattr(trigger_type_db, 'parameters_schema', {})
+ parameters_schema = getattr(trigger_type_db, "parameters_schema", {})
if not parameters_schema:
# Parameters schema not defined for the this trigger
return None
# We only validate non-system triggers if config option is set (enabled)
if not is_system_trigger and not cfg.CONF.system.validate_trigger_parameters:
- LOG.debug('Got non-system trigger "%s", but trigger parameter validation for non-system'
- 'triggers is disabled, skipping validation.' % (trigger_type_ref))
+ LOG.debug(
+ 'Got non-system trigger "%s", but trigger parameter validation for non-system'
+ "triggers is disabled, skipping validation." % (trigger_type_ref)
+ )
return None
- cleaned = util_schema.validate(instance=parameters, schema=parameters_schema,
- cls=util_schema.CustomValidator, use_default=True,
- allow_default_none=True)
+ cleaned = util_schema.validate(
+ instance=parameters,
+ schema=parameters_schema,
+ cls=util_schema.CustomValidator,
+ use_default=True,
+ allow_default_none=True,
+ )
# Additional validation for CronTimer trigger
# TODO: If we need to add more checks like this we should consider abstracting this out.
@@ -110,7 +125,9 @@ def validate_trigger_parameters(trigger_type_ref, parameters):
return cleaned
-def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trigger=False):
+def validate_trigger_payload(
+ trigger_type_ref, payload, throw_on_inexistent_trigger=False
+):
"""
This function validates trigger payload parameters for system and user-defined triggers.
@@ -128,8 +145,8 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig
# NOTE: Due to the awful code in some other places we also need to support a scenario where
# this variable is a dictionary and contains various TriggerDB object attributes.
if isinstance(trigger_type_ref, dict):
- if trigger_type_ref.get('type', None):
- trigger_type_ref = trigger_type_ref['type']
+ if trigger_type_ref.get("type", None):
+ trigger_type_ref = trigger_type_ref["type"]
else:
trigger_db = triggers.get_trigger_db_by_ref_or_dict(trigger_type_ref)
@@ -143,16 +160,16 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig
is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES
if is_system_trigger:
# System trigger
- payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['payload_schema']
+ payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["payload_schema"]
else:
# We assume Trigger ref and not TriggerType ref is passed in if second
# part (trigger name) is a valid UUID version 4
try:
- trigger_uuid = uuid.UUID(trigger_type_ref.split('.')[-1])
+ trigger_uuid = uuid.UUID(trigger_type_ref.split(".")[-1])
except ValueError:
is_trigger_db = False
else:
- is_trigger_db = (trigger_uuid.version == 4)
+ is_trigger_db = trigger_uuid.version == 4
if is_trigger_db:
trigger_db = triggers.get_trigger_db_by_ref(trigger_type_ref)
@@ -165,25 +182,33 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig
if not trigger_type_db:
# Trigger doesn't exist in the database
if throw_on_inexistent_trigger:
- msg = ('Trigger type with reference "%s" doesn\'t exist in the database' %
- (trigger_type_ref))
+ msg = (
+ 'Trigger type with reference "%s" doesn\'t exist in the database'
+ % (trigger_type_ref)
+ )
raise ValueError(msg)
return None
- payload_schema = getattr(trigger_type_db, 'payload_schema', {})
+ payload_schema = getattr(trigger_type_db, "payload_schema", {})
if not payload_schema:
# Payload schema not defined for the this trigger
return None
# We only validate non-system triggers if config option is set (enabled)
if not is_system_trigger and not cfg.CONF.system.validate_trigger_payload:
- LOG.debug('Got non-system trigger "%s", but trigger payload validation for non-system'
- 'triggers is disabled, skipping validation.' % (trigger_type_ref))
+ LOG.debug(
+ 'Got non-system trigger "%s", but trigger payload validation for non-system'
+ "triggers is disabled, skipping validation." % (trigger_type_ref)
+ )
return None
- cleaned = util_schema.validate(instance=payload, schema=payload_schema,
- cls=util_schema.CustomValidator, use_default=True,
- allow_default_none=True)
+ cleaned = util_schema.validate(
+ instance=payload,
+ schema=payload_schema,
+ cls=util_schema.CustomValidator,
+ use_default=True,
+ allow_default_none=True,
+ )
return cleaned
diff --git a/st2common/st2common/validators/workflow/base.py b/st2common/st2common/validators/workflow/base.py
index 226a4668fb..3bf8e9fbd5 100644
--- a/st2common/st2common/validators/workflow/base.py
+++ b/st2common/st2common/validators/workflow/base.py
@@ -20,7 +20,6 @@
@six.add_metaclass(abc.ABCMeta)
class WorkflowValidator(object):
-
@abc.abstractmethod
def validate(self, definition):
raise NotImplementedError
diff --git a/st2common/tests/fixtures/mock_runner/mock_runner.py b/st2common/tests/fixtures/mock_runner/mock_runner.py
index 9110e740f4..66295e8421 100644
--- a/st2common/tests/fixtures/mock_runner/mock_runner.py
+++ b/st2common/tests/fixtures/mock_runner/mock_runner.py
@@ -23,9 +23,7 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'get_runner'
-]
+__all__ = ["get_runner"]
def get_runner():
@@ -36,6 +34,7 @@ class MockRunner(ActionRunner):
"""
Runner which does absolutely nothing.
"""
+
KEYS_TO_TRANSFORM = []
def __init__(self, runner_id):
@@ -47,9 +46,9 @@ def pre_run(self):
def run(self, action_parameters):
result = {
- 'failed': False,
- 'succeeded': True,
- 'return_code': 0,
+ "failed": False,
+ "succeeded": True,
+ "return_code": 0,
}
status = LIVEACTION_STATUS_SUCCEEDED
diff --git a/st2common/tests/fixtures/version_file.py b/st2common/tests/fixtures/version_file.py
index 882f420538..b52f01d75c 100644
--- a/st2common/tests/fixtures/version_file.py
+++ b/st2common/tests/fixtures/version_file.py
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '1.2.3'
+__version__ = "1.2.3"
diff --git a/st2common/tests/integration/test_rabbitmq_ssl_listener.py b/st2common/tests/integration/test_rabbitmq_ssl_listener.py
index 9c1ddeef06..e64a22995d 100644
--- a/st2common/tests/integration/test_rabbitmq_ssl_listener.py
+++ b/st2common/tests/integration/test_rabbitmq_ssl_listener.py
@@ -27,12 +27,10 @@
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'RabbitMQTLSListenerTestCase'
-]
+__all__ = ["RabbitMQTLSListenerTestCase"]
-CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'ssl_certs/')
-ST2_CI = (os.environ.get('ST2_CI', 'false').lower() == 'true')
+CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "ssl_certs/")
+ST2_CI = os.environ.get("ST2_CI", "false").lower() == "true"
NON_SSL_LISTENER_PORT = 5672
SSL_LISTENER_PORT = 5671
@@ -40,42 +38,49 @@
# NOTE: We only run those tests on the CI provider because at the moment, local
# vagrant dev VM doesn't expose RabbitMQ SSL listener by default
-@unittest2.skipIf(not ST2_CI,
- 'Skipping tests because ST2_CI environment variable is not set to "true"')
+@unittest2.skipIf(
+ not ST2_CI,
+ 'Skipping tests because ST2_CI environment variable is not set to "true"',
+)
class RabbitMQTLSListenerTestCase(unittest2.TestCase):
-
def setUp(self):
# Set default values
- cfg.CONF.set_override(name='ssl', override=False, group='messaging')
- cfg.CONF.set_override(name='ssl_keyfile', override=None, group='messaging')
- cfg.CONF.set_override(name='ssl_certfile', override=None, group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=None, group='messaging')
- cfg.CONF.set_override(name='ssl_cert_reqs', override=None, group='messaging')
+ cfg.CONF.set_override(name="ssl", override=False, group="messaging")
+ cfg.CONF.set_override(name="ssl_keyfile", override=None, group="messaging")
+ cfg.CONF.set_override(name="ssl_certfile", override=None, group="messaging")
+ cfg.CONF.set_override(name="ssl_ca_certs", override=None, group="messaging")
+ cfg.CONF.set_override(name="ssl_cert_reqs", override=None, group="messaging")
def test_non_ssl_connection_on_ssl_listener_port_failure(self):
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
- expected_msg_1 = '[Errno 104]' # followed by: ' Connection reset by peer' or ' ECONNRESET'
- expected_msg_2 = 'Socket closed'
- expected_msg_3 = 'Server unexpectedly closed connection'
+ expected_msg_1 = (
+ "[Errno 104]" # followed by: ' Connection reset by peer' or ' ECONNRESET'
+ )
+ expected_msg_2 = "Socket closed"
+ expected_msg_3 = "Server unexpectedly closed connection"
try:
connection.connect()
except Exception as e:
self.assertFalse(connection.connected)
self.assertIsInstance(e, (IOError, socket.error))
- self.assertTrue(expected_msg_1 in six.text_type(e) or
- expected_msg_2 in six.text_type(e) or
- expected_msg_3 in six.text_type(e))
+ self.assertTrue(
+ expected_msg_1 in six.text_type(e)
+ or expected_msg_2 in six.text_type(e)
+ or expected_msg_3 in six.text_type(e)
+ )
else:
- self.fail('Exception was not thrown')
+ self.fail("Exception was not thrown")
if connection:
connection.release()
def test_ssl_connection_on_ssl_listener_success(self):
# Using query param notation
- urls = 'amqp://guest:guest@127.0.0.1:5671/?ssl=true'
+ urls = "amqp://guest:guest@127.0.0.1:5671/?ssl=true"
connection = transport_utils.get_connection(urls=urls)
try:
@@ -86,9 +91,11 @@ def test_ssl_connection_on_ssl_listener_success(self):
connection.release()
# Using messaging.ssl config option
- cfg.CONF.set_override(name='ssl', override=True, group='messaging')
+ cfg.CONF.set_override(name="ssl", override=True, group="messaging")
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
try:
self.assertTrue(connection.connect())
@@ -98,15 +105,21 @@ def test_ssl_connection_on_ssl_listener_success(self):
connection.release()
def test_ssl_connection_ca_certs_provided(self):
- ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem')
+ ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem")
- cfg.CONF.set_override(name='ssl', override=True, group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
+ cfg.CONF.set_override(name="ssl", override=True, group="messaging")
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
# 1. Validate server cert against a valid CA bundle (success) - cert required
- cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging')
+ cfg.CONF.set_override(
+ name="ssl_cert_reqs", override="required", group="messaging"
+ )
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
try:
self.assertTrue(connection.connect())
@@ -117,35 +130,51 @@ def test_ssl_connection_ca_certs_provided(self):
# 2. Validate server cert against other CA bundle (failure)
# CA bundle which was not used to sign the server cert
- ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem')
+ ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem")
- cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
+ cfg.CONF.set_override(
+ name="ssl_cert_reqs", override="required", group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
- expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed'
+ expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed"
self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect)
# 3. Validate server cert against other CA bundle (failure)
- ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem')
+ ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem")
- cfg.CONF.set_override(name='ssl_cert_reqs', override='optional', group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
+ cfg.CONF.set_override(
+ name="ssl_cert_reqs", override="optional", group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
- expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed'
+ expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed"
self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect)
# 4. Validate server cert against other CA bundle (failure)
# We use invalid bundle but cert_reqs is none
- ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem')
+ ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem")
- cfg.CONF.set_override(name='ssl_cert_reqs', override='none', group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
+ cfg.CONF.set_override(name="ssl_cert_reqs", override="none", group="messaging")
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
try:
self.assertTrue(connection.connect())
@@ -156,16 +185,28 @@ def test_ssl_connection_ca_certs_provided(self):
def test_ssl_connect_client_side_cert_authentication(self):
# 1. Success, valid client side cert provided
- ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem')
- ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'client/client_certificate.pem')
- ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem')
-
- cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging')
- cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging')
- cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
-
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
+ ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem")
+ ssl_certfile = os.path.join(
+ CERTS_FIXTURES_PATH, "client/client_certificate.pem"
+ )
+ ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem")
+
+ cfg.CONF.set_override(
+ name="ssl_keyfile", override=ssl_keyfile, group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_certfile", override=ssl_certfile, group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_cert_reqs", override="required", group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
+
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
try:
self.assertTrue(connection.connect())
@@ -175,16 +216,28 @@ def test_ssl_connect_client_side_cert_authentication(self):
connection.release()
# 2. Invalid client side cert provided - failure
- ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem')
- ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'server/server_certificate.pem')
- ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem')
-
- cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging')
- cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging')
- cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging')
- cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging')
-
- connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/')
-
- expected_msg = r'\[X509: KEY_VALUES_MISMATCH\] key values mismatch'
+ ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem")
+ ssl_certfile = os.path.join(
+ CERTS_FIXTURES_PATH, "server/server_certificate.pem"
+ )
+ ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem")
+
+ cfg.CONF.set_override(
+ name="ssl_keyfile", override=ssl_keyfile, group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_certfile", override=ssl_certfile, group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_cert_reqs", override="required", group="messaging"
+ )
+ cfg.CONF.set_override(
+ name="ssl_ca_certs", override=ca_cert_path, group="messaging"
+ )
+
+ connection = transport_utils.get_connection(
+ urls="amqp://guest:guest@127.0.0.1:5671/"
+ )
+
+ expected_msg = r"\[X509: KEY_VALUES_MISMATCH\] key values mismatch"
self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect)
diff --git a/st2common/tests/integration/test_register_content_script.py b/st2common/tests/integration/test_register_content_script.py
index 57568df1c2..0a0dc8e7f1 100644
--- a/st2common/tests/integration/test_register_content_script.py
+++ b/st2common/tests/integration/test_register_content_script.py
@@ -26,15 +26,15 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-SCRIPT_PATH = os.path.join(BASE_DIR, '../../bin/st2-register-content')
+SCRIPT_PATH = os.path.join(BASE_DIR, "../../bin/st2-register-content")
SCRIPT_PATH = os.path.abspath(SCRIPT_PATH)
-BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests.conf', '-v']
-BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ['--register-actions']
+BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, "--config-file=conf/st2.tests.conf", "-v"]
+BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ["--register-actions"]
PACKS_PATH = get_fixtures_packs_base_path()
-PACKS_COUNT = len(glob.glob('%s/*/pack.yaml' % (PACKS_PATH)))
-assert(PACKS_COUNT >= 2)
+PACKS_COUNT = len(glob.glob("%s/*/pack.yaml" % (PACKS_PATH)))
+assert PACKS_COUNT >= 2
class ContentRegisterScriptTestCase(IntegrationTestCase):
@@ -43,27 +43,27 @@ def setUp(self):
test_config.parse_args()
def test_register_from_pack_success(self):
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1')
- runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1")
+ runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners")
opts = [
- '--register-pack=%s' % (pack_dir),
- '--register-runner-dir=%s' % (runner_dirs),
+ "--register-pack=%s" % (pack_dir),
+ "--register-runner-dir=%s" % (runner_dirs),
]
cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts
exit_code, _, stderr = run_command(cmd=cmd)
- self.assertIn('Registered 1 actions.', stderr)
+ self.assertIn("Registered 1 actions.", stderr)
self.assertEqual(exit_code, 0)
def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self):
# No fail on failure flag, should succeed
- pack_dir = 'doesntexistblah'
- runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners')
+ pack_dir = "doesntexistblah"
+ runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners")
opts = [
- '--register-pack=%s' % (pack_dir),
- '--register-runner-dir=%s' % (runner_dirs),
- '--register-no-fail-on-failure'
+ "--register-pack=%s" % (pack_dir),
+ "--register-runner-dir=%s" % (runner_dirs),
+ "--register-no-fail-on-failure",
]
cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts
exit_code, _, _ = run_command(cmd=cmd)
@@ -71,9 +71,9 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self):
# Fail on failure, should fail
opts = [
- '--register-pack=%s' % (pack_dir),
- '--register-runner-dir=%s' % (runner_dirs),
- '--register-fail-on-failure'
+ "--register-pack=%s" % (pack_dir),
+ "--register-runner-dir=%s" % (runner_dirs),
+ "--register-fail-on-failure",
]
cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts
exit_code, _, stderr = run_command(cmd=cmd)
@@ -82,30 +82,30 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self):
def test_register_from_pack_action_metadata_fails_validation(self):
# No fail on failure flag, should succeed
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4')
- runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4")
+ runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners")
opts = [
- '--register-pack=%s' % (pack_dir),
- '--register-no-fail-on-failure',
- '--register-runner-dir=%s' % (runner_dirs),
+ "--register-pack=%s" % (pack_dir),
+ "--register-no-fail-on-failure",
+ "--register-runner-dir=%s" % (runner_dirs),
]
cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts
exit_code, _, stderr = run_command(cmd=cmd)
- self.assertIn('Registered 0 actions.', stderr)
+ self.assertIn("Registered 0 actions.", stderr)
self.assertEqual(exit_code, 0)
# Fail on failure, should fail
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4")
opts = [
- '--register-pack=%s' % (pack_dir),
- '--register-fail-on-failure',
- '--register-runner-dir=%s' % (runner_dirs),
+ "--register-pack=%s" % (pack_dir),
+ "--register-fail-on-failure",
+ "--register-runner-dir=%s" % (runner_dirs),
]
cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts
exit_code, _, stderr = run_command(cmd=cmd)
- self.assertIn('object has no attribute \'get\'', stderr)
+ self.assertIn("object has no attribute 'get'", stderr)
self.assertEqual(exit_code, 1)
def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self):
@@ -114,57 +114,74 @@ def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self):
# Note: We want to use a different config which sets fixtures/packs_1/
# dir as packs_base_paths
- cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v',
- '--register-sensors']
+ cmd = [
+ sys.executable,
+ SCRIPT_PATH,
+ "--config-file=conf/st2.tests1.conf",
+ "-v",
+ "--register-sensors",
+ ]
exit_code, _, stderr = run_command(cmd=cmd)
- self.assertIn('Registered 0 sensors.', stderr, 'Actual stderr: %s' % (stderr))
+ self.assertIn("Registered 0 sensors.", stderr, "Actual stderr: %s" % (stderr))
self.assertEqual(exit_code, 0)
- cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v',
- '--register-all', '--register-no-fail-on-failure']
+ cmd = [
+ sys.executable,
+ SCRIPT_PATH,
+ "--config-file=conf/st2.tests1.conf",
+ "-v",
+ "--register-all",
+ "--register-no-fail-on-failure",
+ ]
exit_code, _, stderr = run_command(cmd=cmd)
- self.assertIn('Registered 0 actions.', stderr)
- self.assertIn('Registered 0 sensors.', stderr)
- self.assertIn('Registered 0 rules.', stderr)
+ self.assertIn("Registered 0 actions.", stderr)
+ self.assertIn("Registered 0 sensors.", stderr)
+ self.assertIn("Registered 0 rules.", stderr)
self.assertEqual(exit_code, 0)
def test_register_all_and_register_setup_virtualenvs(self):
# Verify that --register-all works in combinations with --register-setup-virtualenvs
# Single pack
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1")
cmd = BASE_CMD_ARGS + [
- '--register-pack=%s' % (pack_dir),
- '--register-all',
- '--register-setup-virtualenvs',
- '--register-no-fail-on-failure'
+ "--register-pack=%s" % (pack_dir),
+ "--register-all",
+ "--register-setup-virtualenvs",
+ "--register-no-fail-on-failure",
]
exit_code, stdout, stderr = run_command(cmd=cmd)
- self.assertIn('Registering actions', stderr, 'Actual stderr: %s' % (stderr))
- self.assertIn('Registering rules', stderr)
- self.assertIn('Setup virtualenv for %s pack(s)' % ('1'), stderr)
+ self.assertIn("Registering actions", stderr, "Actual stderr: %s" % (stderr))
+ self.assertIn("Registering rules", stderr)
+ self.assertIn("Setup virtualenv for %s pack(s)" % ("1"), stderr)
self.assertEqual(exit_code, 0)
def test_register_setup_virtualenvs(self):
# Single pack
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1")
- cmd = BASE_CMD_ARGS + ['--register-pack=%s' % (pack_dir), '--register-setup-virtualenvs',
- '--register-no-fail-on-failure']
+ cmd = BASE_CMD_ARGS + [
+ "--register-pack=%s" % (pack_dir),
+ "--register-setup-virtualenvs",
+ "--register-no-fail-on-failure",
+ ]
exit_code, stdout, stderr = run_command(cmd=cmd)
self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr)
- self.assertIn('Setup virtualenv for 1 pack(s)', stderr)
+ self.assertIn("Setup virtualenv for 1 pack(s)", stderr)
self.assertEqual(exit_code, 0)
def test_register_recreate_virtualenvs(self):
# Single pack
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1")
- cmd = BASE_CMD_ARGS + ['--register-pack=%s' % (pack_dir), '--register-recreate-virtualenvs',
- '--register-no-fail-on-failure']
+ cmd = BASE_CMD_ARGS + [
+ "--register-pack=%s" % (pack_dir),
+ "--register-recreate-virtualenvs",
+ "--register-no-fail-on-failure",
+ ]
exit_code, stdout, stderr = run_command(cmd=cmd)
self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr)
- self.assertIn('Setup virtualenv for 1 pack(s)', stderr)
- self.assertIn('Virtualenv successfull removed.', stderr)
+ self.assertIn("Setup virtualenv for 1 pack(s)", stderr)
+ self.assertIn("Virtualenv successfully removed.", stderr)
self.assertEqual(exit_code, 0)
diff --git a/st2common/tests/integration/test_service_setup_log_level_filtering.py b/st2common/tests/integration/test_service_setup_log_level_filtering.py
index ac3f90deaf..a03e90688a 100644
--- a/st2common/tests/integration/test_service_setup_log_level_filtering.py
+++ b/st2common/tests/integration/test_service_setup_log_level_filtering.py
@@ -25,36 +25,42 @@
from st2tests.base import IntegrationTestCase
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'ServiceSetupLogLevelFilteringTestCase'
-]
+__all__ = ["ServiceSetupLogLevelFilteringTestCase"]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
FIXTURES_DIR = get_fixtures_base_path()
-ST2_CONFIG_INFO_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.info_log_level.conf')
+ST2_CONFIG_INFO_LL_PATH = os.path.join(
+ FIXTURES_DIR, "conf/st2.tests.api.info_log_level.conf"
+)
ST2_CONFIG_INFO_LL_PATH = os.path.abspath(ST2_CONFIG_INFO_LL_PATH)
-ST2_CONFIG_DEBUG_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.debug_log_level.conf')
+ST2_CONFIG_DEBUG_LL_PATH = os.path.join(
+ FIXTURES_DIR, "conf/st2.tests.api.debug_log_level.conf"
+)
ST2_CONFIG_DEBUG_LL_PATH = os.path.abspath(ST2_CONFIG_DEBUG_LL_PATH)
-ST2_CONFIG_AUDIT_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.audit_log_level.conf')
+ST2_CONFIG_AUDIT_LL_PATH = os.path.join(
+ FIXTURES_DIR, "conf/st2.tests.api.audit_log_level.conf"
+)
ST2_CONFIG_AUDIT_LL_PATH = os.path.abspath(ST2_CONFIG_AUDIT_LL_PATH)
-ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join(FIXTURES_DIR,
- 'conf/st2.tests.api.system_debug_true.conf')
+ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join(
+ FIXTURES_DIR, "conf/st2.tests.api.system_debug_true.conf"
+)
ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.abspath(ST2_CONFIG_SYSTEM_DEBUG_PATH)
-ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join(FIXTURES_DIR,
- 'conf/st2.tests.api.system_debug_true_logging_debug.conf')
+ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join(
+ FIXTURES_DIR, "conf/st2.tests.api.system_debug_true_logging_debug.conf"
+)
PYTHON_BINARY = sys.executable
-ST2API_BINARY = os.path.join(BASE_DIR, '../../../st2api/bin/st2api')
+ST2API_BINARY = os.path.join(BASE_DIR, "../../../st2api/bin/st2api")
ST2API_BINARY = os.path.abspath(ST2API_BINARY)
-CMD = [PYTHON_BINARY, ST2API_BINARY, '--config-file']
+CMD = [PYTHON_BINARY, ST2API_BINARY, "--config-file"]
class ServiceSetupLogLevelFilteringTestCase(IntegrationTestCase):
@@ -68,11 +74,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self):
process.send_signal(signal.SIGKILL)
# First 3 log lines are debug messages about the environment which are always logged
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:])
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:])
- self.assertIn('INFO [-]', stdout)
- self.assertNotIn('DEBUG [-]', stdout)
- self.assertNotIn('AUDIT [-]', stdout)
+ self.assertIn("INFO [-]", stdout)
+ self.assertNotIn("DEBUG [-]", stdout)
+ self.assertNotIn("AUDIT [-]", stdout)
# 2. DEBUG log level - audit messages should be included
process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH)
@@ -83,11 +89,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self):
process.send_signal(signal.SIGKILL)
# First 3 log lines are debug messages about the environment which are always logged
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:])
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:])
- self.assertIn('INFO [-]', stdout)
- self.assertIn('DEBUG [-]', stdout)
- self.assertIn('AUDIT [-]', stdout)
+ self.assertIn("INFO [-]", stdout)
+ self.assertIn("DEBUG [-]", stdout)
+ self.assertIn("AUDIT [-]", stdout)
# 3. AUDIT log level - audit messages should be included
process = self._start_process(config_path=ST2_CONFIG_AUDIT_LL_PATH)
@@ -98,11 +104,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self):
process.send_signal(signal.SIGKILL)
# First 3 log lines are debug messages about the environment which are always logged
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:])
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:])
- self.assertNotIn('INFO [-]', stdout)
- self.assertNotIn('DEBUG [-]', stdout)
- self.assertIn('AUDIT [-]', stdout)
+ self.assertNotIn("INFO [-]", stdout)
+ self.assertNotIn("DEBUG [-]", stdout)
+ self.assertIn("AUDIT [-]", stdout)
# 2. INFO log level but system.debug set to True
process = self._start_process(config_path=ST2_CONFIG_SYSTEM_DEBUG_PATH)
@@ -113,11 +119,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self):
process.send_signal(signal.SIGKILL)
# First 3 log lines are debug messages about the environment which are always logged
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:])
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:])
- self.assertIn('INFO [-]', stdout)
- self.assertIn('DEBUG [-]', stdout)
- self.assertIn('AUDIT [-]', stdout)
+ self.assertIn("INFO [-]", stdout)
+ self.assertIn("DEBUG [-]", stdout)
+ self.assertIn("AUDIT [-]", stdout)
def test_kombu_heartbeat_tick_log_messages_are_excluded(self):
# 1. system.debug = True config option is set, verify heartbeat_tick message is not logged
@@ -128,8 +134,8 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self):
eventlet.sleep(5)
process.send_signal(signal.SIGKILL)
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n'))
- self.assertNotIn('heartbeat_tick', stdout)
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n"))
+ self.assertNotIn("heartbeat_tick", stdout)
# 2. system.debug = False, log level is set to debug
process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH)
@@ -139,14 +145,19 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self):
eventlet.sleep(5)
process.send_signal(signal.SIGKILL)
- stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n'))
- self.assertNotIn('heartbeat_tick', stdout)
+ stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n"))
+ self.assertNotIn("heartbeat_tick", stdout)
def _start_process(self, config_path):
cmd = CMD + [config_path]
- cwd = os.path.abspath(os.path.join(BASE_DIR, '../../../'))
+ cwd = os.path.abspath(os.path.join(BASE_DIR, "../../../"))
cwd = os.path.abspath(cwd)
- process = subprocess.Popen(cmd, cwd=cwd,
- stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- shell=False, preexec_fn=os.setsid)
+ process = subprocess.Popen(
+ cmd,
+ cwd=cwd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ preexec_fn=os.setsid,
+ )
return process
diff --git a/st2common/tests/unit/base.py b/st2common/tests/unit/base.py
index 6a22b139db..65948d1d11 100644
--- a/st2common/tests/unit/base.py
+++ b/st2common/tests/unit/base.py
@@ -24,13 +24,11 @@
from st2common.exceptions.db import StackStormDBObjectNotFoundError
__all__ = [
- 'BaseDBModelCRUDTestCase',
-
- 'FakeModel',
- 'FakeModelDB',
-
- 'ChangeRevFakeModel',
- 'ChangeRevFakeModelDB'
+ "BaseDBModelCRUDTestCase",
+ "FakeModel",
+ "FakeModelDB",
+ "ChangeRevFakeModel",
+ "ChangeRevFakeModelDB",
]
@@ -57,19 +55,26 @@ def test_crud_operations(self):
self.assertEqual(getattr(retrieved_db, attribute_name), attribute_value)
# 2. Test update
- updated_attribute_value = 'updated-%s' % (str(time.time()))
+ updated_attribute_value = "updated-%s" % (str(time.time()))
setattr(model_db, self.update_attribute_name, updated_attribute_value)
saved_db = self.persistance_class.add_or_update(model_db)
- self.assertEqual(getattr(saved_db, self.update_attribute_name), updated_attribute_value)
+ self.assertEqual(
+ getattr(saved_db, self.update_attribute_name), updated_attribute_value
+ )
retrieved_db = self.persistance_class.get_by_id(saved_db.id)
self.assertEqual(saved_db.id, retrieved_db.id)
- self.assertEqual(getattr(retrieved_db, self.update_attribute_name), updated_attribute_value)
+ self.assertEqual(
+ getattr(retrieved_db, self.update_attribute_name), updated_attribute_value
+ )
# 3. Test delete
self.persistance_class.delete(model_db)
- self.assertRaises(StackStormDBObjectNotFoundError, self.persistance_class.get_by_id,
- model_db.id)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ self.persistance_class.get_by_id,
+ model_db.id,
+ )
class FakeModelDB(stormbase.StormBaseDB):
@@ -79,11 +84,11 @@ class FakeModelDB(stormbase.StormBaseDB):
timestamp = mongoengine.DateTimeField()
meta = {
- 'indexes': [
- {'fields': ['index']},
- {'fields': ['category']},
- {'fields': ['timestamp']},
- {'fields': ['context.user']},
+ "indexes": [
+ {"fields": ["index"]},
+ {"fields": ["category"]},
+ {"fields": ["timestamp"]},
+ {"fields": ["context.user"]},
]
}
diff --git a/st2common/tests/unit/services/test_access.py b/st2common/tests/unit/services/test_access.py
index 79e680b30d..4f7d8169b4 100644
--- a/st2common/tests/unit/services/test_access.py
+++ b/st2common/tests/unit/services/test_access.py
@@ -28,11 +28,10 @@
import st2tests.config as tests_config
-USERNAME = 'manas'
+USERNAME = "manas"
class AccessServiceTest(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(AccessServiceTest, cls).setUpClass()
@@ -47,7 +46,7 @@ def test_create_token(self):
def test_create_token_fail(self):
try:
access.create_token(None)
- self.assertTrue(False, 'Create succeeded was expected to fail.')
+ self.assertTrue(False, "Create succeeded was expected to fail.")
except ValueError:
self.assertTrue(True)
@@ -56,7 +55,7 @@ def test_delete_token(self):
access.delete_token(token.token)
try:
token = Token.get(token.token)
- self.assertTrue(False, 'Delete failed was expected to pass.')
+ self.assertTrue(False, "Delete failed was expected to pass.")
except TokenNotFoundError:
self.assertTrue(True)
@@ -71,13 +70,17 @@ def test_create_token_ttl_ok(self):
self.assertIsNotNone(token)
self.assertIsNotNone(token.token)
self.assertEqual(token.user, USERNAME)
- expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl)
+ expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=ttl
+ )
expected_expiry = date_utils.add_utc_tz(expected_expiry)
self.assertLess(isotime.parse(token.expiry), expected_expiry)
def test_create_token_ttl_capped(self):
ttl = cfg.CONF.auth.token_ttl + 10
- expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl)
+ expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=ttl
+ )
expected_expiry = date_utils.add_utc_tz(expected_expiry)
token = access.create_token(USERNAME, 10)
self.assertIsNotNone(token)
@@ -86,11 +89,13 @@ def test_create_token_ttl_capped(self):
self.assertLess(isotime.parse(token.expiry), expected_expiry)
def test_create_token_service_token_can_use_arbitrary_ttl(self):
- ttl = (10000 * 24 * 24)
+ ttl = 10000 * 24 * 24
# Service token should support arbitrary TTL
token = access.create_token(USERNAME, ttl=ttl, service=True)
- expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl)
+ expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(
+ seconds=ttl
+ )
expected_expiry = date_utils.add_utc_tz(expected_expiry)
self.assertIsNotNone(token)
@@ -98,5 +103,6 @@ def test_create_token_service_token_can_use_arbitrary_ttl(self):
self.assertLess(isotime.parse(token.expiry), expected_expiry)
# Non service token should throw on TTL which is too large
- self.assertRaises(TTLTooLargeException, access.create_token, USERNAME, ttl=ttl,
- service=False)
+ self.assertRaises(
+ TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, service=False
+ )
diff --git a/st2common/tests/unit/services/test_action.py b/st2common/tests/unit/services/test_action.py
index 7bda929cc0..ab8db72329 100644
--- a/st2common/tests/unit/services/test_action.py
+++ b/st2common/tests/unit/services/test_action.py
@@ -39,145 +39,126 @@
RUNNER = {
- 'name': 'local-shell-script',
- 'description': 'A runner to execute local command.',
- 'enabled': True,
- 'runner_parameters': {
- 'hosts': {'type': 'string'},
- 'cmd': {'type': 'string'},
- 'sudo': {'type': 'boolean', 'default': False}
+ "name": "local-shell-script",
+ "description": "A runner to execute local command.",
+ "enabled": True,
+ "runner_parameters": {
+ "hosts": {"type": "string"},
+ "cmd": {"type": "string"},
+ "sudo": {"type": "boolean", "default": False},
},
- 'runner_module': 'remoterunner'
+ "runner_module": "remoterunner",
}
RUNNER_ACTION_CHAIN = {
- 'name': 'action-chain',
- 'description': 'AC runner.',
- 'enabled': True,
- 'runner_parameters': {
- },
- 'runner_module': 'remoterunner'
+ "name": "action-chain",
+ "description": "AC runner.",
+ "enabled": True,
+ "runner_parameters": {},
+ "runner_module": "remoterunner",
}
ACTION = {
- 'name': 'my.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'arg_default_value': {
- 'type': 'string',
- 'default': 'abc'
- },
- 'arg_default_type': {
- }
+ "name": "my.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {
+ "arg_default_value": {"type": "string", "default": "abc"},
+ "arg_default_type": {},
},
- 'notify': {
- 'on-complete': {
- 'message': 'My awesome action is complete. Party time!!!',
- 'routes': ['notify.slack']
+ "notify": {
+ "on-complete": {
+ "message": "My awesome action is complete. Party time!!!",
+ "routes": ["notify.slack"],
}
- }
+ },
}
ACTION_WORKFLOW = {
- 'name': 'my.wf_action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'action-chain'
+ "name": "my.wf_action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "action-chain",
}
ACTION_OVR_PARAM = {
- 'name': 'my.sudo.default.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'sudo': {
- 'default': True
- }
- }
+ "name": "my.sudo.default.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {"sudo": {"default": True}},
}
ACTION_OVR_PARAM_MUTABLE = {
- 'name': 'my.sudo.mutable.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'sudo': {
- 'immutable': False
- }
- }
+ "name": "my.sudo.mutable.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {"sudo": {"immutable": False}},
}
ACTION_OVR_PARAM_IMMUTABLE = {
- 'name': 'my.sudo.immutable.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'sudo': {
- 'immutable': True
- }
- }
+ "name": "my.sudo.immutable.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {"sudo": {"immutable": True}},
}
ACTION_OVR_PARAM_BAD_ATTR = {
- 'name': 'my.sudo.invalid.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'sudo': {
- 'type': 'number'
- }
- }
+ "name": "my.sudo.invalid.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {"sudo": {"type": "number"}},
}
ACTION_OVR_PARAM_BAD_ATTR_NOOP = {
- 'name': 'my.sudo.invalid.noop.action',
- 'description': 'my test',
- 'enabled': True,
- 'entry_point': '/tmp/test/action.sh',
- 'pack': 'default',
- 'runner_type': 'local-shell-script',
- 'parameters': {
- 'sudo': {
- 'type': 'boolean'
- }
- }
+ "name": "my.sudo.invalid.noop.action",
+ "description": "my test",
+ "enabled": True,
+ "entry_point": "/tmp/test/action.sh",
+ "pack": "default",
+ "runner_type": "local-shell-script",
+ "parameters": {"sudo": {"type": "boolean"}},
}
-PACK = 'default'
-ACTION_REF = ResourceReference(name='my.action', pack=PACK).ref
-ACTION_WORKFLOW_REF = ResourceReference(name='my.wf_action', pack=PACK).ref
-ACTION_OVR_PARAM_REF = ResourceReference(name='my.sudo.default.action', pack=PACK).ref
-ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference(name='my.sudo.mutable.action', pack=PACK).ref
-ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference(name='my.sudo.immutable.action', pack=PACK).ref
-ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference(name='my.sudo.invalid.action', pack=PACK).ref
+PACK = "default"
+ACTION_REF = ResourceReference(name="my.action", pack=PACK).ref
+ACTION_WORKFLOW_REF = ResourceReference(name="my.wf_action", pack=PACK).ref
+ACTION_OVR_PARAM_REF = ResourceReference(name="my.sudo.default.action", pack=PACK).ref
+ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference(
+ name="my.sudo.mutable.action", pack=PACK
+).ref
+ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference(
+ name="my.sudo.immutable.action", pack=PACK
+).ref
+ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference(
+ name="my.sudo.invalid.action", pack=PACK
+).ref
ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF = ResourceReference(
- name='my.sudo.invalid.noop.action', pack=PACK).ref
+ name="my.sudo.invalid.noop.action", pack=PACK
+).ref
-USERNAME = 'stanley'
+USERNAME = "stanley"
-@mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None))
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(runners_utils, "invoke_post_run", mock.MagicMock(return_value=None))
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class TestActionExecutionService(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestActionExecutionService, cls).setUpClass()
@@ -188,17 +169,21 @@ def setUpClass(cls):
RunnerType.add_or_update(RunnerTypeAPI.to_model(runner_api))
cls.actions = {
- ACTION['name']: ActionAPI(**ACTION),
- ACTION_WORKFLOW['name']: ActionAPI(**ACTION_WORKFLOW),
- ACTION_OVR_PARAM['name']: ActionAPI(**ACTION_OVR_PARAM),
- ACTION_OVR_PARAM_MUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_MUTABLE),
- ACTION_OVR_PARAM_IMMUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE),
- ACTION_OVR_PARAM_BAD_ATTR['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR),
- ACTION_OVR_PARAM_BAD_ATTR_NOOP['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR_NOOP)
+ ACTION["name"]: ActionAPI(**ACTION),
+ ACTION_WORKFLOW["name"]: ActionAPI(**ACTION_WORKFLOW),
+ ACTION_OVR_PARAM["name"]: ActionAPI(**ACTION_OVR_PARAM),
+ ACTION_OVR_PARAM_MUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_MUTABLE),
+ ACTION_OVR_PARAM_IMMUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE),
+ ACTION_OVR_PARAM_BAD_ATTR["name"]: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR),
+ ACTION_OVR_PARAM_BAD_ATTR_NOOP["name"]: ActionAPI(
+ **ACTION_OVR_PARAM_BAD_ATTR_NOOP
+ ),
}
- cls.actiondbs = {name: Action.add_or_update(ActionAPI.to_model(action))
- for name, action in six.iteritems(cls.actions)}
+ cls.actiondbs = {
+ name: Action.add_or_update(ActionAPI.to_model(action))
+ for name, action in six.iteritems(cls.actions)
+ }
cls.container = RunnerContainer()
@@ -212,8 +197,8 @@ def tearDownClass(cls):
super(TestActionExecutionService, cls).tearDownClass()
def _submit_request(self, action_ref=ACTION_REF):
- context = {'user': USERNAME}
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ context = {"user": USERNAME}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
req = LiveActionDB(action=action_ref, context=context, parameters=parameters)
req, _ = action_service.request(req)
ex = action_db.get_liveaction_by_id(str(req.id))
@@ -249,7 +234,7 @@ def _create_nested_executions(self, depth=2):
root_liveaction_db = LiveAction.add_or_update(root_liveaction_db)
root_ex = executions.create_execution_object(root_liveaction_db)
- last_id = root_ex['id']
+ last_id = root_ex["id"]
# Create children to the specified depth
for i in range(depth):
@@ -264,11 +249,7 @@ def _create_nested_executions(self, depth=2):
child_liveaction_db = LiveActionDB()
child_liveaction_db.status = action_constants.LIVEACTION_STATUS_PAUSED
child_liveaction_db.action = action
- child_liveaction_db.context = {
- "parent": {
- "execution_id": last_id
- }
- }
+ child_liveaction_db.context = {"parent": {"execution_id": last_id}}
child_liveaction_db = LiveAction.add_or_update(child_liveaction_db)
parent_ex = executions.create_execution_object(child_liveaction_db)
last_id = parent_ex.id
@@ -277,104 +258,116 @@ def _create_nested_executions(self, depth=2):
return (child_liveaction_db, root_liveaction_db)
def test_req_non_workflow_action(self):
- actiondb = self.actiondbs[ACTION['name']]
+ actiondb = self.actiondbs[ACTION["name"]]
req, ex = self._submit_request(action_ref=ACTION_REF)
self.assertIsNotNone(ex)
self.assertEqual(ex.action_is_workflow, False)
self.assertEqual(ex.id, req.id)
- self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name]))
- self.assertEqual(ex.context['user'], req.context['user'])
+ self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name]))
+ self.assertEqual(ex.context["user"], req.context["user"])
self.assertDictEqual(ex.parameters, req.parameters)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
self.assertIsNotNone(ex.notify)
# mongoengine DateTimeField stores datetime only up to milliseconds
- self.assertEqual(isotime.format(ex.start_timestamp, usec=False),
- isotime.format(req.start_timestamp, usec=False))
+ self.assertEqual(
+ isotime.format(ex.start_timestamp, usec=False),
+ isotime.format(req.start_timestamp, usec=False),
+ )
def test_req_workflow_action(self):
- actiondb = self.actiondbs[ACTION_WORKFLOW['name']]
+ actiondb = self.actiondbs[ACTION_WORKFLOW["name"]]
req, ex = self._submit_request(action_ref=ACTION_WORKFLOW_REF)
self.assertIsNotNone(ex)
self.assertEqual(ex.action_is_workflow, True)
self.assertEqual(ex.id, req.id)
- self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name]))
- self.assertEqual(ex.context['user'], req.context['user'])
+ self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name]))
+ self.assertEqual(ex.context["user"], req.context["user"])
self.assertDictEqual(ex.parameters, req.parameters)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
def test_req_invalid_parameters(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': 123}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_value": 123}
liveaction = LiveActionDB(action=ACTION_REF, parameters=parameters)
- self.assertRaises(jsonschema.ValidationError, action_service.request, liveaction)
+ self.assertRaises(
+ jsonschema.ValidationError, action_service.request, liveaction
+ )
def test_req_optional_parameter_none_value(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': None}
+ parameters = {
+ "hosts": "127.0.0.1",
+ "cmd": "uname -a",
+ "arg_default_value": None,
+ }
req = LiveActionDB(action=ACTION_REF, parameters=parameters)
req, _ = action_service.request(req)
def test_req_optional_parameter_none_value_no_default(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_type': None}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_type": None}
req = LiveActionDB(action=ACTION_REF, parameters=parameters)
req, _ = action_service.request(req)
def test_req_override_runner_parameter(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters)
req, _ = action_service.request(req)
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': False}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": False}
req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters)
req, _ = action_service.request(req)
def test_req_override_runner_parameter_type_attribute_value_changed(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_REF, parameters=parameters)
with self.assertRaises(action_exc.InvalidActionParameterException) as ex_ctx:
req, _ = action_service.request(req)
- expected = ('The attribute "type" for the runner parameter "sudo" in '
- 'action "default.my.sudo.invalid.action" cannot be overridden.')
+ expected = (
+ 'The attribute "type" for the runner parameter "sudo" in '
+ 'action "default.my.sudo.invalid.action" cannot be overridden.'
+ )
self.assertEqual(str(ex_ctx.exception), expected)
def test_req_override_runner_parameter_type_attribute_no_value_changed(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
- req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters)
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
+ req = LiveActionDB(
+ action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters
+ )
req, _ = action_service.request(req)
def test_req_override_runner_parameter_mutable(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters)
req, _ = action_service.request(req)
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True}
req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters)
req, _ = action_service.request(req)
def test_req_override_runner_parameter_immutable(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters)
req, _ = action_service.request(req)
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True}
req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters)
self.assertRaises(ValueError, action_service.request, req)
def test_req_nonexistent_action(self):
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
- action_ref = ResourceReference(name='i.action', pack='default').ref
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
+ action_ref = ResourceReference(name="i.action", pack="default").ref
ex = LiveActionDB(action=action_ref, parameters=parameters)
self.assertRaises(ValueError, action_service.request, ex)
def test_req_disabled_action(self):
- actiondb = self.actiondbs[ACTION['name']]
+ actiondb = self.actiondbs[ACTION["name"]]
actiondb.enabled = False
Action.add_or_update(actiondb)
try:
- parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'}
+ parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"}
ex = LiveActionDB(action=ACTION_REF, parameters=parameters)
self.assertRaises(ValueError, action_service.request, ex)
except Exception as e:
@@ -390,7 +383,9 @@ def test_req_cancellation(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -405,7 +400,9 @@ def test_req_cancellation_uncancelable_state(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to FAILED.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_FAILED, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_FAILED, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_FAILED)
@@ -429,20 +426,20 @@ def test_req_pause_unsupported(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request pause.
self.assertRaises(
- runner_exc.InvalidActionRunnerOperationError,
- self._submit_pause,
- ex
+ runner_exc.InvalidActionRunnerOperationError, self._submit_pause, ex
)
def test_req_pause(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -451,7 +448,9 @@ def test_req_pause(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -459,11 +458,11 @@ def test_req_pause(self):
ex = self._submit_pause(ex)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING)
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_req_pause_not_running(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -473,16 +472,14 @@ def test_req_pause_not_running(self):
# Request pause.
self.assertRaises(
- runner_exc.UnexpectedActionExecutionStatusError,
- self._submit_pause,
- ex
+ runner_exc.UnexpectedActionExecutionStatusError, self._submit_pause, ex
)
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_req_pause_already_pausing(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -491,7 +488,9 @@ def test_req_pause_already_pausing(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -500,12 +499,14 @@ def test_req_pause_already_pausing(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING)
# Request pause again.
- with mock.patch.object(action_service, 'update_status', return_value=None) as mocked:
+ with mock.patch.object(
+ action_service, "update_status", return_value=None
+ ) as mocked:
ex = self._submit_pause(ex)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING)
mocked.assert_not_called()
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_req_resume_unsupported(self):
req, ex = self._submit_request()
@@ -514,20 +515,20 @@ def test_req_resume_unsupported(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request resume.
self.assertRaises(
- runner_exc.InvalidActionRunnerOperationError,
- self._submit_resume,
- ex
+ runner_exc.InvalidActionRunnerOperationError, self._submit_resume, ex
)
def test_req_resume(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -536,7 +537,9 @@ def test_req_resume(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -545,7 +548,9 @@ def test_req_resume(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING)
# Update ex status to PAUSED.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_PAUSED, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_PAUSED, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSED)
@@ -553,11 +558,11 @@ def test_req_resume(self):
ex = self._submit_resume(ex)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RESUMING)
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_req_resume_not_paused(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -566,7 +571,9 @@ def test_req_resume_not_paused(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
@@ -576,16 +583,14 @@ def test_req_resume_not_paused(self):
# Request resume.
self.assertRaises(
- runner_exc.UnexpectedActionExecutionStatusError,
- self._submit_resume,
- ex
+ runner_exc.UnexpectedActionExecutionStatusError, self._submit_resume, ex
)
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_req_resume_already_running(self):
# Add the runner type to the list of runners that support pause and resume.
- action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"])
try:
req, ex = self._submit_request()
@@ -594,25 +599,28 @@ def test_req_resume_already_running(self):
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED)
# Update ex status to RUNNING.
- action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False)
+ action_service.update_status(
+ ex, action_constants.LIVEACTION_STATUS_RUNNING, False
+ )
ex = action_db.get_liveaction_by_id(ex.id)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Request resume.
- with mock.patch.object(action_service, 'update_status', return_value=None) as mocked:
+ with mock.patch.object(
+ action_service, "update_status", return_value=None
+ ) as mocked:
ex = self._submit_resume(ex)
self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING)
mocked.assert_not_called()
finally:
- action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type'])
+ action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"])
def test_root_liveaction(self):
- """Test that get_root_liveaction correctly retrieves the root liveaction
- """
+ """Test that get_root_liveaction correctly retrieves the root liveaction"""
# Test a variety of depths
for i in range(1, 7):
child, expected_root = self._create_nested_executions(depth=i)
actual_root = action_service.get_root_liveaction(child)
- self.assertEqual(expected_root['id'], actual_root['id'])
+ self.assertEqual(expected_root["id"], actual_root["id"])
diff --git a/st2common/tests/unit/services/test_keyvalue.py b/st2common/tests/unit/services/test_keyvalue.py
index a11a3bb11b..bd080719bb 100644
--- a/st2common/tests/unit/services/test_keyvalue.py
+++ b/st2common/tests/unit/services/test_keyvalue.py
@@ -22,17 +22,22 @@
class KeyValueServicesTest(unittest2.TestCase):
-
def test_get_key_reference_system_scope(self):
- ref = get_key_reference(scope=SYSTEM_SCOPE, name='foo')
- self.assertEqual(ref, 'foo')
+ ref = get_key_reference(scope=SYSTEM_SCOPE, name="foo")
+ self.assertEqual(ref, "foo")
def test_get_key_reference_user_scope(self):
- ref = get_key_reference(scope=USER_SCOPE, name='foo', user='stanley')
- self.assertEqual(ref, 'stanley:foo')
- self.assertRaises(InvalidUserException, get_key_reference,
- scope=USER_SCOPE, name='foo', user='')
+ ref = get_key_reference(scope=USER_SCOPE, name="foo", user="stanley")
+ self.assertEqual(ref, "stanley:foo")
+ self.assertRaises(
+ InvalidUserException,
+ get_key_reference,
+ scope=USER_SCOPE,
+ name="foo",
+ user="",
+ )
def test_get_key_reference_invalid_scope_raises_exception(self):
- self.assertRaises(InvalidScopeException, get_key_reference,
- scope='sketchy', name='foo')
+ self.assertRaises(
+ InvalidScopeException, get_key_reference, scope="sketchy", name="foo"
+ )
diff --git a/st2common/tests/unit/services/test_policy.py b/st2common/tests/unit/services/test_policy.py
index 69fb0624e6..128ce1defe 100644
--- a/st2common/tests/unit/services/test_policy.py
+++ b/st2common/tests/unit/services/test_policy.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
import st2tests.config as tests_config
+
tests_config.parse_args()
import st2common
@@ -32,23 +33,22 @@
from st2tests import fixturesloader as fixtures
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml', # wolfpack.action-1
- 'action2.yaml', # wolfpack.action-2
- 'local.yaml' # core.local
+ "actions": [
+ "action1.yaml", # wolfpack.action-1
+ "action2.yaml", # wolfpack.action-2
+ "local.yaml", # core.local
+ ],
+ "policies": [
+ "policy_2.yaml", # mock policy on wolfpack.action-1
+ "policy_5.yaml", # concurrency policy on wolfpack.action-2
],
- 'policies': [
- 'policy_2.yaml', # mock policy on wolfpack.action-1
- 'policy_5.yaml' # concurrency policy on wolfpack.action-2
- ]
}
class PolicyServiceTestCase(st2tests.DbTestCase):
-
@classmethod
def setUpClass(cls):
super(PolicyServiceTestCase, cls).setUpClass()
@@ -60,28 +60,39 @@ def setUpClass(cls):
policies_registrar.register_policy_types(st2common)
loader = fixtures.FixturesLoader()
- loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
def setUp(self):
super(PolicyServiceTestCase, self).setUp()
- params = {'action': 'wolfpack.action-1', 'parameters': {'actionstr': 'foo-last'}}
+ params = {
+ "action": "wolfpack.action-1",
+ "parameters": {"actionstr": "foo-last"},
+ }
self.lv_ac_db_1 = action_db_models.LiveActionDB(**params)
self.lv_ac_db_1, _ = action_service.request(self.lv_ac_db_1)
- params = {'action': 'wolfpack.action-2', 'parameters': {'actionstr': 'foo-last'}}
+ params = {
+ "action": "wolfpack.action-2",
+ "parameters": {"actionstr": "foo-last"},
+ }
self.lv_ac_db_2 = action_db_models.LiveActionDB(**params)
self.lv_ac_db_2, _ = action_service.request(self.lv_ac_db_2)
- params = {'action': 'core.local', 'parameters': {'cmd': 'date'}}
+ params = {"action": "core.local", "parameters": {"cmd": "date"}}
self.lv_ac_db_3 = action_db_models.LiveActionDB(**params)
self.lv_ac_db_3, _ = action_service.request(self.lv_ac_db_3)
def tearDown(self):
- action_service.update_status(self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED)
- action_service.update_status(self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED)
- action_service.update_status(self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED)
+ action_service.update_status(
+ self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED
+ )
+ action_service.update_status(
+ self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED
+ )
+ action_service.update_status(
+ self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED
+ )
def test_action_has_policies(self):
self.assertTrue(policy_service.has_policies(self.lv_ac_db_1))
@@ -93,7 +104,7 @@ def test_action_has_specific_policies(self):
self.assertTrue(
policy_service.has_policies(
self.lv_ac_db_2,
- policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK
+ policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK,
)
)
@@ -101,6 +112,6 @@ def test_action_does_not_have_specific_policies(self):
self.assertFalse(
policy_service.has_policies(
self.lv_ac_db_1,
- policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK
+ policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK,
)
)
diff --git a/st2common/tests/unit/services/test_synchronization.py b/st2common/tests/unit/services/test_synchronization.py
index 86cf36042f..991e6b9036 100644
--- a/st2common/tests/unit/services/test_synchronization.py
+++ b/st2common/tests/unit/services/test_synchronization.py
@@ -39,13 +39,15 @@ def tearDownClass(cls):
super(SynchronizationTest, cls).tearDownClass()
def test_service_configured(self):
- cfg.CONF.set_override(name='url', override='kazoo://127.0.0.1:2181', group='coordination')
+ cfg.CONF.set_override(
+ name="url", override="kazoo://127.0.0.1:2181", group="coordination"
+ )
self.assertTrue(coordination.configured())
- cfg.CONF.set_override(name='url', override='file:///tmp', group='coordination')
+ cfg.CONF.set_override(name="url", override="file:///tmp", group="coordination")
self.assertFalse(coordination.configured())
- cfg.CONF.set_override(name='url', override='zake://', group='coordination')
+ cfg.CONF.set_override(name="url", override="zake://", group="coordination")
self.assertFalse(coordination.configured())
def test_lock(self):
diff --git a/st2common/tests/unit/services/test_trace.py b/st2common/tests/unit/services/test_trace.py
index 06c9260586..807dc4251d 100644
--- a/st2common/tests/unit/services/test_trace.py
+++ b/st2common/tests/unit/services/test_trace.py
@@ -30,33 +30,37 @@
from st2tests import DbTestCase
-FIXTURES_PACK = 'traces'
-
-TEST_MODELS = OrderedDict((
- ('executions', [
- 'traceable_execution.yaml',
- 'rule_fired_execution.yaml',
- 'execution_with_parent.yaml'
- ]),
- ('liveactions', [
- 'traceable_liveaction.yaml',
- 'liveaction_with_parent.yaml'
- ]),
- ('traces', [
- 'trace_empty.yaml',
- 'trace_multiple_components.yaml',
- 'trace_one_each.yaml',
- 'trace_one_each_dup.yaml',
- 'trace_execution.yaml'
- ]),
- ('triggers', ['trigger1.yaml']),
- ('triggerinstances', [
- 'action_trigger.yaml',
- 'notify_trigger.yaml',
- 'non_internal_trigger.yaml'
- ]),
- ('rules', ['rule1.yaml']),
-))
+FIXTURES_PACK = "traces"
+
+TEST_MODELS = OrderedDict(
+ (
+ (
+ "executions",
+ [
+ "traceable_execution.yaml",
+ "rule_fired_execution.yaml",
+ "execution_with_parent.yaml",
+ ],
+ ),
+ ("liveactions", ["traceable_liveaction.yaml", "liveaction_with_parent.yaml"]),
+ (
+ "traces",
+ [
+ "trace_empty.yaml",
+ "trace_multiple_components.yaml",
+ "trace_one_each.yaml",
+ "trace_one_each_dup.yaml",
+ "trace_execution.yaml",
+ ],
+ ),
+ ("triggers", ["trigger1.yaml"]),
+ (
+ "triggerinstances",
+ ["action_trigger.yaml", "notify_trigger.yaml", "non_internal_trigger.yaml"],
+ ),
+ ("rules", ["rule1.yaml"]),
+ )
+)
class DummyComponent(object):
@@ -78,139 +82,184 @@ class TestTraceService(DbTestCase):
@classmethod
def setUpClass(cls):
super(TestTraceService, cls).setUpClass()
- cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
- cls.trace1 = cls.models['traces']['trace_multiple_components.yaml']
- cls.trace2 = cls.models['traces']['trace_one_each.yaml']
- cls.trace3 = cls.models['traces']['trace_one_each_dup.yaml']
- cls.trace_empty = cls.models['traces']['trace_empty.yaml']
- cls.trace_execution = cls.models['traces']['trace_execution.yaml']
-
- cls.action_trigger = cls.models['triggerinstances']['action_trigger.yaml']
- cls.notify_trigger = cls.models['triggerinstances']['notify_trigger.yaml']
- cls.non_internal_trigger = cls.models['triggerinstances']['non_internal_trigger.yaml']
-
- cls.rule1 = cls.models['rules']['rule1.yaml']
-
- cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml']
- cls.liveaction_with_parent = cls.models['liveactions']['liveaction_with_parent.yaml']
- cls.traceable_execution = cls.models['executions']['traceable_execution.yaml']
- cls.rule_fired_execution = cls.models['executions']['rule_fired_execution.yaml']
- cls.execution_with_parent = cls.models['executions']['execution_with_parent.yaml']
+ cls.models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
+ cls.trace1 = cls.models["traces"]["trace_multiple_components.yaml"]
+ cls.trace2 = cls.models["traces"]["trace_one_each.yaml"]
+ cls.trace3 = cls.models["traces"]["trace_one_each_dup.yaml"]
+ cls.trace_empty = cls.models["traces"]["trace_empty.yaml"]
+ cls.trace_execution = cls.models["traces"]["trace_execution.yaml"]
+
+ cls.action_trigger = cls.models["triggerinstances"]["action_trigger.yaml"]
+ cls.notify_trigger = cls.models["triggerinstances"]["notify_trigger.yaml"]
+ cls.non_internal_trigger = cls.models["triggerinstances"][
+ "non_internal_trigger.yaml"
+ ]
+
+ cls.rule1 = cls.models["rules"]["rule1.yaml"]
+
+ cls.traceable_liveaction = cls.models["liveactions"][
+ "traceable_liveaction.yaml"
+ ]
+ cls.liveaction_with_parent = cls.models["liveactions"][
+ "liveaction_with_parent.yaml"
+ ]
+ cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"]
+ cls.rule_fired_execution = cls.models["executions"]["rule_fired_execution.yaml"]
+ cls.execution_with_parent = cls.models["executions"][
+ "execution_with_parent.yaml"
+ ]
def test_get_trace_db_by_action_execution(self):
- action_execution = DummyComponent(id_=self.trace1.action_executions[0].object_id)
- trace_db = trace_service.get_trace_db_by_action_execution(action_execution=action_execution)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ action_execution = DummyComponent(
+ id_=self.trace1.action_executions[0].object_id
+ )
+ trace_db = trace_service.get_trace_db_by_action_execution(
+ action_execution=action_execution
+ )
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
def test_get_trace_db_by_action_execution_fail(self):
- action_execution = DummyComponent(id_=self.trace2.action_executions[0].object_id)
- self.assertRaises(UniqueTraceNotFoundException,
- trace_service.get_trace_db_by_action_execution,
- **{'action_execution': action_execution})
+ action_execution = DummyComponent(
+ id_=self.trace2.action_executions[0].object_id
+ )
+ self.assertRaises(
+ UniqueTraceNotFoundException,
+ trace_service.get_trace_db_by_action_execution,
+ **{"action_execution": action_execution},
+ )
def test_get_trace_db_by_rule(self):
rule = DummyComponent(id_=self.trace1.rules[0].object_id)
trace_dbs = trace_service.get_trace_db_by_rule(rule=rule)
- self.assertEqual(len(trace_dbs), 1, 'Expected 1 trace_db.')
- self.assertEqual(trace_dbs[0].id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(len(trace_dbs), 1, "Expected 1 trace_db.")
+ self.assertEqual(
+ trace_dbs[0].id, self.trace1.id, "Incorrect trace_db returned."
+ )
def test_get_multiple_trace_db_by_rule(self):
rule = DummyComponent(id_=self.trace2.rules[0].object_id)
trace_dbs = trace_service.get_trace_db_by_rule(rule=rule)
- self.assertEqual(len(trace_dbs), 2, 'Expected 2 trace_db.')
+ self.assertEqual(len(trace_dbs), 2, "Expected 2 trace_db.")
result = [trace_db.id for trace_db in trace_dbs]
- self.assertEqual(result, [self.trace2.id, self.trace3.id], 'Incorrect trace_dbs returned.')
+ self.assertEqual(
+ result, [self.trace2.id, self.trace3.id], "Incorrect trace_dbs returned."
+ )
def test_get_trace_db_by_trigger_instance(self):
- trigger_instance = DummyComponent(id_=self.trace1.trigger_instances[0].object_id)
- trace_db = trace_service.get_trace_db_by_trigger_instance(trigger_instance=trigger_instance)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ trigger_instance = DummyComponent(
+ id_=self.trace1.trigger_instances[0].object_id
+ )
+ trace_db = trace_service.get_trace_db_by_trigger_instance(
+ trigger_instance=trigger_instance
+ )
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
def test_get_trace_db_by_trigger_instance_fail(self):
- trigger_instance = DummyComponent(id_=self.trace2.trigger_instances[0].object_id)
- self.assertRaises(UniqueTraceNotFoundException,
- trace_service.get_trace_db_by_trigger_instance,
- **{'trigger_instance': trigger_instance})
+ trigger_instance = DummyComponent(
+ id_=self.trace2.trigger_instances[0].object_id
+ )
+ self.assertRaises(
+ UniqueTraceNotFoundException,
+ trace_service.get_trace_db_by_trigger_instance,
+ **{"trigger_instance": trigger_instance},
+ )
def test_get_trace_by_dict(self):
- trace_context = {'id_': str(self.trace1.id)}
+ trace_context = {"id_": str(self.trace1.id)}
trace_db = trace_service.get_trace(trace_context)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
- trace_context = {'id_': str(bson.ObjectId())}
- self.assertRaises(StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context)
+ trace_context = {"id_": str(bson.ObjectId())}
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context
+ )
- trace_context = {'trace_tag': self.trace1.trace_tag}
+ trace_context = {"trace_tag": self.trace1.trace_tag}
trace_db = trace_service.get_trace(trace_context)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
def test_get_trace_by_trace_context(self):
- trace_context = TraceContext(**{'id_': str(self.trace1.id)})
+ trace_context = TraceContext(**{"id_": str(self.trace1.id)})
trace_db = trace_service.get_trace(trace_context)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
- trace_context = TraceContext(**{'trace_tag': self.trace1.trace_tag})
+ trace_context = TraceContext(**{"trace_tag": self.trace1.trace_tag})
trace_db = trace_service.get_trace(trace_context)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
def test_get_trace_ignore_trace_tag(self):
- trace_context = {'trace_tag': self.trace1.trace_tag}
+ trace_context = {"trace_tag": self.trace1.trace_tag}
trace_db = trace_service.get_trace(trace_context)
- self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.')
+ self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.")
- trace_context = {'trace_tag': self.trace1.trace_tag}
+ trace_context = {"trace_tag": self.trace1.trace_tag}
trace_db = trace_service.get_trace(trace_context, ignore_trace_tag=True)
- self.assertEqual(trace_db, None, 'Should be None.')
+ self.assertEqual(trace_db, None, "Should be None.")
def test_get_trace_fail_empty_context(self):
trace_context = {}
self.assertRaises(ValueError, trace_service.get_trace, trace_context)
def test_get_trace_fail_multi_match(self):
- trace_context = {'trace_tag': self.trace2.trace_tag}
- self.assertRaises(UniqueTraceNotFoundException, trace_service.get_trace, trace_context)
+ trace_context = {"trace_tag": self.trace2.trace_tag}
+ self.assertRaises(
+ UniqueTraceNotFoundException, trace_service.get_trace, trace_context
+ )
def test_get_trace_db_by_live_action_valid_id_context(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
- traceable_liveaction.context['trace_context'] = {'id_': str(self.trace_execution.id)}
- created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction)
+ traceable_liveaction.context["trace_context"] = {
+ "id_": str(self.trace_execution.id)
+ }
+ created, trace_db = trace_service.get_trace_db_by_live_action(
+ traceable_liveaction
+ )
self.assertFalse(created)
self.assertEqual(trace_db.id, self.trace_execution.id)
def test_get_trace_db_by_live_action_trace_tag_context(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
- traceable_liveaction.context['trace_context'] = {
- 'trace_tag': str(self.trace_execution.trace_tag)
+ traceable_liveaction.context["trace_context"] = {
+ "trace_tag": str(self.trace_execution.trace_tag)
}
- created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction)
+ created, trace_db = trace_service.get_trace_db_by_live_action(
+ traceable_liveaction
+ )
self.assertTrue(created)
- self.assertEqual(trace_db.id, None, 'Expected to be None')
+ self.assertEqual(trace_db.id, None, "Expected to be None")
self.assertEqual(trace_db.trace_tag, str(self.trace_execution.trace_tag))
def test_get_trace_db_by_live_action_parent(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
- traceable_liveaction.context['parent'] = {
- 'execution_id': str(self.trace1.action_executions[0].object_id)
+ traceable_liveaction.context["parent"] = {
+ "execution_id": str(self.trace1.action_executions[0].object_id)
}
- created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction)
+ created, trace_db = trace_service.get_trace_db_by_live_action(
+ traceable_liveaction
+ )
self.assertFalse(created)
self.assertEqual(trace_db.id, self.trace1.id)
def test_get_trace_db_by_live_action_parent_fail(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
- traceable_liveaction.context['parent'] = {
- 'execution_id': str(bson.ObjectId())
- }
- self.assertRaises(StackStormDBObjectNotFoundError,
- trace_service.get_trace_db_by_live_action,
- traceable_liveaction)
+ traceable_liveaction.context["parent"] = {"execution_id": str(bson.ObjectId())}
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ trace_service.get_trace_db_by_live_action,
+ traceable_liveaction,
+ )
def test_get_trace_db_by_live_action_from_execution(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
# fixtures id value in liveaction is not persisted in DB.
- traceable_liveaction.id = bson.ObjectId(self.traceable_execution.liveaction['id'])
- created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction)
+ traceable_liveaction.id = bson.ObjectId(
+ self.traceable_execution.liveaction["id"]
+ )
+ created, trace_db = trace_service.get_trace_db_by_live_action(
+ traceable_liveaction
+ )
self.assertFalse(created)
self.assertEqual(trace_db.id, self.trace_execution.id)
@@ -218,76 +267,119 @@ def test_get_trace_db_by_live_action_new_trace(self):
traceable_liveaction = copy.copy(self.traceable_liveaction)
# a liveaction without any associated ActionExecution
traceable_liveaction.id = bson.ObjectId()
- created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction)
+ created, trace_db = trace_service.get_trace_db_by_live_action(
+ traceable_liveaction
+ )
self.assertTrue(created)
- self.assertEqual(trace_db.id, None, 'Should be None.')
+ self.assertEqual(trace_db.id, None, "Should be None.")
def test_add_or_update_given_trace_context(self):
- trace_context = {'id_': str(self.trace_empty.id)}
- action_execution_id = 'action_execution_1'
- rule_id = 'rule_1'
- trigger_instance_id = 'trigger_instance_1'
+ trace_context = {"id_": str(self.trace_empty.id)}
+ action_execution_id = "action_execution_1"
+ rule_id = "rule_1"
+ trigger_instance_id = "trigger_instance_1"
trace_service.add_or_update_given_trace_context(
trace_context,
action_executions=[action_execution_id],
rules=[rule_id],
- trigger_instances=[trigger_instance_id])
+ trigger_instances=[trigger_instance_id],
+ )
retrieved_trace_db = Trace.get_by_id(self.trace_empty.id)
- self.assertEqual(len(retrieved_trace_db.action_executions), 1,
- 'Expected updated action_executions.')
- self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id,
- 'Expected updated action_executions.')
-
- self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.')
- self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.')
-
- self.assertEqual(len(retrieved_trace_db.trigger_instances), 1,
- 'Expected updated trigger_instances.')
- self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id,
- 'Expected updated trigger_instances.')
+ self.assertEqual(
+ len(retrieved_trace_db.action_executions),
+ 1,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.action_executions[0].object_id,
+ action_execution_id,
+ "Expected updated action_executions.",
+ )
+
+ self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.")
+ self.assertEqual(
+ retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules."
+ )
+
+ self.assertEqual(
+ len(retrieved_trace_db.trigger_instances),
+ 1,
+ "Expected updated trigger_instances.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.trigger_instances[0].object_id,
+ trigger_instance_id,
+ "Expected updated trigger_instances.",
+ )
Trace.delete(retrieved_trace_db)
Trace.add_or_update(self.trace_empty)
def test_add_or_update_given_trace_db(self):
- action_execution_id = 'action_execution_1'
- rule_id = 'rule_1'
- trigger_instance_id = 'trigger_instance_1'
+ action_execution_id = "action_execution_1"
+ rule_id = "rule_1"
+ trigger_instance_id = "trigger_instance_1"
to_save = copy.copy(self.trace_empty)
to_save.id = None
saved = trace_service.add_or_update_given_trace_db(
to_save,
action_executions=[action_execution_id],
rules=[rule_id],
- trigger_instances=[trigger_instance_id])
+ trigger_instances=[trigger_instance_id],
+ )
retrieved_trace_db = Trace.get_by_id(saved.id)
- self.assertEqual(len(retrieved_trace_db.action_executions), 1,
- 'Expected updated action_executions.')
- self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id,
- 'Expected updated action_executions.')
-
- self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.')
- self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.')
-
- self.assertEqual(len(retrieved_trace_db.trigger_instances), 1,
- 'Expected updated trigger_instances.')
- self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id,
- 'Expected updated trigger_instances.')
+ self.assertEqual(
+ len(retrieved_trace_db.action_executions),
+ 1,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.action_executions[0].object_id,
+ action_execution_id,
+ "Expected updated action_executions.",
+ )
+
+ self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.")
+ self.assertEqual(
+ retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules."
+ )
+
+ self.assertEqual(
+ len(retrieved_trace_db.trigger_instances),
+ 1,
+ "Expected updated trigger_instances.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.trigger_instances[0].object_id,
+ trigger_instance_id,
+ "Expected updated trigger_instances.",
+ )
# Now add more TraceComponents and validated that they are added properly.
saved = trace_service.add_or_update_given_trace_db(
retrieved_trace_db,
action_executions=[str(bson.ObjectId()), str(bson.ObjectId())],
rules=[str(bson.ObjectId())],
- trigger_instances=[str(bson.ObjectId()), str(bson.ObjectId()), str(bson.ObjectId())])
+ trigger_instances=[
+ str(bson.ObjectId()),
+ str(bson.ObjectId()),
+ str(bson.ObjectId()),
+ ],
+ )
retrieved_trace_db = Trace.get_by_id(saved.id)
- self.assertEqual(len(retrieved_trace_db.action_executions), 3,
- 'Expected updated action_executions.')
- self.assertEqual(len(retrieved_trace_db.rules), 2, 'Expected updated rules.')
- self.assertEqual(len(retrieved_trace_db.trigger_instances), 4,
- 'Expected updated trigger_instances.')
+ self.assertEqual(
+ len(retrieved_trace_db.action_executions),
+ 3,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(len(retrieved_trace_db.rules), 2, "Expected updated rules.")
+ self.assertEqual(
+ len(retrieved_trace_db.trigger_instances),
+ 4,
+ "Expected updated trigger_instances.",
+ )
Trace.delete(retrieved_trace_db)
@@ -295,179 +387,238 @@ def test_add_or_update_given_trace_db_fail(self):
self.assertRaises(ValueError, trace_service.add_or_update_given_trace_db, None)
def test_add_or_update_given_trace_context_new(self):
- trace_context = {'trace_tag': 'awesome_test_trace'}
- action_execution_id = 'action_execution_1'
- rule_id = 'rule_1'
- trigger_instance_id = 'trigger_instance_1'
+ trace_context = {"trace_tag": "awesome_test_trace"}
+ action_execution_id = "action_execution_1"
+ rule_id = "rule_1"
+ trigger_instance_id = "trigger_instance_1"
pre_add_or_update_traces = len(Trace.get_all())
trace_db = trace_service.add_or_update_given_trace_context(
trace_context,
action_executions=[action_execution_id],
rules=[rule_id],
- trigger_instances=[trigger_instance_id])
+ trigger_instances=[trigger_instance_id],
+ )
post_add_or_update_traces = len(Trace.get_all())
- self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces,
- 'Expected new Trace to be created.')
+ self.assertTrue(
+ post_add_or_update_traces > pre_add_or_update_traces,
+ "Expected new Trace to be created.",
+ )
retrieved_trace_db = Trace.get_by_id(trace_db.id)
- self.assertEqual(len(retrieved_trace_db.action_executions), 1,
- 'Expected updated action_executions.')
- self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id,
- 'Expected updated action_executions.')
-
- self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.')
- self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.')
-
- self.assertEqual(len(retrieved_trace_db.trigger_instances), 1,
- 'Expected updated trigger_instances.')
- self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id,
- 'Expected updated trigger_instances.')
+ self.assertEqual(
+ len(retrieved_trace_db.action_executions),
+ 1,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.action_executions[0].object_id,
+ action_execution_id,
+ "Expected updated action_executions.",
+ )
+
+ self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.")
+ self.assertEqual(
+ retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules."
+ )
+
+ self.assertEqual(
+ len(retrieved_trace_db.trigger_instances),
+ 1,
+ "Expected updated trigger_instances.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.trigger_instances[0].object_id,
+ trigger_instance_id,
+ "Expected updated trigger_instances.",
+ )
Trace.delete(retrieved_trace_db)
def test_add_or_update_given_trace_context_new_with_causals(self):
- trace_context = {'trace_tag': 'causal_test_trace'}
- action_execution_id = 'action_execution_1'
- rule_id = 'rule_1'
- trigger_instance_id = 'trigger_instance_1'
+ trace_context = {"trace_tag": "causal_test_trace"}
+ action_execution_id = "action_execution_1"
+ rule_id = "rule_1"
+ trigger_instance_id = "trigger_instance_1"
pre_add_or_update_traces = len(Trace.get_all())
trace_db = trace_service.add_or_update_given_trace_context(
trace_context,
- action_executions=[{'id': action_execution_id,
- 'caused_by': {'id': '%s:%s' % (rule_id, trigger_instance_id),
- 'type': 'rule'}}],
- rules=[{'id': rule_id,
- 'caused_by': {'id': trigger_instance_id, 'type': 'trigger-instance'}}],
- trigger_instances=[trigger_instance_id])
+ action_executions=[
+ {
+ "id": action_execution_id,
+ "caused_by": {
+ "id": "%s:%s" % (rule_id, trigger_instance_id),
+ "type": "rule",
+ },
+ }
+ ],
+ rules=[
+ {
+ "id": rule_id,
+ "caused_by": {
+ "id": trigger_instance_id,
+ "type": "trigger-instance",
+ },
+ }
+ ],
+ trigger_instances=[trigger_instance_id],
+ )
post_add_or_update_traces = len(Trace.get_all())
- self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces,
- 'Expected new Trace to be created.')
+ self.assertTrue(
+ post_add_or_update_traces > pre_add_or_update_traces,
+ "Expected new Trace to be created.",
+ )
retrieved_trace_db = Trace.get_by_id(trace_db.id)
- self.assertEqual(len(retrieved_trace_db.action_executions), 1,
- 'Expected updated action_executions.')
- self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id,
- 'Expected updated action_executions.')
- self.assertEqual(retrieved_trace_db.action_executions[0].caused_by,
- {'id': '%s:%s' % (rule_id, trigger_instance_id),
- 'type': 'rule'},
- 'Expected updated action_executions.')
-
- self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.')
- self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.')
- self.assertEqual(retrieved_trace_db.rules[0].caused_by,
- {'id': trigger_instance_id, 'type': 'trigger-instance'},
- 'Expected updated rules.')
-
- self.assertEqual(len(retrieved_trace_db.trigger_instances), 1,
- 'Expected updated trigger_instances.')
- self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id,
- 'Expected updated trigger_instances.')
- self.assertEqual(retrieved_trace_db.trigger_instances[0].caused_by, {},
- 'Expected updated rules.')
+ self.assertEqual(
+ len(retrieved_trace_db.action_executions),
+ 1,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.action_executions[0].object_id,
+ action_execution_id,
+ "Expected updated action_executions.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.action_executions[0].caused_by,
+ {"id": "%s:%s" % (rule_id, trigger_instance_id), "type": "rule"},
+ "Expected updated action_executions.",
+ )
+
+ self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.")
+ self.assertEqual(
+ retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules."
+ )
+ self.assertEqual(
+ retrieved_trace_db.rules[0].caused_by,
+ {"id": trigger_instance_id, "type": "trigger-instance"},
+ "Expected updated rules.",
+ )
+
+ self.assertEqual(
+ len(retrieved_trace_db.trigger_instances),
+ 1,
+ "Expected updated trigger_instances.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.trigger_instances[0].object_id,
+ trigger_instance_id,
+ "Expected updated trigger_instances.",
+ )
+ self.assertEqual(
+ retrieved_trace_db.trigger_instances[0].caused_by,
+ {},
+ "Expected updated rules.",
+ )
Trace.delete(retrieved_trace_db)
def test_trace_component_for_trigger_instance(self):
# action_trigger
trace_component = trace_service.get_trace_component_for_trigger_instance(
- self.action_trigger)
+ self.action_trigger
+ )
expected = {
- 'id': str(self.action_trigger.id),
- 'ref': self.action_trigger.trigger,
- 'caused_by': {
- 'type': 'action_execution',
- 'id': self.action_trigger.payload['execution_id']
- }
+ "id": str(self.action_trigger.id),
+ "ref": self.action_trigger.trigger,
+ "caused_by": {
+ "type": "action_execution",
+ "id": self.action_trigger.payload["execution_id"],
+ },
}
self.assertEqual(trace_component, expected)
# notify_trigger
trace_component = trace_service.get_trace_component_for_trigger_instance(
- self.notify_trigger)
+ self.notify_trigger
+ )
expected = {
- 'id': str(self.notify_trigger.id),
- 'ref': self.notify_trigger.trigger,
- 'caused_by': {
- 'type': 'action_execution',
- 'id': self.notify_trigger.payload['execution_id']
- }
+ "id": str(self.notify_trigger.id),
+ "ref": self.notify_trigger.trigger,
+ "caused_by": {
+ "type": "action_execution",
+ "id": self.notify_trigger.payload["execution_id"],
+ },
}
self.assertEqual(trace_component, expected)
# non_internal_trigger
trace_component = trace_service.get_trace_component_for_trigger_instance(
- self.non_internal_trigger)
+ self.non_internal_trigger
+ )
expected = {
- 'id': str(self.non_internal_trigger.id),
- 'ref': self.non_internal_trigger.trigger,
- 'caused_by': {}
+ "id": str(self.non_internal_trigger.id),
+ "ref": self.non_internal_trigger.trigger,
+ "caused_by": {},
}
self.assertEqual(trace_component, expected)
def test_trace_component_for_rule(self):
- trace_component = trace_service.get_trace_component_for_rule(self.rule1,
- self.non_internal_trigger)
+ trace_component = trace_service.get_trace_component_for_rule(
+ self.rule1, self.non_internal_trigger
+ )
expected = {
- 'id': str(self.rule1.id),
- 'ref': self.rule1.ref,
- 'caused_by': {
- 'type': 'trigger_instance',
- 'id': str(self.non_internal_trigger.id)
- }
+ "id": str(self.rule1.id),
+ "ref": self.rule1.ref,
+ "caused_by": {
+ "type": "trigger_instance",
+ "id": str(self.non_internal_trigger.id),
+ },
}
self.assertEqual(trace_component, expected)
def test_trace_component_for_action_execution(self):
# no cause
trace_component = trace_service.get_trace_component_for_action_execution(
- self.traceable_execution,
- self.traceable_liveaction)
+ self.traceable_execution, self.traceable_liveaction
+ )
expected = {
- 'id': str(self.traceable_execution.id),
- 'ref': self.traceable_execution.action['ref'],
- 'caused_by': {}
+ "id": str(self.traceable_execution.id),
+ "ref": self.traceable_execution.action["ref"],
+ "caused_by": {},
}
self.assertEqual(trace_component, expected)
# rule_fired_execution
trace_component = trace_service.get_trace_component_for_action_execution(
- self.rule_fired_execution,
- self.traceable_liveaction)
+ self.rule_fired_execution, self.traceable_liveaction
+ )
expected = {
- 'id': str(self.rule_fired_execution.id),
- 'ref': self.rule_fired_execution.action['ref'],
- 'caused_by': {
- 'type': 'rule',
- 'id': '%s:%s' % (self.rule_fired_execution.rule['id'],
- self.rule_fired_execution.trigger_instance['id'])
- }
+ "id": str(self.rule_fired_execution.id),
+ "ref": self.rule_fired_execution.action["ref"],
+ "caused_by": {
+ "type": "rule",
+ "id": "%s:%s"
+ % (
+ self.rule_fired_execution.rule["id"],
+ self.rule_fired_execution.trigger_instance["id"],
+ ),
+ },
}
self.assertEqual(trace_component, expected)
# execution_with_parent
trace_component = trace_service.get_trace_component_for_action_execution(
- self.execution_with_parent,
- self.liveaction_with_parent)
+ self.execution_with_parent, self.liveaction_with_parent
+ )
expected = {
- 'id': str(self.execution_with_parent.id),
- 'ref': self.execution_with_parent.action['ref'],
- 'caused_by': {
- 'type': 'action_execution',
- 'id': self.liveaction_with_parent.context['parent']['execution_id']
- }
+ "id": str(self.execution_with_parent.id),
+ "ref": self.execution_with_parent.action["ref"],
+ "caused_by": {
+ "type": "action_execution",
+ "id": self.liveaction_with_parent.context["parent"]["execution_id"],
+ },
}
self.assertEqual(trace_component, expected)
class TestTraceContext(TestCase):
-
def test_str_method(self):
- trace_context = TraceContext(id_='id', trace_tag='tag')
+ trace_context = TraceContext(id_="id", trace_tag="tag")
self.assertTrue(str(trace_context))
- trace_context = TraceContext(trace_tag='tag')
+ trace_context = TraceContext(trace_tag="tag")
self.assertTrue(str(trace_context))
- trace_context = TraceContext(id_='id')
+ trace_context = TraceContext(id_="id")
self.assertTrue(str(trace_context))
diff --git a/st2common/tests/unit/services/test_trace_injection_action_services.py b/st2common/tests/unit/services/test_trace_injection_action_services.py
index 8f9570d0e2..4b4fe0d177 100644
--- a/st2common/tests/unit/services/test_trace_injection_action_services.py
+++ b/st2common/tests/unit/services/test_trace_injection_action_services.py
@@ -21,13 +21,13 @@
from st2tests.fixturesloader import FixturesLoader
from st2tests import DbTestCase
-FIXTURES_PACK = 'traces'
+FIXTURES_PACK = "traces"
TEST_MODELS = {
- 'executions': ['traceable_execution.yaml'],
- 'liveactions': ['traceable_liveaction.yaml'],
- 'actions': ['chain1.yaml'],
- 'runners': ['actionchain.yaml']
+ "executions": ["traceable_execution.yaml"],
+ "liveactions": ["traceable_liveaction.yaml"],
+ "actions": ["chain1.yaml"],
+ "runners": ["actionchain.yaml"],
}
@@ -41,44 +41,52 @@ class TraceInjectionTests(DbTestCase):
@classmethod
def setUpClass(cls):
super(TraceInjectionTests, cls).setUpClass()
- cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
+ cls.models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+ )
- cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml']
- cls.traceable_execution = cls.models['executions']['traceable_execution.yaml']
- cls.action = cls.models['actions']['chain1.yaml']
+ cls.traceable_liveaction = cls.models["liveactions"][
+ "traceable_liveaction.yaml"
+ ]
+ cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"]
+ cls.action = cls.models["actions"]["chain1.yaml"]
def test_trace_provided(self):
- self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'OohLaLaLa'}
+ self.traceable_liveaction["context"]["trace_context"] = {
+ "trace_tag": "OohLaLaLa"
+ }
action_services.request(self.traceable_liveaction)
traces = Trace.get_all()
self.assertEqual(len(traces), 1)
- self.assertEqual(len(traces[0]['action_executions']), 1)
+ self.assertEqual(len(traces[0]["action_executions"]), 1)
# Let's use existing trace id in trace context.
# We shouldn't create new trace object.
trace_id = str(traces[0].id)
- self.traceable_liveaction['context']['trace_context'] = {'id_': trace_id}
+ self.traceable_liveaction["context"]["trace_context"] = {"id_": trace_id}
action_services.request(self.traceable_liveaction)
traces = Trace.get_all()
self.assertEqual(len(traces), 1)
- self.assertEqual(len(traces[0]['action_executions']), 2)
+ self.assertEqual(len(traces[0]["action_executions"]), 2)
def test_trace_tag_resuse(self):
- self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'blank space'}
+ self.traceable_liveaction["context"]["trace_context"] = {
+ "trace_tag": "blank space"
+ }
action_services.request(self.traceable_liveaction)
# Let's use same trace tag again and we should see two trace objects in db.
action_services.request(self.traceable_liveaction)
- traces = Trace.query(**{'trace_tag': 'blank space'})
+ traces = Trace.query(**{"trace_tag": "blank space"})
self.assertEqual(len(traces), 2)
def test_invalid_trace_id_provided(self):
liveactions = LiveAction.get_all()
self.assertEqual(len(liveactions), 1) # fixtures loads it.
- self.traceable_liveaction['context']['trace_context'] = {'id_': 'balleilaka'}
+ self.traceable_liveaction["context"]["trace_context"] = {"id_": "balleilaka"}
- self.assertRaises(TraceNotFoundException, action_services.request,
- self.traceable_liveaction)
+ self.assertRaises(
+ TraceNotFoundException, action_services.request, self.traceable_liveaction
+ )
# Make sure no liveactions are left behind
liveactions = LiveAction.get_all()
diff --git a/st2common/tests/unit/services/test_workflow.py b/st2common/tests/unit/services/test_workflow.py
index 71cae679ba..23bd4aca60 100644
--- a/st2common/tests/unit/services/test_workflow.py
+++ b/st2common/tests/unit/services/test_workflow.py
@@ -25,6 +25,7 @@
import st2tests
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -43,33 +44,35 @@
from st2tests.mocks import liveaction as mock_lv_ac_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
-PACK_7 = 'dummy_pack_7'
-PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + PACK_7
+PACK_7 = "dummy_pack_7"
+PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + PACK_7
PACKS = [
TEST_PACK_PATH,
PACK_7_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
publishers.CUDPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
class WorkflowExecutionServiceTest(st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(WorkflowExecutionServiceTest, cls).setUpClass()
@@ -79,18 +82,17 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def test_request(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request the workflow execution.
@@ -99,7 +101,9 @@ def test_request(self):
wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx)
# Check workflow execution is saved to the database.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
# Check required attributes.
@@ -110,10 +114,12 @@ def test_request(self):
self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED)
def test_request_with_input(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters={'who': 'stan'})
+ lv_ac_db = lv_db_models.LiveActionDB(
+ action=wf_meta["name"], parameters={"who": "stan"}
+ )
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request the workflow execution.
@@ -122,7 +128,9 @@ def test_request_with_input(self):
wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx)
# Check workflow execution is saved to the database.
- wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))
+ wf_ex_dbs = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )
self.assertEqual(len(wf_ex_dbs), 1)
# Check required attributes.
@@ -133,18 +141,16 @@ def test_request_with_input(self):
self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED)
# Check input and context.
- expected_input = {
- 'who': 'stan'
- }
+ expected_input = {"who": "stan"}
self.assertDictEqual(wf_ex_db.input, expected_input)
def test_request_bad_action(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the action execution object with the bad action.
ac_ex_db = ex_db_models.ActionExecutionDB(
- action={'ref': 'mock.foobar'}, runner={'name': 'foobar'}
+ action={"ref": "mock.foobar"}, runner={"name": "foobar"}
)
# Request the workflow execution.
@@ -153,14 +159,16 @@ def test_request_bad_action(self):
workflow_service.request,
self.get_wf_def(TEST_PACK_PATH, wf_meta),
ac_ex_db,
- self.mock_st2_context(ac_ex_db)
+ self.mock_st2_context(ac_ex_db),
)
def test_request_wf_def_with_bad_action_ref(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-ref.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-inspection-action-ref.yaml"
+ )
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Exception is expected on request of workflow execution.
@@ -169,14 +177,16 @@ def test_request_wf_def_with_bad_action_ref(self):
workflow_service.request,
self.get_wf_def(TEST_PACK_PATH, wf_meta),
ac_ex_db,
- self.mock_st2_context(ac_ex_db)
+ self.mock_st2_context(ac_ex_db),
)
def test_request_wf_def_with_unregistered_action(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-db.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "fail-inspection-action-db.yaml"
+ )
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Exception is expected on request of workflow execution.
@@ -185,15 +195,15 @@ def test_request_wf_def_with_unregistered_action(self):
workflow_service.request,
self.get_wf_def(TEST_PACK_PATH, wf_meta),
ac_ex_db,
- self.mock_st2_context(ac_ex_db)
+ self.mock_st2_context(ac_ex_db),
)
def test_request_wf_def_with_missing_required_action_param(self):
- wf_name = 'fail-inspection-missing-required-action-param'
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
+ wf_name = "fail-inspection-missing-required-action-param"
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Exception is expected on request of workflow execution.
@@ -202,15 +212,15 @@ def test_request_wf_def_with_missing_required_action_param(self):
workflow_service.request,
self.get_wf_def(TEST_PACK_PATH, wf_meta),
ac_ex_db,
- self.mock_st2_context(ac_ex_db)
+ self.mock_st2_context(ac_ex_db),
)
def test_request_wf_def_with_unexpected_action_param(self):
- wf_name = 'fail-inspection-unexpected-action-param'
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml')
+ wf_name = "fail-inspection-unexpected-action-param"
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Exception is expected on request of workflow execution.
@@ -219,44 +229,46 @@ def test_request_wf_def_with_unexpected_action_param(self):
workflow_service.request,
self.get_wf_def(TEST_PACK_PATH, wf_meta),
ac_ex_db,
- self.mock_st2_context(ac_ex_db)
+ self.mock_st2_context(ac_ex_db),
)
def test_request_task_execution(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request the workflow execution.
wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta)
st2_ctx = self.mock_st2_context(ac_ex_db)
wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx)
- spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog'])
+ spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"])
wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec)
# Manually request task execution.
task_route = 0
- task_id = 'task1'
+ task_id = "task1"
task_spec = wf_spec.tasks.get_task(task_id)
- task_ctx = {'foo': 'bar'}
- st2_ctx = {'execution_id': wf_ex_db.action_execution}
+ task_ctx = {"foo": "bar"}
+ st2_ctx = {"execution_id": wf_ex_db.action_execution}
task_ex_req = {
- 'id': task_id,
- 'route': task_route,
- 'spec': task_spec,
- 'ctx': task_ctx,
- 'actions': [
- {'action': 'core.echo', 'input': {'message': 'Veni, vidi, vici.'}}
- ]
+ "id": task_id,
+ "route": task_route,
+ "spec": task_spec,
+ "ctx": task_ctx,
+ "actions": [
+ {"action": "core.echo", "input": {"message": "Veni, vidi, vici."}}
+ ],
}
workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req)
# Check task execution is saved to the database.
- task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ task_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(task_ex_dbs), 1)
# Check required attributes.
@@ -267,42 +279,46 @@ def test_request_task_execution(self):
self.assertEqual(task_ex_db.status, wf_statuses.RUNNING)
# Check action execution for the task query with task execution ID.
- ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(task_ex_db.id))
+ ac_ex_dbs = ex_db_access.ActionExecution.query(
+ task_execution=str(task_ex_db.id)
+ )
self.assertEqual(len(ac_ex_dbs), 1)
# Check action execution for the task query with workflow execution ID.
- ac_ex_dbs = ex_db_access.ActionExecution.query(workflow_execution=str(wf_ex_db.id))
+ ac_ex_dbs = ex_db_access.ActionExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(ac_ex_dbs), 1)
def test_request_task_execution_bad_action(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request the workflow execution.
wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta)
st2_ctx = self.mock_st2_context(ac_ex_db)
wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx)
- spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog'])
+ spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"])
wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec)
# Manually request task execution.
task_route = 0
- task_id = 'task1'
+ task_id = "task1"
task_spec = wf_spec.tasks.get_task(task_id)
- task_ctx = {'foo': 'bar'}
- st2_ctx = {'execution_id': wf_ex_db.action_execution}
+ task_ctx = {"foo": "bar"}
+ st2_ctx = {"execution_id": wf_ex_db.action_execution}
task_ex_req = {
- 'id': task_id,
- 'route': task_route,
- 'spec': task_spec,
- 'ctx': task_ctx,
- 'actions': [
- {'action': 'mock.echo', 'input': {'message': 'Veni, vidi, vici.'}}
- ]
+ "id": task_id,
+ "route": task_route,
+ "spec": task_spec,
+ "ctx": task_ctx,
+ "actions": [
+ {"action": "mock.echo", "input": {"message": "Veni, vidi, vici."}}
+ ],
}
self.assertRaises(
@@ -310,14 +326,14 @@ def test_request_task_execution_bad_action(self):
workflow_service.request_task_execution,
wf_ex_db,
st2_ctx,
- task_ex_req
+ task_ex_req,
)
def test_handle_action_execution_completion(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request and pre-process the workflow execution.
@@ -327,111 +343,124 @@ def test_handle_action_execution_completion(self):
wf_ex_db = self.prep_wf_ex(wf_ex_db)
# Manually request task execution.
- self.run_workflow_step(wf_ex_db, 'task1', 0, ctx={'foo': 'bar'})
+ self.run_workflow_step(wf_ex_db, "task1", 0, ctx={"foo": "bar"})
# Check that a new task is executed.
- self.assert_task_running('task2', 0)
+ self.assert_task_running("task2", 0)
def test_evaluate_action_execution_delay(self):
- base_task_ex_req = {'task_id': 'task1', 'task_name': 'task1', 'route': 0}
+ base_task_ex_req = {"task_id": "task1", "task_name": "task1", "route": 0}
# No task delay.
task_ex_req = copy.deepcopy(base_task_ex_req)
- ac_ex_req = {'action': 'core.noop', 'input': None}
- actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req)
+ ac_ex_req = {"action": "core.noop", "input": None}
+ actual_delay = workflow_service.eval_action_execution_delay(
+ task_ex_req, ac_ex_req
+ )
self.assertIsNone(actual_delay)
# Simple task delay.
task_ex_req = copy.deepcopy(base_task_ex_req)
- task_ex_req['delay'] = 180
- ac_ex_req = {'action': 'core.noop', 'input': None}
- actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req)
+ task_ex_req["delay"] = 180
+ ac_ex_req = {"action": "core.noop", "input": None}
+ actual_delay = workflow_service.eval_action_execution_delay(
+ task_ex_req, ac_ex_req
+ )
self.assertEqual(actual_delay, 180)
# Task delay for with items task and with no concurrency.
task_ex_req = copy.deepcopy(base_task_ex_req)
- task_ex_req['delay'] = 180
- task_ex_req['concurrency'] = None
- ac_ex_req = {'action': 'core.noop', 'input': None, 'items_id': 0}
- actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True)
+ task_ex_req["delay"] = 180
+ task_ex_req["concurrency"] = None
+ ac_ex_req = {"action": "core.noop", "input": None, "items_id": 0}
+ actual_delay = workflow_service.eval_action_execution_delay(
+ task_ex_req, ac_ex_req, True
+ )
self.assertEqual(actual_delay, 180)
# Task delay for with items task, with concurrency, and evaluate first item.
task_ex_req = copy.deepcopy(base_task_ex_req)
- task_ex_req['delay'] = 180
- task_ex_req['concurrency'] = 1
- ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 0}
- actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True)
+ task_ex_req["delay"] = 180
+ task_ex_req["concurrency"] = 1
+ ac_ex_req = {"action": "core.noop", "input": None, "item_id": 0}
+ actual_delay = workflow_service.eval_action_execution_delay(
+ task_ex_req, ac_ex_req, True
+ )
self.assertEqual(actual_delay, 180)
# Task delay for with items task, with concurrency, and evaluate later items.
task_ex_req = copy.deepcopy(base_task_ex_req)
- task_ex_req['delay'] = 180
- task_ex_req['concurrency'] = 1
- ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 1}
- actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True)
+ task_ex_req["delay"] = 180
+ task_ex_req["concurrency"] = 1
+ ac_ex_req = {"action": "core.noop", "input": None, "item_id": 1}
+ actual_delay = workflow_service.eval_action_execution_delay(
+ task_ex_req, ac_ex_req, True
+ )
self.assertIsNone(actual_delay)
def test_request_action_execution_render(self):
# Manually create ConfigDB
- output = 'Testing'
- value = {
- "config_item_one": output
- }
+ output = "Testing"
+ value = {"config_item_one": output}
config_db = pk_db_models.ConfigDB(pack=PACK_7, values=value)
config = pk_db_access.Config.add_or_update(config_db)
self.assertEqual(len(config), 3)
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'render_config_context.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, "render_config_context.yaml"
+ )
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db)
# Request the workflow execution.
wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta)
st2_ctx = self.mock_st2_context(ac_ex_db)
wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx)
- spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog'])
+ spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"])
wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec)
# Pass down appropriate st2 context to the task and action execution(s).
- root_st2_ctx = wf_ex_db.context.get('st2', {})
+ root_st2_ctx = wf_ex_db.context.get("st2", {})
st2_ctx = {
- 'execution_id': wf_ex_db.action_execution,
- 'user': root_st2_ctx.get('user'),
- 'pack': root_st2_ctx.get('pack')
+ "execution_id": wf_ex_db.action_execution,
+ "user": root_st2_ctx.get("user"),
+ "pack": root_st2_ctx.get("pack"),
}
# Manually request task execution.
task_route = 0
- task_id = 'task1'
+ task_id = "task1"
task_spec = wf_spec.tasks.get_task(task_id)
- task_ctx = {'foo': 'bar'}
+ task_ctx = {"foo": "bar"}
task_ex_req = {
- 'id': task_id,
- 'route': task_route,
- 'spec': task_spec,
- 'ctx': task_ctx,
- 'actions': [
- {'action': 'dummy_pack_7.render_config_context', 'input': None}
- ]
+ "id": task_id,
+ "route": task_route,
+ "spec": task_spec,
+ "ctx": task_ctx,
+ "actions": [
+ {"action": "dummy_pack_7.render_config_context", "input": None}
+ ],
}
workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req)
# Check task execution is saved to the database.
- task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))
+ task_ex_dbs = wf_db_access.TaskExecution.query(
+ workflow_execution=str(wf_ex_db.id)
+ )
self.assertEqual(len(task_ex_dbs), 1)
workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req)
# Manually request action execution
task_ex_db = task_ex_dbs[0]
- action_ex_db = workflow_service.request_action_execution(wf_ex_db, task_ex_db, st2_ctx,
- task_ex_req['actions'][0])
+ action_ex_db = workflow_service.request_action_execution(
+ wf_ex_db, task_ex_db, st2_ctx, task_ex_req["actions"][0]
+ )
# Check required attributes.
self.assertIsNotNone(str(action_ex_db.id))
self.assertEqual(task_ex_db.workflow_execution, str(wf_ex_db.id))
- expected_parameters = {'value1': output}
+ expected_parameters = {"value1": output}
self.assertEqual(expected_parameters, action_ex_db.parameters)
diff --git a/st2common/tests/unit/services/test_workflow_cancellation.py b/st2common/tests/unit/services/test_workflow_cancellation.py
index 26455971f0..22694924a3 100644
--- a/st2common/tests/unit/services/test_workflow_cancellation.py
+++ b/st2common/tests/unit/services/test_workflow_cancellation.py
@@ -22,6 +22,7 @@
import st2tests
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -35,39 +36,35 @@
TEST_FIXTURES = {
- 'workflows': [
- 'sequential.yaml',
- 'join.yaml'
- ],
- 'actions': [
- 'sequential.yaml',
- 'join.yaml'
- ]
+ "workflows": ["sequential.yaml", "join.yaml"],
+ "actions": ["sequential.yaml", "join.yaml"],
}
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
publishers.CUDPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
class WorkflowExecutionCancellationTest(st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(WorkflowExecutionCancellationTest, cls).setUpClass()
@@ -77,8 +74,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -86,8 +82,10 @@ def setUpClass(cls):
def test_cancellation(self):
# Manually create the liveaction and action execution objects without publishing.
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, TEST_FIXTURES['workflows'][0])
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(
+ TEST_PACK_PATH, TEST_FIXTURES["workflows"][0]
+ )
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.create_request(lv_ac_db)
# Request and pre-process the workflow execution.
@@ -98,8 +96,8 @@ def test_cancellation(self):
# Manually request task executions.
task_route = 0
- self.run_workflow_step(wf_ex_db, 'task1', task_route)
- self.assert_task_running('task2', task_route)
+ self.run_workflow_step(wf_ex_db, "task1", task_route)
+ self.assert_task_running("task2", task_route)
# Cancel the workflow when there is still active task(s).
wf_ex_db = wf_svc.request_cancellation(ac_ex_db)
@@ -108,8 +106,8 @@ def test_cancellation(self):
self.assertEqual(wf_ex_db.status, wf_statuses.CANCELING)
# Manually complete the task and ensure workflow is canceled.
- self.run_workflow_step(wf_ex_db, 'task2', task_route)
- self.assert_task_not_started('task3', task_route)
+ self.run_workflow_step(wf_ex_db, "task2", task_route)
+ self.assert_task_not_started("task3", task_route)
conductor, wf_ex_db = wf_svc.refresh_conductor(str(wf_ex_db.id))
self.assertEqual(conductor.get_workflow_status(), wf_statuses.CANCELED)
self.assertEqual(wf_ex_db.status, wf_statuses.CANCELED)
diff --git a/st2common/tests/unit/services/test_workflow_identify_orphans.py b/st2common/tests/unit/services/test_workflow_identify_orphans.py
index d45ba1527f..306e22badd 100644
--- a/st2common/tests/unit/services/test_workflow_identify_orphans.py
+++ b/st2common/tests/unit/services/test_workflow_identify_orphans.py
@@ -24,6 +24,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -48,42 +49,51 @@
LOG = logging.getLogger(__name__)
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class WorkflowServiceIdentifyOrphansTest(st2tests.WorkflowTestCase):
ensure_indexes = True
ensure_indexes_models = [
ex_db_models.ActionExecutionDB,
lv_db_models.LiveActionDB,
wf_db_models.WorkflowExecutionDB,
- wf_db_models.TaskExecutionDB
+ wf_db_models.TaskExecutionDB,
]
@classmethod
@@ -95,8 +105,7 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
@@ -119,8 +128,9 @@ def tearDown(self):
def mock_workflow_records(self, completed=False, expired=True, log=True):
status = (
- ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else
- ac_const.LIVEACTION_STATUS_RUNNING
+ ac_const.LIVEACTION_STATUS_SUCCEEDED
+ if completed
+ else ac_const.LIVEACTION_STATUS_RUNNING
)
# Identify start and end timestamp
@@ -131,18 +141,24 @@ def mock_workflow_records(self, completed=False, expired=True, log=True):
end_timestamp = utc_now_dt if completed else None
# Assign metadata.
- action_ref = 'orquesta_tests.sequential'
- runner = 'orquesta'
- user = 'stanley'
+ action_ref = "orquesta_tests.sequential"
+ runner = "orquesta"
+ user = "stanley"
# Create the WorkflowExecutionDB record first since the ID needs to be
# included in the LiveActionDB and ActionExecutionDB records.
- st2_ctx = {'st2': {'action_execution_id': '123', 'action': 'foobar', 'runner': 'orquesta'}}
+ st2_ctx = {
+ "st2": {
+ "action_execution_id": "123",
+ "action": "foobar",
+ "runner": "orquesta",
+ }
+ }
wf_ex_db = wf_db_models.WorkflowExecutionDB(
context=st2_ctx,
status=status,
start_timestamp=start_timestamp,
- end_timestamp=end_timestamp
+ end_timestamp=end_timestamp,
)
wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False)
@@ -152,13 +168,10 @@ def mock_workflow_records(self, completed=False, expired=True, log=True):
workflow_execution=str(wf_ex_db.id),
action=action_ref,
action_is_workflow=True,
- context={
- 'user': user,
- 'workflow_execution': str(wf_ex_db.id)
- },
+ context={"user": user, "workflow_execution": str(wf_ex_db.id)},
status=status,
start_timestamp=start_timestamp,
- end_timestamp=end_timestamp
+ end_timestamp=end_timestamp,
)
lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False)
@@ -166,30 +179,20 @@ def mock_workflow_records(self, completed=False, expired=True, log=True):
# Create the ActionExecutionDB record.
ac_ex_db = ex_db_models.ActionExecutionDB(
workflow_execution=str(wf_ex_db.id),
- action={
- 'runner_type': runner,
- 'ref': action_ref
- },
- runner={
- 'name': runner
- },
- liveaction={
- 'id': str(lv_ac_db.id)
- },
- context={
- 'user': user,
- 'workflow_execution': str(wf_ex_db.id)
- },
+ action={"runner_type": runner, "ref": action_ref},
+ runner={"name": runner},
+ liveaction={"id": str(lv_ac_db.id)},
+ context={"user": user, "workflow_execution": str(wf_ex_db.id)},
status=status,
start_timestamp=start_timestamp,
- end_timestamp=end_timestamp
+ end_timestamp=end_timestamp,
)
if log:
- ac_ex_db.log = [{'status': 'running', 'timestamp': start_timestamp}]
+ ac_ex_db.log = [{"status": "running", "timestamp": start_timestamp}]
if log and status in ac_const.LIVEACTION_COMPLETED_STATES:
- ac_ex_db.log.append({'status': status, 'timestamp': end_timestamp})
+ ac_ex_db.log.append({"status": status, "timestamp": end_timestamp})
ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False)
@@ -199,14 +202,16 @@ def mock_workflow_records(self, completed=False, expired=True, log=True):
return wf_ex_db, lv_ac_db, ac_ex_db
- def mock_task_records(self, parent, task_id, task_route=0,
- completed=True, expired=False, log=True):
+ def mock_task_records(
+ self, parent, task_id, task_route=0, completed=True, expired=False, log=True
+ ):
if not completed and expired:
- raise ValueError('Task must be set completed=True if expired=True.')
+ raise ValueError("Task must be set completed=True if expired=True.")
status = (
- ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else
- ac_const.LIVEACTION_STATUS_RUNNING
+ ac_const.LIVEACTION_STATUS_SUCCEEDED
+ if completed
+ else ac_const.LIVEACTION_STATUS_RUNNING
)
parent_wf_ex_db, parent_ac_ex_db = parent[0], parent[2]
@@ -218,9 +223,9 @@ def mock_task_records(self, parent, task_id, task_route=0,
end_timestamp = expiry_dt if expired else utc_now_dt
# Assign metadata.
- action_ref = 'core.local'
- runner = 'local-shell-cmd'
- user = 'stanley'
+ action_ref = "core.local"
+ runner = "local-shell-cmd"
+ user = "stanley"
# Create the TaskExecutionDB record first since the ID needs to be
# included in the LiveActionDB and ActionExecutionDB records.
@@ -229,7 +234,7 @@ def mock_task_records(self, parent, task_id, task_route=0,
task_id=task_id,
task_route=0,
status=status,
- start_timestamp=parent_wf_ex_db.start_timestamp
+ start_timestamp=parent_wf_ex_db.start_timestamp,
)
if status in ac_const.LIVEACTION_COMPLETED_STATES:
@@ -239,18 +244,15 @@ def mock_task_records(self, parent, task_id, task_route=0,
# Build context for LiveActionDB and ActionExecutionDB.
context = {
- 'user': user,
- 'orquesta': {
- 'task_id': tk_ex_db.task_id,
- 'task_name': tk_ex_db.task_id,
- 'workflow_execution_id': str(parent_wf_ex_db.id),
- 'task_execution_id': str(tk_ex_db.id),
- 'task_route': tk_ex_db.task_route
+ "user": user,
+ "orquesta": {
+ "task_id": tk_ex_db.task_id,
+ "task_name": tk_ex_db.task_id,
+ "workflow_execution_id": str(parent_wf_ex_db.id),
+ "task_execution_id": str(tk_ex_db.id),
+ "task_route": tk_ex_db.task_route,
},
- 'parent': {
- 'user': user,
- 'execution_id': str(parent_ac_ex_db.id)
- }
+ "parent": {"user": user, "execution_id": str(parent_ac_ex_db.id)},
}
# Create the LiveActionDB record.
@@ -262,7 +264,7 @@ def mock_task_records(self, parent, task_id, task_route=0,
context=context,
status=status,
start_timestamp=tk_ex_db.start_timestamp,
- end_timestamp=tk_ex_db.end_timestamp
+ end_timestamp=tk_ex_db.end_timestamp,
)
lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False)
@@ -271,27 +273,22 @@ def mock_task_records(self, parent, task_id, task_route=0,
ac_ex_db = ex_db_models.ActionExecutionDB(
workflow_execution=str(parent_wf_ex_db.id),
task_execution=str(tk_ex_db.id),
- action={
- 'runner_type': runner,
- 'ref': action_ref
- },
- runner={
- 'name': runner
- },
- liveaction={
- 'id': str(lv_ac_db.id)
- },
+ action={"runner_type": runner, "ref": action_ref},
+ runner={"name": runner},
+ liveaction={"id": str(lv_ac_db.id)},
context=context,
status=status,
start_timestamp=tk_ex_db.start_timestamp,
- end_timestamp=tk_ex_db.end_timestamp
+ end_timestamp=tk_ex_db.end_timestamp,
)
if log:
- ac_ex_db.log = [{'status': 'running', 'timestamp': tk_ex_db.start_timestamp}]
+ ac_ex_db.log = [
+ {"status": "running", "timestamp": tk_ex_db.start_timestamp}
+ ]
if log and status in ac_const.LIVEACTION_COMPLETED_STATES:
- ac_ex_db.log.append({'status': status, 'timestamp': tk_ex_db.end_timestamp})
+ ac_ex_db.log.append({"status": status, "timestamp": tk_ex_db.end_timestamp})
ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False)
@@ -303,18 +300,18 @@ def test_no_orphans(self):
# Workflow that is still running with task completed and not expired.
wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False)
+ self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False)
# Workflow that is still running with task running and not expired.
wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False)
+ self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False)
# Workflow that is completed and not expired.
self.mock_workflow_records(completed=True, expired=False)
# Workflow that is completed with task completed and not expired.
wf_ex_set_5 = self.mock_workflow_records(completed=True, expired=False)
- self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=False)
+ self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=False)
orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows()
self.assertEqual(len(orphaned_ac_ex_dbs), 0)
@@ -339,33 +336,33 @@ def test_identify_orphans_with_no_task_executions(self):
def test_identify_orphans_with_task_executions(self):
# Workflow that is still running with task completed and expired.
wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True)
- self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True)
+ self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True)
# Workflow that is still running with task completed and not expired.
wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False)
+ self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False)
# Workflow that is still running with task running and not expired.
wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False)
+ self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False)
# Workflow that is still running with multiple tasks and not expired.
# One of the task completed passed expiry date but another task is still running.
wf_ex_set_4 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_4, 'task1', completed=True, expired=True)
- self.mock_task_records(wf_ex_set_4, 'task2', completed=False, expired=False)
+ self.mock_task_records(wf_ex_set_4, "task1", completed=True, expired=True)
+ self.mock_task_records(wf_ex_set_4, "task2", completed=False, expired=False)
# Workflow that is still running with multiple tasks and not expired.
# Both of the tasks are completed with one completed only recently.
wf_ex_set_5 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=True)
- self.mock_task_records(wf_ex_set_5, 'task2', completed=True, expired=False)
+ self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=True)
+ self.mock_task_records(wf_ex_set_5, "task2", completed=True, expired=False)
# Workflow that is still running with multiple tasks and not expired.
# One of the task completed recently and another task is still running.
wf_ex_set_6 = self.mock_workflow_records(completed=False, expired=False)
- self.mock_task_records(wf_ex_set_6, 'task1', completed=True, expired=False)
- self.mock_task_records(wf_ex_set_6, 'task2', completed=False, expired=False)
+ self.mock_task_records(wf_ex_set_6, "task1", completed=True, expired=False)
+ self.mock_task_records(wf_ex_set_6, "task2", completed=False, expired=False)
orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows()
self.assertEqual(len(orphaned_ac_ex_dbs), 1)
@@ -373,8 +370,10 @@ def test_identify_orphans_with_task_executions(self):
def test_action_execution_with_missing_log_entries(self):
# Workflow that is still running and expired. However the state change logs are missing.
- wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True, log=False)
- self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True)
+ wf_ex_set_1 = self.mock_workflow_records(
+ completed=False, expired=True, log=False
+ )
+ self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True)
orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows()
self.assertEqual(len(orphaned_ac_ex_dbs), 0)
@@ -385,7 +384,7 @@ def test_garbage_collection(self):
# Workflow that is still running with task completed and expired.
wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=True)
- self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=True)
+ self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=True)
# Ensure these workflows are identified as orphans.
orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows()
diff --git a/st2common/tests/unit/services/test_workflow_rerun.py b/st2common/tests/unit/services/test_workflow_rerun.py
index 6808991595..f5ff2bc487 100644
--- a/st2common/tests/unit/services/test_workflow_rerun.py
+++ b/st2common/tests/unit/services/test_workflow_rerun.py
@@ -24,6 +24,7 @@
import st2tests
import st2tests.config as tests_config
+
tests_config.parse_args()
from local_runner import local_shell_command_runner
@@ -42,32 +43,38 @@
from st2tests.mocks import liveaction as mock_lv_ac_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {})
-RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {})
+RUNNER_RESULT_SUCCEEDED = (
+ action_constants.LIVEACTION_STATUS_SUCCEEDED,
+ {"stdout": "foobar"},
+ {},
+)
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
publishers.CUDPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
class WorkflowExecutionRerunTest(st2tests.WorkflowTestCase):
-
@classmethod
def setUpClass(cls):
super(WorkflowExecutionRerunTest, cls).setUpClass()
@@ -77,18 +84,17 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
def prep_wf_ex_for_rerun(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1)
# Request the workflow execution.
@@ -99,9 +105,12 @@ def prep_wf_ex_for_rerun(self):
# Fail workflow execution.
self.run_workflow_step(
- wf_ex_db, 'task1', 0,
+ wf_ex_db,
+ "task1",
+ 0,
expected_ac_ex_db_status=action_constants.LIVEACTION_STATUS_FAILED,
- expected_tk_ex_db_status=wf_statuses.FAILED)
+ expected_tk_ex_db_status=wf_statuses.FAILED,
+ )
# Check workflow status.
conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id))
@@ -115,20 +124,22 @@ def prep_wf_ex_for_rerun(self):
return wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]),
+ )
def test_request_rerun(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options)
wf_ex_db = self.prep_wf_ex(wf_ex_db)
@@ -138,7 +149,7 @@ def test_request_rerun(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Complete task1.
- self.run_workflow_step(wf_ex_db, 'task1', 0)
+ self.run_workflow_step(wf_ex_db, "task1", 0)
# Check workflow status and make sure it is still running.
conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id))
@@ -150,10 +161,10 @@ def test_request_rerun(self):
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
def test_request_rerun_while_original_is_still_running(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
# Manually create the liveaction and action execution objects without publishing.
- lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1)
# Request the workflow execution.
@@ -168,16 +179,16 @@ def test_request_rerun_while_original_is_still_running(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- '^Unable to rerun workflow execution \".*\" '
- 'because it is not in a completed state.$'
+ '^Unable to rerun workflow execution ".*" '
+ "because it is not in a completed state.$"
)
self.assertRaisesRegexp(
@@ -186,24 +197,26 @@ def test_request_rerun_while_original_is_still_running(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]),
+ )
def test_request_rerun_again_while_prev_rerun_is_still_running(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options)
wf_ex_db = self.prep_wf_ex(wf_ex_db)
@@ -213,7 +226,7 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self):
self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING)
# Complete task1.
- self.run_workflow_step(wf_ex_db, 'task1', 0)
+ self.run_workflow_step(wf_ex_db, "task1", 0)
# Check workflow status and make sure it is still running.
conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id))
@@ -225,16 +238,16 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self):
self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING)
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db3, ac_ex_db3 = action_service.create_request(lv_ac_db3)
# Request workflow execution rerun again.
st2_ctx = self.mock_st2_context(ac_ex_db3, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- '^Unable to rerun workflow execution \".*\" '
- 'because it is not in a completed state.$'
+ '^Unable to rerun workflow execution ".*" '
+ "because it is not in a completed state.$"
)
self.assertRaisesRegexp(
@@ -243,26 +256,28 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self):
workflow_service.request_rerun,
ac_ex_db3,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_missing_workflow_execution_id(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun without workflow_execution_id.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- 'Unable to rerun workflow execution because '
- 'workflow_execution_id is not provided.'
+ "Unable to rerun workflow execution because "
+ "workflow_execution_id is not provided."
)
self.assertRaisesRegexp(
@@ -271,27 +286,28 @@ def test_request_rerun_with_missing_workflow_execution_id(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_nonexistent_workflow_execution(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun with bogus workflow_execution_id.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = uuid.uuid4().hex[0:24]
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = uuid.uuid4().hex[0:24]
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- '^Unable to rerun workflow execution \".*\" '
- 'because it does not exist.$'
+ '^Unable to rerun workflow execution ".*" ' "because it does not exist.$"
)
self.assertRaisesRegexp(
@@ -300,12 +316,14 @@ def test_request_rerun_with_nonexistent_workflow_execution(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_workflow_execution_not_abended(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
@@ -315,16 +333,16 @@ def test_request_rerun_with_workflow_execution_not_abended(self):
wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db)
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun with bogus workflow_execution_id.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- '^Unable to rerun workflow execution \".*\" '
- 'because it is not in a completed state.$'
+ '^Unable to rerun workflow execution ".*" '
+ "because it is not in a completed state.$"
)
self.assertRaisesRegexp(
@@ -333,29 +351,33 @@ def test_request_rerun_with_workflow_execution_not_abended(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_conductor_status_not_abended(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually set workflow conductor state to paused.
- wf_ex_db.state['status'] = wf_statuses.PAUSED
+ wf_ex_db.state["status"] = wf_statuses.PAUSED
wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db)
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun with bogus workflow_execution_id.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
- expected_error = 'Unable to rerun workflow because it is not in a completed state.'
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
+ expected_error = (
+ "Unable to rerun workflow because it is not in a completed state."
+ )
self.assertRaisesRegexp(
wf_exc.WorkflowExecutionRerunException,
@@ -363,25 +385,29 @@ def test_request_rerun_with_conductor_status_not_abended(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_bad_task_name(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task5354']}
- expected_error = '^Unable to rerun workflow because one or more tasks is not found: .*$'
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task5354"]}
+ expected_error = (
+ "^Unable to rerun workflow because one or more tasks is not found: .*$"
+ )
self.assertRaisesRegexp(
wf_exc.WorkflowExecutionRerunException,
@@ -389,36 +415,40 @@ def test_request_rerun_with_bad_task_name(self):
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
@mock.patch.object(
- local_shell_command_runner.LocalShellCommandRunner, 'run',
- mock.MagicMock(return_value=RUNNER_RESULT_FAILED))
+ local_shell_command_runner.LocalShellCommandRunner,
+ "run",
+ mock.MagicMock(return_value=RUNNER_RESULT_FAILED),
+ )
def test_request_rerun_with_conductor_status_not_resuming(self):
# Create and return a failed workflow execution.
wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun()
# Manually create the liveaction and action execution objects for the rerun.
- lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2)
# Request workflow execution rerun with bogus workflow_execution_id.
st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context)
- st2_ctx['workflow_execution_id'] = str(wf_ex_db.id)
- rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']}
+ st2_ctx["workflow_execution_id"] = str(wf_ex_db.id)
+ rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}
expected_error = (
- '^Unable to rerun workflow execution \".*\" '
- 'due to an unknown cause.'
+ '^Unable to rerun workflow execution ".*" ' "due to an unknown cause."
)
- with mock.patch.object(conducting.WorkflowConductor, 'get_workflow_status',
- mock.MagicMock(return_value=wf_statuses.FAILED)):
+ with mock.patch.object(
+ conducting.WorkflowConductor,
+ "get_workflow_status",
+ mock.MagicMock(return_value=wf_statuses.FAILED),
+ ):
self.assertRaisesRegexp(
wf_exc.WorkflowExecutionRerunException,
expected_error,
workflow_service.request_rerun,
ac_ex_db2,
st2_ctx,
- rerun_options
+ rerun_options,
)
diff --git a/st2common/tests/unit/services/test_workflow_service_retries.py b/st2common/tests/unit/services/test_workflow_service_retries.py
index 35fafc1213..baa79c6954 100644
--- a/st2common/tests/unit/services/test_workflow_service_retries.py
+++ b/st2common/tests/unit/services/test_workflow_service_retries.py
@@ -27,6 +27,7 @@
# XXX: actionsensor import depends on config being setup.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.bootstrap import actionsregistrar
@@ -49,12 +50,14 @@
from st2tests.mocks import workflow as mock_wf_ex_xport
-TEST_PACK = 'orquesta_tests'
-TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK
+TEST_PACK = "orquesta_tests"
+TEST_PACK_PATH = (
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK
+)
PACKS = [
TEST_PACK_PATH,
- st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'
+ st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core",
]
@@ -63,11 +66,11 @@
def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, **kwargs):
- seq_len = len(wf_ex_db.state['sequence'])
+ seq_len = len(wf_ex_db.state["sequence"])
if seq_len > 0:
- current_task_id = wf_ex_db.state['sequence'][seq_len - 1:][0]['id']
- temp_file_path = TEMP_DIR_PATH + '/' + current_task_id
+ current_task_id = wf_ex_db.state["sequence"][seq_len - 1 :][0]["id"]
+ temp_file_path = TEMP_DIR_PATH + "/" + current_task_id
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
@@ -77,31 +80,38 @@ def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, **
@mock.patch.object(
- publishers.CUDPublisher,
- 'publish_update',
- mock.MagicMock(return_value=None))
+ publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None)
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create),
+)
@mock.patch.object(
lv_ac_xport.LiveActionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_create',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create))
+ "publish_create",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create
+ ),
+)
@mock.patch.object(
wf_ex_xport.WorkflowExecutionPublisher,
- 'publish_state',
- mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state))
+ "publish_state",
+ mock.MagicMock(
+ side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state
+ ),
+)
class OrquestaServiceRetryTest(st2tests.WorkflowTestCase):
ensure_indexes = True
ensure_indexes_models = [
wf_db_models.WorkflowExecutionDB,
wf_db_models.TaskExecutionDB,
- ex_q_db_models.ActionExecutionSchedulingQueueItemDB
+ ex_q_db_models.ActionExecutionSchedulingQueueItemDB,
]
@classmethod
@@ -113,30 +123,38 @@ def setUpClass(cls):
# Register test pack(s).
actions_registrar = actionsregistrar.ActionsRegistrar(
- use_pack_cache=False,
- fail_on_failure=True
+ use_pack_cache=False, fail_on_failure=True
)
for pack in PACKS:
actions_registrar.register_from_pack(pack)
@mock.patch.object(
- coord_svc.NoOpDriver, 'get_lock',
- mock.MagicMock(side_effect=[
- coordination.ToozConnectionError('foobar'),
- coordination.ToozConnectionError('fubar'),
- coord_svc.NoOpLock(name='noop')]))
+ coord_svc.NoOpDriver,
+ "get_lock",
+ mock.MagicMock(
+ side_effect=[
+ coordination.ToozConnectionError("foobar"),
+ coordination.ToozConnectionError("fubar"),
+ coord_svc.NoOpLock(name="noop"),
+ ]
+ ),
+ )
def test_recover_from_coordinator_connection_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Process task1 and expect acquiring lock returns a few connection errors.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
@@ -145,45 +163,60 @@ def test_recover_from_coordinator_connection_error(self):
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
@mock.patch.object(
- coord_svc.NoOpDriver, 'get_lock',
- mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar')))
+ coord_svc.NoOpDriver,
+ "get_lock",
+ mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")),
+ )
def test_retries_exhausted_from_coordinator_connection_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Process task1 but retries exhaused with connection errors.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# The connection error should raise if retries are exhaused.
self.assertRaises(
coordination.ToozConnectionError,
wf_svc.handle_action_execution_completion,
- tk1_ac_ex_db
+ tk1_ac_ex_db,
)
@mock.patch.object(
- wf_svc, 'update_task_state',
- mock.MagicMock(side_effect=[
- mongoengine.connection.MongoEngineConnectionError(),
- mongoengine.connection.MongoEngineConnectionError(),
- None]))
+ wf_svc,
+ "update_task_state",
+ mock.MagicMock(
+ side_effect=[
+ mongoengine.connection.MongoEngineConnectionError(),
+ mongoengine.connection.MongoEngineConnectionError(),
+ None,
+ ]
+ ),
+ )
def test_recover_from_database_connection_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Process task1 and expect acquiring lock returns a few connection errors.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
wf_svc.handle_action_execution_completion(tk1_ac_ex_db)
@@ -192,61 +225,71 @@ def test_recover_from_database_connection_error(self):
self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED)
@mock.patch.object(
- wf_svc, 'update_task_state',
- mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError()))
+ wf_svc,
+ "update_task_state",
+ mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError()),
+ )
def test_retries_exhausted_from_database_connection_error(self):
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Process task1 but retries exhaused with connection errors.
- query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'}
+ query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"}
tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0]
- tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0]
- tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id'])
+ tk1_ac_ex_db = ex_db_access.ActionExecution.query(
+ task_execution=str(tk1_ex_db.id)
+ )[0]
+ tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"])
self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
# The connection error should raise if retries are exhaused.
self.assertRaises(
mongoengine.connection.MongoEngineConnectionError,
wf_svc.handle_action_execution_completion,
- tk1_ac_ex_db
+ tk1_ac_ex_db,
)
@mock.patch.object(
- wf_db_access.WorkflowExecution, 'update',
- mock.MagicMock(side_effect=mock_wf_db_update_conflict))
+ wf_db_access.WorkflowExecution,
+ "update",
+ mock.MagicMock(side_effect=mock_wf_db_update_conflict),
+ )
def test_recover_from_database_write_conflicts(self):
# Create a temporary file which will be used to signal
# which task(s) to mock the DB write conflict.
- temp_file_path = TEMP_DIR_PATH + '/task4'
+ temp_file_path = TEMP_DIR_PATH + "/task4"
if not os.path.exists(temp_file_path):
- with open(temp_file_path, 'w'):
+ with open(temp_file_path, "w"):
pass
# Manually create the liveaction and action execution objects without publishing.
- wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'join.yaml')
- lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'])
+ wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "join.yaml")
+ lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"])
lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db)
- wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0]
+ wf_ex_db = wf_db_access.WorkflowExecution.query(
+ action_execution=str(ac_ex_db.id)
+ )[0]
# Manually request task executions.
task_route = 0
- self.run_workflow_step(wf_ex_db, 'task1', task_route)
- self.assert_task_running('task2', task_route)
- self.assert_task_running('task4', task_route)
- self.run_workflow_step(wf_ex_db, 'task2', task_route)
- self.assert_task_running('task3', task_route)
- self.run_workflow_step(wf_ex_db, 'task4', task_route)
- self.assert_task_running('task5', task_route)
- self.run_workflow_step(wf_ex_db, 'task3', task_route)
- self.assert_task_not_started('task6', task_route)
- self.run_workflow_step(wf_ex_db, 'task5', task_route)
- self.assert_task_running('task6', task_route)
- self.run_workflow_step(wf_ex_db, 'task6', task_route)
- self.assert_task_running('task7', task_route)
- self.run_workflow_step(wf_ex_db, 'task7', task_route)
+ self.run_workflow_step(wf_ex_db, "task1", task_route)
+ self.assert_task_running("task2", task_route)
+ self.assert_task_running("task4", task_route)
+ self.run_workflow_step(wf_ex_db, "task2", task_route)
+ self.assert_task_running("task3", task_route)
+ self.run_workflow_step(wf_ex_db, "task4", task_route)
+ self.assert_task_running("task5", task_route)
+ self.run_workflow_step(wf_ex_db, "task3", task_route)
+ self.assert_task_not_started("task6", task_route)
+ self.run_workflow_step(wf_ex_db, "task5", task_route)
+ self.assert_task_running("task6", task_route)
+ self.run_workflow_step(wf_ex_db, "task6", task_route)
+ self.assert_task_running("task7", task_route)
+ self.run_workflow_step(wf_ex_db, "task7", task_route)
self.assert_workflow_completed(str(wf_ex_db.id), status=wf_statuses.SUCCEEDED)
# Ensure retry happened.
diff --git a/st2common/tests/unit/test_action_alias_utils.py b/st2common/tests/unit/test_action_alias_utils.py
index 33b78981a5..daad0fbe1e 100644
--- a/st2common/tests/unit/test_action_alias_utils.py
+++ b/st2common/tests/unit/test_action_alias_utils.py
@@ -14,281 +14,312 @@
# limitations under the License.
from __future__ import absolute_import
-from sre_parse import (parse, AT, AT_BEGINNING, AT_BEGINNING_STRING, AT_END, AT_END_STRING)
+from sre_parse import (
+ parse,
+ AT,
+ AT_BEGINNING,
+ AT_BEGINNING_STRING,
+ AT_END,
+ AT_END_STRING,
+)
from mock import Mock
from unittest2 import TestCase
from st2common.exceptions.content import ParseException
from st2common.models.utils.action_alias_utils import (
- ActionAliasFormatParser, search_regex_tokens,
- inject_immutable_parameters
+ ActionAliasFormatParser,
+ search_regex_tokens,
+ inject_immutable_parameters,
)
class TestActionAliasParser(TestCase):
def test_empty_string(self):
- alias_format = ''
- param_stream = ''
+ alias_format = ""
+ param_stream = ""
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
self.assertEqual(extracted_values, {})
def test_arbitrary_pairs(self):
# single-word param
- alias_format = ''
- param_stream = 'a=foobar1'
+ alias_format = ""
+ param_stream = "a=foobar1"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'foobar1'})
+ self.assertEqual(extracted_values, {"a": "foobar1"})
# multi-word double-quoted param
- alias_format = 'foo'
+ alias_format = "foo"
param_stream = 'foo a="foobar2 poonies bar"'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'})
+ self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"})
# multi-word single-quoted param
- alias_format = 'foo'
- param_stream = 'foo a=\'foobar2 poonies bar\''
+ alias_format = "foo"
+ param_stream = "foo a='foobar2 poonies bar'"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'})
+ self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"})
# JSON param
- alias_format = 'foo'
+ alias_format = "foo"
param_stream = 'foo a={"foobar2": "poonies"}'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': '{"foobar2": "poonies"}'})
+ self.assertEqual(extracted_values, {"a": '{"foobar2": "poonies"}'})
# Multiple mixed params
- alias_format = ''
+ alias_format = ""
param_stream = 'a=foobar1 b="boobar2 3 4" c=\'coobar3 4\' d={"a": "b"}'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'foobar1',
- 'b': 'boobar2 3 4',
- 'c': 'coobar3 4',
- 'd': '{"a": "b"}'})
+ self.assertEqual(
+ extracted_values,
+ {"a": "foobar1", "b": "boobar2 3 4", "c": "coobar3 4", "d": '{"a": "b"}'},
+ )
# Params along with a "normal" alias format
- alias_format = '{{ captain }} is my captain'
+ alias_format = "{{ captain }} is my captain"
param_stream = 'Malcolm Reynolds is my captain weirdo="River Tam"'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'captain': 'Malcolm Reynolds',
- 'weirdo': 'River Tam'})
+ self.assertEqual(
+ extracted_values, {"captain": "Malcolm Reynolds", "weirdo": "River Tam"}
+ )
def test_simple_parsing(self):
- alias_format = 'skip {{a}} more skip {{b}} and skip more.'
- param_stream = 'skip a1 more skip b1 and skip more.'
+ alias_format = "skip {{a}} more skip {{b}} and skip more."
+ param_stream = "skip a1 more skip b1 and skip more."
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'})
+ self.assertEqual(extracted_values, {"a": "a1", "b": "b1"})
def test_end_string_parsing(self):
- alias_format = 'skip {{a}} more skip {{b}}'
- param_stream = 'skip a1 more skip b1'
+ alias_format = "skip {{a}} more skip {{b}}"
+ param_stream = "skip a1 more skip b1"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'})
+ self.assertEqual(extracted_values, {"a": "a1", "b": "b1"})
def test_spaced_parsing(self):
- alias_format = 'skip {{a}} more skip {{b}} and skip more.'
+ alias_format = "skip {{a}} more skip {{b}} and skip more."
param_stream = 'skip "a1 a2" more skip b1 and skip more.'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1'})
+ self.assertEqual(extracted_values, {"a": "a1 a2", "b": "b1"})
def test_default_values(self):
- alias_format = 'acl {{a}} {{b}} {{c}} {{d=1}}'
+ alias_format = "acl {{a}} {{b}} {{c}} {{d=1}}"
param_stream = 'acl "a1 a2" "b1" "c1"'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1',
- 'c': 'c1', 'd': '1'})
+ self.assertEqual(
+ extracted_values, {"a": "a1 a2", "b": "b1", "c": "c1", "d": "1"}
+ )
def test_spacing(self):
- alias_format = 'acl {{a=test}}'
- param_stream = 'acl'
+ alias_format = "acl {{a=test}}"
+ param_stream = "acl"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'test'})
+ self.assertEqual(extracted_values, {"a": "test"})
def test_json_parsing(self):
- alias_format = 'skip {{a}} more skip.'
+ alias_format = "skip {{a}} more skip."
param_stream = 'skip {"a": "b", "c": "d"} more skip.'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}'})
+ self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}'})
def test_mixed_parsing(self):
- alias_format = 'skip {{a}} more skip {{b}}.'
+ alias_format = "skip {{a}} more skip {{b}}."
param_stream = 'skip {"a": "b", "c": "d"} more skip x.'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}',
- 'b': 'x'})
+ self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}', "b": "x"})
def test_param_spaces(self):
- alias_format = 's {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}'
- param_stream = 's one more two more three more'
+ alias_format = "s {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}"
+ param_stream = "s one more two more three more"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'one', 'b': 'two',
- 'c': 'three', 'd': '99'})
+ self.assertEqual(
+ extracted_values, {"a": "one", "b": "two", "c": "three", "d": "99"}
+ )
def test_enclosed_defaults(self):
- alias_format = 'skip {{ a = value }} more'
- param_stream = 'skip one more'
+ alias_format = "skip {{ a = value }} more"
+ param_stream = "skip one more"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'one'})
+ self.assertEqual(extracted_values, {"a": "one"})
- alias_format = 'skip {{ a = value }} more'
- param_stream = 'skip more'
+ alias_format = "skip {{ a = value }} more"
+ param_stream = "skip more"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'value'})
+ self.assertEqual(extracted_values, {"a": "value"})
def test_template_defaults(self):
- alias_format = 'two by two hands of {{ color = {{ colors.default_color }} }}'
- param_stream = 'two by two hands of'
+ alias_format = "two by two hands of {{ color = {{ colors.default_color }} }}"
+ param_stream = "two by two hands of"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'color': '{{ colors.default_color }}'})
+ self.assertEqual(extracted_values, {"color": "{{ colors.default_color }}"})
def test_key_value_combinations(self):
# one-word value, single extra pair
- alias_format = 'testing {{ a }}'
- param_stream = 'testing value b=value2'
+ alias_format = "testing {{ a }}"
+ param_stream = "testing value b=value2"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'value',
- 'b': 'value2'})
+ self.assertEqual(extracted_values, {"a": "value", "b": "value2"})
# default value, single extra pair with quotes
- alias_format = 'testing {{ a=new }}'
+ alias_format = "testing {{ a=new }}"
param_stream = 'testing b="another value"'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'a': 'new',
- 'b': 'another value'})
+ self.assertEqual(extracted_values, {"a": "new", "b": "another value"})
# multiple values and multiple extra pairs
- alias_format = 'testing {{ b=abc }} {{ c=xyz }}'
+ alias_format = "testing {{ b=abc }} {{ c=xyz }}"
param_stream = 'testing newvalue d={"1": "2"} e="long value"'
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'b': 'newvalue',
- 'c': 'xyz',
- 'd': '{"1": "2"}',
- 'e': 'long value'})
+ self.assertEqual(
+ extracted_values,
+ {"b": "newvalue", "c": "xyz", "d": '{"1": "2"}', "e": "long value"},
+ )
def test_stream_is_none_with_all_default_values(self):
- alias_format = 'skip {{d=test1}} more skip {{e=test1}}.'
- param_stream = 'skip more skip'
+ alias_format = "skip {{d=test1}} more skip {{e=test1}}."
+ param_stream = "skip more skip"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'d': 'test1', 'e': 'test1'})
+ self.assertEqual(extracted_values, {"d": "test1", "e": "test1"})
def test_stream_is_not_none_some_default_values(self):
- alias_format = 'skip {{d=test}} more skip {{e=test}}'
- param_stream = 'skip ponies more skip'
+ alias_format = "skip {{d=test}} more skip {{e=test}}"
+ param_stream = "skip ponies more skip"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'d': 'ponies', 'e': 'test'})
+ self.assertEqual(extracted_values, {"d": "ponies", "e": "test"})
def test_stream_is_none_no_default_values(self):
- alias_format = 'skip {{d}} more skip {{e}}.'
+ alias_format = "skip {{d}} more skip {{e}}."
param_stream = None
parser = ActionAliasFormatParser(alias_format, param_stream)
- expected_msg = 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."'
- self.assertRaisesRegexp(ParseException, expected_msg,
- parser.get_extracted_param_value)
+ expected_msg = (
+ 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."'
+ )
+ self.assertRaisesRegexp(
+ ParseException, expected_msg, parser.get_extracted_param_value
+ )
def test_all_the_things(self):
# this is the most insane example I could come up with
- alias_format = "{{ p0='http' }} g {{ p1=p }} a " + \
- "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}"
- param_stream = "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}"
+ alias_format = (
+ "{{ p0='http' }} g {{ p1=p }} a "
+ + "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}"
+ )
+ param_stream = (
+ "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}"
+ )
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'p0': 'http', 'p1': 'p',
- 'url': 'http://google.com',
- 'p2': '{{ execution.id }}',
- 'p3': '{{ e.i }}',
- 'p4': 'testing', 'p5': "{'a':'c'}"})
+ self.assertEqual(
+ extracted_values,
+ {
+ "p0": "http",
+ "p1": "p",
+ "url": "http://google.com",
+ "p2": "{{ execution.id }}",
+ "p3": "{{ e.i }}",
+ "p4": "testing",
+ "p5": "{'a':'c'}",
+ },
+ )
def test_command_doesnt_match_format_string(self):
- alias_format = 'foo bar ponies'
- param_stream = 'foo lulz ponies'
+ alias_format = "foo bar ponies"
+ param_stream = "foo lulz ponies"
parser = ActionAliasFormatParser(alias_format, param_stream)
- expected_msg = 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"'
- self.assertRaisesRegexp(ParseException, expected_msg,
- parser.get_extracted_param_value)
+ expected_msg = (
+ 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"'
+ )
+ self.assertRaisesRegexp(
+ ParseException, expected_msg, parser.get_extracted_param_value
+ )
def test_ending_parameters_matching(self):
- alias_format = 'foo bar'
- param_stream = 'foo bar pony1=foo pony2=bar'
+ alias_format = "foo bar"
+ param_stream = "foo bar pony1=foo pony2=bar"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'pony1': 'foo', 'pony2': 'bar'})
+ self.assertEqual(extracted_values, {"pony1": "foo", "pony2": "bar"})
def test_regex_beginning_anchors(self):
- alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)'
- param_stream = 'foo ASDF-1234'
+ alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"
+ param_stream = "foo ASDF-1234"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'})
+ self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"})
def test_regex_beginning_anchors_dont_match(self):
- alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)'
- param_stream = 'bar foo ASDF-1234'
+ alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"
+ param_stream = "bar foo ASDF-1234"
parser = ActionAliasFormatParser(alias_format, param_stream)
- expected_msg = r'''Command "bar foo ASDF-1234" doesn't match format string '''\
- r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"'''
+ expected_msg = (
+ r"""Command "bar foo ASDF-1234" doesn't match format string """
+ r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"'''
+ )
with self.assertRaises(ParseException) as e:
parser.get_extracted_param_value()
self.assertEqual(e.msg, expected_msg)
def test_regex_ending_anchors(self):
- alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$'
- param_stream = 'foo ASDF-1234'
+ alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"
+ param_stream = "foo ASDF-1234"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'})
+ self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"})
def test_regex_ending_anchors_dont_match(self):
- alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$'
- param_stream = 'foo ASDF-1234 bar'
+ alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"
+ param_stream = "foo ASDF-1234 bar"
parser = ActionAliasFormatParser(alias_format, param_stream)
- expected_msg = r'''Command "foo ASDF-1234 bar" doesn't match format string '''\
- r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"'''
+ expected_msg = (
+ r"""Command "foo ASDF-1234 bar" doesn't match format string """
+ r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"'''
+ )
with self.assertRaises(ParseException) as e:
parser.get_extracted_param_value()
self.assertEqual(e.msg, expected_msg)
def test_regex_beginning_and_ending_anchors(self):
- alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$'
- param_stream = 'foo ASDF-1234 bar'
+ alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$"
+ param_stream = "foo ASDF-1234 bar"
parser = ActionAliasFormatParser(alias_format, param_stream)
extracted_values = parser.get_extracted_param_value()
- self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'})
+ self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"})
def test_regex_beginning_and_ending_anchors_dont_match(self):
- alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$'
- param_stream = 'bar ASDF-1234'
+ alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"
+ param_stream = "bar ASDF-1234"
parser = ActionAliasFormatParser(alias_format, param_stream)
- expected_msg = r'''Command "bar ASDF-1234" doesn't match format string '''\
- r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"'''
+ expected_msg = (
+ r"""Command "bar ASDF-1234" doesn't match format string """
+ r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"'''
+ )
with self.assertRaises(ParseException) as e:
parser.get_extracted_param_value()
@@ -332,8 +363,8 @@ def test_immutable_parameters_are_injected(self):
exec_params = [{"param1": "value1", "param2": "value2"}]
inject_immutable_parameters(action_alias_db, exec_params, {})
self.assertEqual(
- exec_params,
- [{"param1": "value1", "param2": "value2", "env": "dev"}])
+ exec_params, [{"param1": "value1", "param2": "value2", "env": "dev"}]
+ )
def test_immutable_parameters_with_jinja(self):
action_alias_db = Mock()
@@ -341,8 +372,8 @@ def test_immutable_parameters_with_jinja(self):
exec_params = [{"param1": "value1", "param2": "value2"}]
inject_immutable_parameters(action_alias_db, exec_params, {})
self.assertEqual(
- exec_params,
- [{"param1": "value1", "param2": "value2", "env": "dev1"}])
+ exec_params, [{"param1": "value1", "param2": "value2", "env": "dev1"}]
+ )
def test_override_raises_error(self):
action_alias_db = Mock()
diff --git a/st2common/tests/unit/test_action_api_validator.py b/st2common/tests/unit/test_action_api_validator.py
index 1cf16d3f14..5be1ca13ba 100644
--- a/st2common/tests/unit/test_action_api_validator.py
+++ b/st2common/tests/unit/test_action_api_validator.py
@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except ImportError:
@@ -29,66 +30,83 @@
from st2tests import DbTestCase
from st2tests.fixtures.packs import executions as fixture
-__all__ = [
- 'TestActionAPIValidator'
-]
+__all__ = ["TestActionAPIValidator"]
class TestActionAPIValidator(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestActionAPIValidator, cls).setUpClass()
runners_registrar.register_runners()
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_runner_type_happy_case(self):
- action_api_dict = fixture.ARTIFACTS['actions']['local']
+ action_api_dict = fixture.ARTIFACTS["actions"]["local"]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
except:
- self.fail('Exception validating action: %s' % json.dumps(action_api_dict))
+ self.fail("Exception validating action: %s" % json.dumps(action_api_dict))
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_runner_type_invalid_runner(self):
- action_api_dict = fixture.ARTIFACTS['actions']['action-with-invalid-runner']
+ action_api_dict = fixture.ARTIFACTS["actions"]["action-with-invalid-runner"]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
- self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should not have passed. %s"
+ % json.dumps(action_api_dict)
+ )
except ValueValidationException:
pass
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_override_immutable_runner_param(self):
- action_api_dict = fixture.ARTIFACTS['actions']['remote-override-runner-immutable']
+ action_api_dict = fixture.ARTIFACTS["actions"][
+ "remote-override-runner-immutable"
+ ]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
- self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should not have passed. %s"
+ % json.dumps(action_api_dict)
+ )
except ValueValidationException as e:
- self.assertIn('Cannot override in action.', six.text_type(e))
+ self.assertIn("Cannot override in action.", six.text_type(e))
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_action_param_immutable(self):
- action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-param-no-default']
+ action_api_dict = fixture.ARTIFACTS["actions"][
+ "action-immutable-param-no-default"
+ ]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
- self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should not have passed. %s"
+ % json.dumps(action_api_dict)
+ )
except ValueValidationException as e:
- self.assertIn('requires a default value.', six.text_type(e))
+ self.assertIn("requires a default value.", six.text_type(e))
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_action_param_immutable_no_default(self):
- action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-runner-param-no-default']
+ action_api_dict = fixture.ARTIFACTS["actions"][
+ "action-immutable-runner-param-no-default"
+ ]
action_api = ActionAPI(**action_api_dict)
# Runner param sudo is decalred immutable in action but no defualt value
@@ -97,30 +115,44 @@ def test_validate_action_param_immutable_no_default(self):
action_validator.validate_action(action_api)
except ValueValidationException as e:
print(e)
- self.fail('Action validation should have passed. %s' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should have passed. %s" % json.dumps(action_api_dict)
+ )
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_action_param_position_values_unique(self):
- action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-unique-positions']
+ action_api_dict = fixture.ARTIFACTS["actions"][
+ "action-with-non-unique-positions"
+ ]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
- self.fail('Action validation should have failed ' +
- 'because position values are not unique.' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should have failed "
+ + "because position values are not unique."
+ % json.dumps(action_api_dict)
+ )
except ValueValidationException as e:
- self.assertIn('have same position', six.text_type(e))
+ self.assertIn("have same position", six.text_type(e))
- @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(
- return_value=True))
+ @mock.patch.object(
+ action_validator, "_is_valid_pack", mock.MagicMock(return_value=True)
+ )
def test_validate_action_param_position_values_contiguous(self):
- action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-contiguous-positions']
+ action_api_dict = fixture.ARTIFACTS["actions"][
+ "action-with-non-contiguous-positions"
+ ]
action_api = ActionAPI(**action_api_dict)
try:
action_validator.validate_action(action_api)
- self.fail('Action validation should have failed ' +
- 'because position values are not contiguous.' % json.dumps(action_api_dict))
+ self.fail(
+ "Action validation should have failed "
+ + "because position values are not contiguous."
+ % json.dumps(action_api_dict)
+ )
except ValueValidationException as e:
- self.assertIn('are not contiguous', six.text_type(e))
+ self.assertIn("are not contiguous", six.text_type(e))
diff --git a/st2common/tests/unit/test_action_db_utils.py b/st2common/tests/unit/test_action_db_utils.py
index ba2dcef018..f7a114b85b 100644
--- a/st2common/tests/unit/test_action_db_utils.py
+++ b/st2common/tests/unit/test_action_db_utils.py
@@ -35,7 +35,7 @@
from st2tests.base import DbTestCase
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ActionDBUtilsTestCase(DbTestCase):
runnertype_db = None
action_db = None
@@ -48,26 +48,39 @@ def setUpClass(cls):
def test_get_runnertype_nonexisting(self):
# By id.
- self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_id,
- 'somedummyrunnerid')
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ action_db_utils.get_runnertype_by_id,
+ "somedummyrunnerid",
+ )
# By name.
- self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_name,
- 'somedummyrunnername')
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ action_db_utils.get_runnertype_by_name,
+ "somedummyrunnername",
+ )
def test_get_runnertype_existing(self):
# Lookup by id and verify name equals.
- runner = action_db_utils.get_runnertype_by_id(ActionDBUtilsTestCase.runnertype_db.id)
+ runner = action_db_utils.get_runnertype_by_id(
+ ActionDBUtilsTestCase.runnertype_db.id
+ )
self.assertEqual(runner.name, ActionDBUtilsTestCase.runnertype_db.name)
# Lookup by name and verify id equals.
- runner = action_db_utils.get_runnertype_by_name(ActionDBUtilsTestCase.runnertype_db.name)
+ runner = action_db_utils.get_runnertype_by_name(
+ ActionDBUtilsTestCase.runnertype_db.name
+ )
self.assertEqual(runner.id, ActionDBUtilsTestCase.runnertype_db.id)
def test_get_action_nonexisting(self):
# By id.
- self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_action_by_id,
- 'somedummyactionid')
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ action_db_utils.get_action_by_id,
+ "somedummyactionid",
+ )
# By ref.
- action = action_db_utils.get_action_by_ref('packaintexist.somedummyactionname')
+ action = action_db_utils.get_action_by_ref("packaintexist.somedummyactionname")
self.assertIsNone(action)
def test_get_action_existing(self):
@@ -77,50 +90,57 @@ def test_get_action_existing(self):
# Lookup by reference as string.
action_ref = ResourceReference.to_string_reference(
pack=ActionDBUtilsTestCase.action_db.pack,
- name=ActionDBUtilsTestCase.action_db.name)
+ name=ActionDBUtilsTestCase.action_db.name,
+ )
action = action_db_utils.get_action_by_ref(action_ref)
self.assertEqual(action.id, ActionDBUtilsTestCase.action_db.id)
def test_get_actionexec_nonexisting(self):
# By id.
- self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_liveaction_by_id,
- 'somedummyactionexecid')
+ self.assertRaises(
+ StackStormDBObjectNotFoundError,
+ action_db_utils.get_liveaction_by_id,
+ "somedummyactionexecid",
+ )
def test_get_actionexec_existing(self):
- liveaction = action_db_utils.get_liveaction_by_id(ActionDBUtilsTestCase.liveaction_db.id)
+ liveaction = action_db_utils.get_liveaction_by_id(
+ ActionDBUtilsTestCase.liveaction_db.id
+ )
self.assertEqual(liveaction, ActionDBUtilsTestCase.liveaction_db)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_liveaction_with_incorrect_output_schema(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
runner = mock.MagicMock()
- runner.output_schema = {
- "notaparam": {
- "type": "boolean"
- }
- }
+ runner.output_schema = {"notaparam": {"type": "boolean"}}
liveaction_db.runner = runner
liveaction_db = LiveAction.add_or_update(liveaction_db)
origliveaction_db = copy.copy(liveaction_db)
now = get_datetime_utc_now()
- status = 'succeeded'
- result = 'Work is done.'
- context = {'third_party_id': uuid.uuid4().hex}
+ status = "succeeded"
+ result = "Work is done."
+ context = {"third_party_id": uuid.uuid4().hex}
newliveaction_db = action_db_utils.update_liveaction_status(
- status=status, result=result, context=context, end_timestamp=now,
- liveaction_id=liveaction_db.id)
+ status=status,
+ result=result,
+ context=context,
+ end_timestamp=now,
+ liveaction_id=liveaction_db.id,
+ )
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
self.assertEqual(newliveaction_db.status, status)
@@ -128,18 +148,19 @@ def test_update_liveaction_with_incorrect_output_schema(self):
self.assertDictEqual(newliveaction_db.context, context)
self.assertEqual(newliveaction_db.end_timestamp, now)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_liveaction_status(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -147,24 +168,31 @@ def test_update_liveaction_status(self):
# Update by id.
newliveaction_db = action_db_utils.update_liveaction_status(
- status='running', liveaction_id=liveaction_db.id)
+ status="running", liveaction_id=liveaction_db.id
+ )
# Verify id didn't change.
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
- self.assertEqual(newliveaction_db.status, 'running')
+ self.assertEqual(newliveaction_db.status, "running")
# Verify that state is published.
self.assertTrue(LiveActionPublisher.publish_state.called)
- LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running')
+ LiveActionPublisher.publish_state.assert_called_once_with(
+ newliveaction_db, "running"
+ )
# Update status, result, context, and end timestamp.
now = get_datetime_utc_now()
- status = 'succeeded'
- result = 'Work is done.'
- context = {'third_party_id': uuid.uuid4().hex}
+ status = "succeeded"
+ result = "Work is done."
+ context = {"third_party_id": uuid.uuid4().hex}
newliveaction_db = action_db_utils.update_liveaction_status(
- status=status, result=result, context=context, end_timestamp=now,
- liveaction_id=liveaction_db.id)
+ status=status,
+ result=result,
+ context=context,
+ end_timestamp=now,
+ liveaction_id=liveaction_db.id,
+ )
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
self.assertEqual(newliveaction_db.status, status)
@@ -172,18 +200,19 @@ def test_update_liveaction_status(self):
self.assertDictEqual(newliveaction_db.context, context)
self.assertEqual(newliveaction_db.end_timestamp, now)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_canceled_liveaction(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -191,21 +220,25 @@ def test_update_canceled_liveaction(self):
# Update by id.
newliveaction_db = action_db_utils.update_liveaction_status(
- status='running', liveaction_id=liveaction_db.id)
+ status="running", liveaction_id=liveaction_db.id
+ )
# Verify id didn't change.
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
- self.assertEqual(newliveaction_db.status, 'running')
+ self.assertEqual(newliveaction_db.status, "running")
# Verify that state is published.
self.assertTrue(LiveActionPublisher.publish_state.called)
- LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running')
+ LiveActionPublisher.publish_state.assert_called_once_with(
+ newliveaction_db, "running"
+ )
# Cancel liveaction.
now = get_datetime_utc_now()
- status = 'canceled'
+ status = "canceled"
newliveaction_db = action_db_utils.update_liveaction_status(
- status=status, end_timestamp=now, liveaction_id=liveaction_db.id)
+ status=status, end_timestamp=now, liveaction_id=liveaction_db.id
+ )
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
self.assertEqual(newliveaction_db.status, status)
self.assertEqual(newliveaction_db.end_timestamp, now)
@@ -213,31 +246,36 @@ def test_update_canceled_liveaction(self):
# Since liveaction has already been canceled, check that anymore update of
# status, result, context, and end timestamp are not processed.
now = get_datetime_utc_now()
- status = 'succeeded'
- result = 'Work is done.'
- context = {'third_party_id': uuid.uuid4().hex}
+ status = "succeeded"
+ result = "Work is done."
+ context = {"third_party_id": uuid.uuid4().hex}
newliveaction_db = action_db_utils.update_liveaction_status(
- status=status, result=result, context=context, end_timestamp=now,
- liveaction_id=liveaction_db.id)
+ status=status,
+ result=result,
+ context=context,
+ end_timestamp=now,
+ liveaction_id=liveaction_db.id,
+ )
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
- self.assertEqual(newliveaction_db.status, 'canceled')
+ self.assertEqual(newliveaction_db.status, "canceled")
self.assertNotEqual(newliveaction_db.result, result)
self.assertNotEqual(newliveaction_db.context, context)
self.assertNotEqual(newliveaction_db.end_timestamp, now)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_liveaction_result_with_dotted_key(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -245,66 +283,79 @@ def test_update_liveaction_result_with_dotted_key(self):
# Update by id.
newliveaction_db = action_db_utils.update_liveaction_status(
- status='running', liveaction_id=liveaction_db.id)
+ status="running", liveaction_id=liveaction_db.id
+ )
# Verify id didn't change.
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
- self.assertEqual(newliveaction_db.status, 'running')
+ self.assertEqual(newliveaction_db.status, "running")
# Verify that state is published.
self.assertTrue(LiveActionPublisher.publish_state.called)
- LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running')
+ LiveActionPublisher.publish_state.assert_called_once_with(
+ newliveaction_db, "running"
+ )
now = get_datetime_utc_now()
- status = 'succeeded'
- result = {'a': 1, 'b': True, 'a.b.c': 'abc'}
- context = {'third_party_id': uuid.uuid4().hex}
+ status = "succeeded"
+ result = {"a": 1, "b": True, "a.b.c": "abc"}
+ context = {"third_party_id": uuid.uuid4().hex}
newliveaction_db = action_db_utils.update_liveaction_status(
- status=status, result=result, context=context, end_timestamp=now,
- liveaction_id=liveaction_db.id)
+ status=status,
+ result=result,
+ context=context,
+ end_timestamp=now,
+ liveaction_id=liveaction_db.id,
+ )
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
self.assertEqual(newliveaction_db.status, status)
- self.assertIn('a.b.c', list(result.keys()))
+ self.assertIn("a.b.c", list(result.keys()))
self.assertDictEqual(newliveaction_db.result, result)
self.assertDictEqual(newliveaction_db.context, context)
self.assertEqual(newliveaction_db.end_timestamp, now)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_LiveAction_status_invalid(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
liveaction_db = LiveAction.add_or_update(liveaction_db)
# Update by id.
- self.assertRaises(ValueError, action_db_utils.update_liveaction_status,
- status='mea culpa', liveaction_id=liveaction_db.id)
+ self.assertRaises(
+ ValueError,
+ action_db_utils.update_liveaction_status,
+ status="mea culpa",
+ liveaction_id=liveaction_db.id,
+ )
# Verify that state is not published.
self.assertFalse(LiveActionPublisher.publish_state.called)
- @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock())
+ @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock())
def test_update_same_liveaction_status(self):
liveaction_db = LiveActionDB()
- liveaction_db.status = 'requested'
+ liveaction_db.status = "requested"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ResourceReference(
name=ActionDBUtilsTestCase.action_db.name,
- pack=ActionDBUtilsTestCase.action_db.pack).ref
+ pack=ActionDBUtilsTestCase.action_db.pack,
+ ).ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
liveaction_db = LiveAction.add_or_update(liveaction_db)
@@ -312,141 +363,150 @@ def test_update_same_liveaction_status(self):
# Update by id.
newliveaction_db = action_db_utils.update_liveaction_status(
- status='requested', liveaction_id=liveaction_db.id)
+ status="requested", liveaction_id=liveaction_db.id
+ )
# Verify id didn't change.
self.assertEqual(origliveaction_db.id, newliveaction_db.id)
- self.assertEqual(newliveaction_db.status, 'requested')
+ self.assertEqual(newliveaction_db.status, "requested")
# Verify that state is not published.
self.assertFalse(LiveActionPublisher.publish_state.called)
def test_get_args(self):
- params = {
- 'actionstr': 'foo',
- 'actionint': 20,
- 'runnerint': 555
- }
- pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, ['20', '', 'foo', '', '', '', '', ''],
- 'Positional args not parsed correctly.')
- self.assertNotIn('actionint', named_args)
- self.assertNotIn('actionstr', named_args)
- self.assertEqual(named_args.get('runnerint'), 555)
+ params = {"actionstr": "foo", "actionint": 20, "runnerint": 555}
+ pos_args, named_args = action_db_utils.get_args(
+ params, ActionDBUtilsTestCase.action_db
+ )
+ self.assertListEqual(
+ pos_args,
+ ["20", "", "foo", "", "", "", "", ""],
+ "Positional args not parsed correctly.",
+ )
+ self.assertNotIn("actionint", named_args)
+ self.assertNotIn("actionstr", named_args)
+ self.assertEqual(named_args.get("runnerint"), 555)
# Test serialization for different positional argument types and values
# Test all the values provided
params = {
- 'actionint': 1,
- 'actionfloat': 1.5,
- 'actionstr': 'string value',
- 'actionbool': True,
- 'actionarray': ['foo', 'bar', 'baz', 'qux'],
- 'actionlist': ['foo', 'bar', 'baz'],
- 'actionobject': {'a': 1, 'b': '2'},
+ "actionint": 1,
+ "actionfloat": 1.5,
+ "actionstr": "string value",
+ "actionbool": True,
+ "actionarray": ["foo", "bar", "baz", "qux"],
+ "actionlist": ["foo", "bar", "baz"],
+ "actionobject": {"a": 1, "b": "2"},
}
expected_pos_args = [
- '1',
- '1.5',
- 'string value',
- '1',
- 'foo,bar,baz,qux',
- 'foo,bar,baz',
+ "1",
+ "1.5",
+ "string value",
+ "1",
+ "foo,bar,baz,qux",
+ "foo,bar,baz",
'{"a": 1, "b": "2"}',
- ''
+ "",
]
pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, expected_pos_args,
- 'Positional args not parsed / serialized correctly.')
+ self.assertListEqual(
+ pos_args,
+ expected_pos_args,
+ "Positional args not parsed / serialized correctly.",
+ )
params = {
- 'actionint': 1,
- 'actionfloat': 1.5,
- 'actionstr': 'string value',
- 'actionbool': False,
- 'actionarray': [],
- 'actionlist': [],
- 'actionobject': {'a': 1, 'b': '2'},
+ "actionint": 1,
+ "actionfloat": 1.5,
+ "actionstr": "string value",
+ "actionbool": False,
+ "actionarray": [],
+ "actionlist": [],
+ "actionobject": {"a": 1, "b": "2"},
}
expected_pos_args = [
- '1',
- '1.5',
- 'string value',
- '0',
- '',
- '',
+ "1",
+ "1.5",
+ "string value",
+ "0",
+ "",
+ "",
'{"a": 1, "b": "2"}',
- ''
+ "",
]
pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, expected_pos_args,
- 'Positional args not parsed / serialized correctly.')
+ self.assertListEqual(
+ pos_args,
+ expected_pos_args,
+ "Positional args not parsed / serialized correctly.",
+ )
# Test none values
params = {
- 'actionint': None,
- 'actionfloat': None,
- 'actionstr': None,
- 'actionbool': None,
- 'actionarray': None,
- 'actionlist': None,
- 'actionobject': None,
+ "actionint": None,
+ "actionfloat": None,
+ "actionstr": None,
+ "actionbool": None,
+ "actionarray": None,
+ "actionlist": None,
+ "actionobject": None,
}
- expected_pos_args = [
- '',
- '',
- '',
- '',
- '',
- '',
- '',
- ''
- ]
+ expected_pos_args = ["", "", "", "", "", "", "", ""]
pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, expected_pos_args,
- 'Positional args not parsed / serialized correctly.')
+ self.assertListEqual(
+ pos_args,
+ expected_pos_args,
+ "Positional args not parsed / serialized correctly.",
+ )
# Test unicode values
params = {
- 'actionstr': 'bar č š hello đ č p ž Ž a 💩😁',
- 'actionint': 20,
- 'runnerint': 555
+ "actionstr": "bar č š hello đ č p ž Ž a 💩😁",
+ "actionint": 20,
+ "runnerint": 555,
}
expected_pos_args = [
- '20',
- '',
- u'bar č š hello đ č p ž Ž a 💩😁',
- '',
- '',
- '',
- '',
- ''
+ "20",
+ "",
+ "bar č š hello đ č p ž Ž a 💩😁",
+ "",
+ "",
+ "",
+ "",
+ "",
]
- pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, expected_pos_args, 'Positional args not parsed correctly.')
+ pos_args, named_args = action_db_utils.get_args(
+ params, ActionDBUtilsTestCase.action_db
+ )
+ self.assertListEqual(
+ pos_args, expected_pos_args, "Positional args not parsed correctly."
+ )
# Test arrays and lists with values of different types
params = {
- 'actionarray': [None, False, 1, 4.2e1, '1e3', 'foo'],
- 'actionlist': [None, False, 1, 73e-2, '1e2', 'bar']
+ "actionarray": [None, False, 1, 4.2e1, "1e3", "foo"],
+ "actionlist": [None, False, 1, 73e-2, "1e2", "bar"],
}
expected_pos_args = [
- '',
- '',
- '',
- '',
- 'None,False,1,42.0,1e3,foo',
- 'None,False,1,0.73,1e2,bar',
- '',
- ''
+ "",
+ "",
+ "",
+ "",
+ "None,False,1,42.0,1e3,foo",
+ "None,False,1,0.73,1e2,bar",
+ "",
+ "",
]
pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db)
- self.assertListEqual(pos_args, expected_pos_args,
- 'Positional args not parsed / serialized correctly.')
+ self.assertListEqual(
+ pos_args,
+ expected_pos_args,
+ "Positional args not parsed / serialized correctly.",
+ )
- self.assertNotIn('actionint', named_args)
- self.assertNotIn('actionstr', named_args)
- self.assertEqual(named_args.get('runnerint'), 555)
+ self.assertNotIn("actionint", named_args)
+ self.assertNotIn("actionstr", named_args)
+ self.assertEqual(named_args.get("runnerint"), 555)
@classmethod
def _setup_test_models(cls):
@@ -456,63 +516,65 @@ def _setup_test_models(cls):
@classmethod
def setup_runner(cls):
test_runner = {
- 'name': 'test-runner',
- 'description': 'A test runner.',
- 'enabled': True,
- 'runner_parameters': {
- 'runnerstr': {
- 'description': 'Foo str param.',
- 'type': 'string',
- 'default': 'defaultfoo'
+ "name": "test-runner",
+ "description": "A test runner.",
+ "enabled": True,
+ "runner_parameters": {
+ "runnerstr": {
+ "description": "Foo str param.",
+ "type": "string",
+ "default": "defaultfoo",
},
- 'runnerint': {
- 'description': 'Foo int param.',
- 'type': 'number'
+ "runnerint": {"description": "Foo int param.", "type": "number"},
+ "runnerdummy": {
+ "description": "Dummy param.",
+ "type": "string",
+ "default": "runnerdummy",
},
- 'runnerdummy': {
- 'description': 'Dummy param.',
- 'type': 'string',
- 'default': 'runnerdummy'
- }
},
- 'runner_module': 'tests.test_runner'
+ "runner_module": "tests.test_runner",
}
runnertype_api = RunnerTypeAPI(**test_runner)
ActionDBUtilsTestCase.runnertype_db = RunnerType.add_or_update(
- RunnerTypeAPI.to_model(runnertype_api))
+ RunnerTypeAPI.to_model(runnertype_api)
+ )
@classmethod
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def setup_action_models(cls):
- pack = 'wolfpack'
- name = 'action-1'
+ pack = "wolfpack"
+ name = "action-1"
parameters = {
- 'actionint': {'type': 'number', 'default': 10, 'position': 0},
- 'actionfloat': {'type': 'float', 'required': False, 'position': 1},
- 'actionstr': {'type': 'string', 'required': True, 'position': 2},
- 'actionbool': {'type': 'boolean', 'required': False, 'position': 3},
- 'actionarray': {'type': 'array', 'required': False, 'position': 4},
- 'actionlist': {'type': 'list', 'required': False, 'position': 5},
- 'actionobject': {'type': 'object', 'required': False, 'position': 6},
- 'actionnull': {'type': 'null', 'required': False, 'position': 7},
-
- 'runnerdummy': {'type': 'string', 'default': 'actiondummy'}
+ "actionint": {"type": "number", "default": 10, "position": 0},
+ "actionfloat": {"type": "float", "required": False, "position": 1},
+ "actionstr": {"type": "string", "required": True, "position": 2},
+ "actionbool": {"type": "boolean", "required": False, "position": 3},
+ "actionarray": {"type": "array", "required": False, "position": 4},
+ "actionlist": {"type": "list", "required": False, "position": 5},
+ "actionobject": {"type": "object", "required": False, "position": 6},
+ "actionnull": {"type": "null", "required": False, "position": 7},
+ "runnerdummy": {"type": "string", "default": "actiondummy"},
}
- action_db = ActionDB(pack=pack, name=name, description='awesomeness',
- enabled=True,
- ref=ResourceReference(name=name, pack=pack).ref,
- entry_point='', runner_type={'name': 'test-runner'},
- parameters=parameters)
+ action_db = ActionDB(
+ pack=pack,
+ name=name,
+ description="awesomeness",
+ enabled=True,
+ ref=ResourceReference(name=name, pack=pack).ref,
+ entry_point="",
+ runner_type={"name": "test-runner"},
+ parameters=parameters,
+ )
ActionDBUtilsTestCase.action_db = Action.add_or_update(action_db)
liveaction_db = LiveActionDB()
- liveaction_db.status = 'initializing'
+ liveaction_db.status = "initializing"
liveaction_db.start_timestamp = get_datetime_utc_now()
liveaction_db.action = ActionDBUtilsTestCase.action_db.ref
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
}
liveaction_db.parameters = params
ActionDBUtilsTestCase.liveaction_db = LiveAction.add_or_update(liveaction_db)
diff --git a/st2common/tests/unit/test_action_param_utils.py b/st2common/tests/unit/test_action_param_utils.py
index 08a6654f21..5eecf018dc 100644
--- a/st2common/tests/unit/test_action_param_utils.py
+++ b/st2common/tests/unit/test_action_param_utils.py
@@ -28,23 +28,16 @@
TEST_FIXTURES = {
- 'actions': [
- 'action1.yaml',
- 'action3.yaml'
- ],
- 'runners': [
- 'testrunner1.yaml',
- 'testrunner3.yaml'
- ]
+ "actions": ["action1.yaml", "action3.yaml"],
+ "runners": ["testrunner1.yaml", "testrunner3.yaml"],
}
-PACK = 'generic'
+PACK = "generic"
LOADER = FixturesLoader()
FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
class ActionParamsUtilsTest(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(ActionParamsUtilsTest, cls).setUpClass()
@@ -54,86 +47,105 @@ def setUpClass(cls):
cls.runnertype_dbs = {}
cls.action_dbs = {}
- for _, fixture in six.iteritems(FIXTURES['runners']):
+ for _, fixture in six.iteritems(FIXTURES["runners"]):
instance = RunnerTypeAPI(**fixture)
runnertype_db = RunnerType.add_or_update(RunnerTypeAPI.to_model(instance))
cls.runnertype_dbs[runnertype_db.name] = runnertype_db
- for _, fixture in six.iteritems(FIXTURES['actions']):
+ for _, fixture in six.iteritems(FIXTURES["actions"]):
instance = ActionAPI(**fixture)
action_db = Action.add_or_update(ActionAPI.to_model(instance))
cls.action_dbs[action_db.name] = action_db
def test_merge_action_runner_params_meta(self):
required, optional, immutable = action_param_utils.get_params_view(
- action_db=self.action_dbs['action-1'],
- runner_db=self.runnertype_dbs['test-runner-1'])
+ action_db=self.action_dbs["action-1"],
+ runner_db=self.runnertype_dbs["test-runner-1"],
+ )
merged = {}
merged.update(required)
merged.update(optional)
merged.update(immutable)
consolidated = action_param_utils.get_params_view(
- action_db=self.action_dbs['action-1'],
- runner_db=self.runnertype_dbs['test-runner-1'],
- merged_only=True)
+ action_db=self.action_dbs["action-1"],
+ runner_db=self.runnertype_dbs["test-runner-1"],
+ merged_only=True,
+ )
# Validate that merged_only view works.
self.assertEqual(merged, consolidated)
# Validate required params.
- self.assertEqual(len(required), 1, 'Required should contain only one param.')
- self.assertIn('actionstr', required, 'actionstr param is a required param.')
- self.assertNotIn('actionstr', optional, 'actionstr should not be in optional parameters')
- self.assertNotIn('actionstr', immutable, 'actionstr should not be in immutable parameters')
- self.assertIn('actionstr', merged, 'actionstr should be in action parameters')
+ self.assertEqual(len(required), 1, "Required should contain only one param.")
+ self.assertIn("actionstr", required, "actionstr param is a required param.")
+ self.assertNotIn(
+ "actionstr", optional, "actionstr should not be in optional parameters"
+ )
+ self.assertNotIn(
+ "actionstr", immutable, "actionstr should not be in immutable parameters"
+ )
+ self.assertIn("actionstr", merged, "actionstr should be in action parameters")
# Validate immutable params.
- self.assertIn('runnerimmutable', immutable, 'runnerimmutable should be in immutable.')
- self.assertIn('actionimmutable', immutable, 'actionimmutable should be in immutable.')
+ self.assertIn(
+ "runnerimmutable", immutable, "runnerimmutable should be in immutable."
+ )
+ self.assertIn(
+ "actionimmutable", immutable, "actionimmutable should be in immutable."
+ )
# Validate optional params.
for opt in optional:
- self.assertIn(opt, merged, 'Optional %s should be in action parameters' % opt)
- self.assertNotIn(opt, required, 'Optional %s should not be in required params' % opt)
- self.assertNotIn(opt, immutable, 'Optional %s should not be in immutable params' % opt)
+ self.assertIn(
+ opt, merged, "Optional %s should be in action parameters" % opt
+ )
+ self.assertNotIn(
+ opt, required, "Optional %s should not be in required params" % opt
+ )
+ self.assertNotIn(
+ opt, immutable, "Optional %s should not be in immutable params" % opt
+ )
def test_merge_param_meta_values(self):
runner_meta = copy.deepcopy(
- self.runnertype_dbs['test-runner-1'].runner_parameters['runnerdummy'])
- action_meta = copy.deepcopy(self.action_dbs['action-1'].parameters['runnerdummy'])
- merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta,
- runner_meta=runner_meta)
+ self.runnertype_dbs["test-runner-1"].runner_parameters["runnerdummy"]
+ )
+ action_meta = copy.deepcopy(
+ self.action_dbs["action-1"].parameters["runnerdummy"]
+ )
+ merged_meta = action_param_utils._merge_param_meta_values(
+ action_meta=action_meta, runner_meta=runner_meta
+ )
# Description is in runner meta but not in action meta.
- self.assertEqual(merged_meta['description'], runner_meta['description'])
+ self.assertEqual(merged_meta["description"], runner_meta["description"])
# Default value is overridden in action.
- self.assertEqual(merged_meta['default'], action_meta['default'])
+ self.assertEqual(merged_meta["default"], action_meta["default"])
# Immutability is set in action.
- self.assertEqual(merged_meta['immutable'], action_meta['immutable'])
+ self.assertEqual(merged_meta["immutable"], action_meta["immutable"])
def test_merge_param_meta_require_override(self):
- action_meta = {
- 'required': False
- }
- runner_meta = {
- 'required': True
- }
- merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta,
- runner_meta=runner_meta)
+ action_meta = {"required": False}
+ runner_meta = {"required": True}
+ merged_meta = action_param_utils._merge_param_meta_values(
+ action_meta=action_meta, runner_meta=runner_meta
+ )
- self.assertEqual(merged_meta['required'], action_meta['required'])
+ self.assertEqual(merged_meta["required"], action_meta["required"])
def test_validate_action_inputs(self):
requires, unexpected = action_param_utils.validate_action_parameters(
- self.action_dbs['action-1'].ref, {'foo': 'bar'})
+ self.action_dbs["action-1"].ref, {"foo": "bar"}
+ )
- self.assertListEqual(requires, ['actionstr'])
- self.assertListEqual(unexpected, ['foo'])
+ self.assertListEqual(requires, ["actionstr"])
+ self.assertListEqual(unexpected, ["foo"])
def test_validate_overridden_action_inputs(self):
requires, unexpected = action_param_utils.validate_action_parameters(
- self.action_dbs['action-3'].ref, {'k1': 'foo'})
+ self.action_dbs["action-3"].ref, {"k1": "foo"}
+ )
self.assertListEqual(requires, [])
self.assertListEqual(unexpected, [])
diff --git a/st2common/tests/unit/test_action_system_models.py b/st2common/tests/unit/test_action_system_models.py
index 8098759b56..c8812acf38 100644
--- a/st2common/tests/unit/test_action_system_models.py
+++ b/st2common/tests/unit/test_action_system_models.py
@@ -19,24 +19,30 @@
from st2common.models.system.action import RemoteAction
from st2common.models.system.action import RemoteScriptAction
-__all__ = [
- 'RemoteActionTestCase',
- 'RemoteScriptActionTestCase'
-]
+__all__ = ["RemoteActionTestCase", "RemoteScriptActionTestCase"]
class RemoteActionTestCase(unittest2.TestCase):
def test_instantiation(self):
- action = RemoteAction(name='name', action_exec_id='aeid', command='ls -la',
- env_vars={'a': 1}, on_behalf_user='onbehalf', user='user',
- hosts=['127.0.0.1'], parallel=False, sudo=True, timeout=10)
- self.assertEqual(action.name, 'name')
- self.assertEqual(action.action_exec_id, 'aeid')
- self.assertEqual(action.command, 'ls -la')
- self.assertEqual(action.env_vars, {'a': 1})
- self.assertEqual(action.on_behalf_user, 'onbehalf')
- self.assertEqual(action.user, 'user')
- self.assertEqual(action.hosts, ['127.0.0.1'])
+ action = RemoteAction(
+ name="name",
+ action_exec_id="aeid",
+ command="ls -la",
+ env_vars={"a": 1},
+ on_behalf_user="onbehalf",
+ user="user",
+ hosts=["127.0.0.1"],
+ parallel=False,
+ sudo=True,
+ timeout=10,
+ )
+ self.assertEqual(action.name, "name")
+ self.assertEqual(action.action_exec_id, "aeid")
+ self.assertEqual(action.command, "ls -la")
+ self.assertEqual(action.env_vars, {"a": 1})
+ self.assertEqual(action.on_behalf_user, "onbehalf")
+ self.assertEqual(action.user, "user")
+ self.assertEqual(action.hosts, ["127.0.0.1"])
self.assertEqual(action.parallel, False)
self.assertEqual(action.sudo, True)
self.assertEqual(action.timeout, 10)
@@ -44,26 +50,35 @@ def test_instantiation(self):
class RemoteScriptActionTestCase(unittest2.TestCase):
def test_instantiation(self):
- action = RemoteScriptAction(name='name', action_exec_id='aeid',
- script_local_path_abs='/tmp/sc/ma_script.sh',
- script_local_libs_path_abs='/tmp/sc/libs', named_args=None,
- positional_args=None, env_vars={'a': 1},
- on_behalf_user='onbehalf', user='user',
- remote_dir='/home/mauser', hosts=['127.0.0.1'],
- parallel=False, sudo=True, timeout=10)
- self.assertEqual(action.name, 'name')
- self.assertEqual(action.action_exec_id, 'aeid')
- self.assertEqual(action.script_local_libs_path_abs, '/tmp/sc/libs')
- self.assertEqual(action.env_vars, {'a': 1})
- self.assertEqual(action.on_behalf_user, 'onbehalf')
- self.assertEqual(action.user, 'user')
- self.assertEqual(action.remote_dir, '/home/mauser')
- self.assertEqual(action.hosts, ['127.0.0.1'])
+ action = RemoteScriptAction(
+ name="name",
+ action_exec_id="aeid",
+ script_local_path_abs="/tmp/sc/ma_script.sh",
+ script_local_libs_path_abs="/tmp/sc/libs",
+ named_args=None,
+ positional_args=None,
+ env_vars={"a": 1},
+ on_behalf_user="onbehalf",
+ user="user",
+ remote_dir="/home/mauser",
+ hosts=["127.0.0.1"],
+ parallel=False,
+ sudo=True,
+ timeout=10,
+ )
+ self.assertEqual(action.name, "name")
+ self.assertEqual(action.action_exec_id, "aeid")
+ self.assertEqual(action.script_local_libs_path_abs, "/tmp/sc/libs")
+ self.assertEqual(action.env_vars, {"a": 1})
+ self.assertEqual(action.on_behalf_user, "onbehalf")
+ self.assertEqual(action.user, "user")
+ self.assertEqual(action.remote_dir, "/home/mauser")
+ self.assertEqual(action.hosts, ["127.0.0.1"])
self.assertEqual(action.parallel, False)
self.assertEqual(action.sudo, True)
self.assertEqual(action.timeout, 10)
- self.assertEqual(action.script_local_dir, '/tmp/sc')
- self.assertEqual(action.script_name, 'ma_script.sh')
- self.assertEqual(action.remote_script, '/home/mauser/ma_script.sh')
- self.assertEqual(action.command, '/home/mauser/ma_script.sh')
+ self.assertEqual(action.script_local_dir, "/tmp/sc")
+ self.assertEqual(action.script_name, "ma_script.sh")
+ self.assertEqual(action.remote_script, "/home/mauser/ma_script.sh")
+ self.assertEqual(action.command, "/home/mauser/ma_script.sh")
diff --git a/st2common/tests/unit/test_actionchain_schema.py b/st2common/tests/unit/test_actionchain_schema.py
index 5c968c9a11..e5bba6c0e2 100644
--- a/st2common/tests/unit/test_actionchain_schema.py
+++ b/st2common/tests/unit/test_actionchain_schema.py
@@ -20,42 +20,48 @@
from st2common.models.system import actionchain
from st2tests.fixturesloader import FixturesLoader
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_FIXTURES = {
- 'actionchains': ['chain1.yaml', 'malformedchain.yaml', 'no_default_chain.yaml',
- 'chain_with_vars.yaml', 'chain_with_publish.yaml']
+ "actionchains": [
+ "chain1.yaml",
+ "malformedchain.yaml",
+ "no_default_chain.yaml",
+ "chain_with_vars.yaml",
+ "chain_with_publish.yaml",
+ ]
}
-FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_FIXTURES)
-CHAIN_1 = FIXTURES['actionchains']['chain1.yaml']
-MALFORMED_CHAIN = FIXTURES['actionchains']['malformedchain.yaml']
-NO_DEFAULT_CHAIN = FIXTURES['actionchains']['no_default_chain.yaml']
-CHAIN_WITH_VARS = FIXTURES['actionchains']['chain_with_vars.yaml']
-CHAIN_WITH_PUBLISH = FIXTURES['actionchains']['chain_with_publish.yaml']
+FIXTURES = FixturesLoader().load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+)
+CHAIN_1 = FIXTURES["actionchains"]["chain1.yaml"]
+MALFORMED_CHAIN = FIXTURES["actionchains"]["malformedchain.yaml"]
+NO_DEFAULT_CHAIN = FIXTURES["actionchains"]["no_default_chain.yaml"]
+CHAIN_WITH_VARS = FIXTURES["actionchains"]["chain_with_vars.yaml"]
+CHAIN_WITH_PUBLISH = FIXTURES["actionchains"]["chain_with_publish.yaml"]
class ActionChainSchemaTest(unittest2.TestCase):
-
def test_actionchain_schema_valid(self):
chain = actionchain.ActionChain(**CHAIN_1)
- self.assertEqual(len(chain.chain), len(CHAIN_1['chain']))
- self.assertEqual(chain.default, CHAIN_1['default'])
+ self.assertEqual(len(chain.chain), len(CHAIN_1["chain"]))
+ self.assertEqual(chain.default, CHAIN_1["default"])
def test_actionchain_no_default(self):
chain = actionchain.ActionChain(**NO_DEFAULT_CHAIN)
- self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN['chain']))
+ self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN["chain"]))
self.assertEqual(chain.default, None)
def test_actionchain_with_vars(self):
chain = actionchain.ActionChain(**CHAIN_WITH_VARS)
- self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS['chain']))
- self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS['vars']))
+ self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS["chain"]))
+ self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS["vars"]))
def test_actionchain_with_publish(self):
chain = actionchain.ActionChain(**CHAIN_WITH_PUBLISH)
- self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH['chain']))
- self.assertEqual(len(chain.chain[0].publish),
- len(CHAIN_WITH_PUBLISH['chain'][0]['publish']))
+ self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH["chain"]))
+ self.assertEqual(
+ len(chain.chain[0].publish), len(CHAIN_WITH_PUBLISH["chain"][0]["publish"])
+ )
def test_actionchain_schema_invalid(self):
with self.assertRaises(ValidationError):
diff --git a/st2common/tests/unit/test_aliasesregistrar.py b/st2common/tests/unit/test_aliasesregistrar.py
index b827830594..4f17246dcf 100644
--- a/st2common/tests/unit/test_aliasesregistrar.py
+++ b/st2common/tests/unit/test_aliasesregistrar.py
@@ -22,22 +22,20 @@
from st2tests import DbTestCase
from st2tests import fixturesloader
-__all__ = [
- 'TestAliasRegistrar'
-]
+__all__ = ["TestAliasRegistrar"]
-ALIASES_FIXTURE_PACK_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(),
- 'dummy_pack_1')
-ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, 'aliases')
+ALIASES_FIXTURE_PACK_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1"
+)
+ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, "aliases")
class TestAliasRegistrar(DbTestCase):
-
def test_alias_registration(self):
count = aliasesregistrar.register_aliases(pack_dir=ALIASES_FIXTURE_PACK_PATH)
# expect all files to contain be aliases
self.assertEqual(count, len(os.listdir(ALIASES_FIXTURE_PATH)))
action_alias_dbs = ActionAlias.get_all()
- self.assertEqual(action_alias_dbs[0].metadata_file, 'aliases/alias1.yaml')
+ self.assertEqual(action_alias_dbs[0].metadata_file, "aliases/alias1.yaml")
diff --git a/st2common/tests/unit/test_api_model_validation.py b/st2common/tests/unit/test_api_model_validation.py
index d5f250482b..20eb98ce6c 100644
--- a/st2common/tests/unit/test_api_model_validation.py
+++ b/st2common/tests/unit/test_api_model_validation.py
@@ -18,196 +18,197 @@
from st2common.models.api.base import BaseAPI
-__all__ = [
- 'APIModelValidationTestCase'
-]
+__all__ = ["APIModelValidationTestCase"]
class MockAPIModel1(BaseAPI):
model = None
schema = {
- 'title': 'MockAPIModel',
- 'description': 'Test',
- 'type': 'object',
- 'properties': {
- 'id': {
- 'description': 'The unique identifier for the action runner.',
- 'type': ['string', 'null'],
- 'default': None
+ "title": "MockAPIModel",
+ "description": "Test",
+ "type": "object",
+ "properties": {
+ "id": {
+ "description": "The unique identifier for the action runner.",
+ "type": ["string", "null"],
+ "default": None,
},
- 'name': {
- 'description': 'The name of the action runner.',
- 'type': 'string',
- 'required': True
+ "name": {
+ "description": "The name of the action runner.",
+ "type": "string",
+ "required": True,
},
- 'description': {
- 'description': 'The description of the action runner.',
- 'type': 'string'
+ "description": {
+ "description": "The description of the action runner.",
+ "type": "string",
},
- 'enabled': {
- 'type': 'boolean',
- 'default': True
- },
- 'parameters': {
- 'type': 'object'
- },
- 'permission_grants': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'resource_uid': {
- 'type': 'string',
- 'description': 'UID of a resource to which this grant applies to.',
- 'required': False,
- 'default': 'unknown'
+ "enabled": {"type": "boolean", "default": True},
+ "parameters": {"type": "object"},
+ "permission_grants": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "resource_uid": {
+ "type": "string",
+ "description": "UID of a resource to which this grant applies to.",
+ "required": False,
+ "default": "unknown",
},
- 'enabled': {
- 'type': 'boolean',
- 'default': True
+ "enabled": {"type": "boolean", "default": True},
+ "description": {
+ "type": "string",
+ "description": "Description",
+ "required": False,
},
- 'description': {
- 'type': 'string',
- 'description': 'Description',
- 'required': False
- }
- }
+ },
},
- 'default': []
- }
+ "default": [],
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class MockAPIModel2(BaseAPI):
model = None
schema = {
- 'title': 'MockAPIModel2',
- 'description': 'Test',
- 'type': 'object',
- 'properties': {
- 'id': {
- 'description': 'The unique identifier for the action runner.',
- 'type': 'string',
- 'default': None
+ "title": "MockAPIModel2",
+ "description": "Test",
+ "type": "object",
+ "properties": {
+ "id": {
+ "description": "The unique identifier for the action runner.",
+ "type": "string",
+ "default": None,
},
- 'permission_grants': {
- 'type': 'array',
- 'items': {
- 'type': 'object',
- 'properties': {
- 'resource_uid': {
- 'type': 'string',
- 'description': 'UID of a resource to which this grant applies to.',
- 'required': False,
- 'default': None
+ "permission_grants": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "resource_uid": {
+ "type": "string",
+ "description": "UID of a resource to which this grant applies to.",
+ "required": False,
+ "default": None,
},
- 'description': {
- 'type': 'string',
- 'required': True
- }
- }
+ "description": {"type": "string", "required": True},
+ },
},
- 'default': []
+ "default": [],
},
- 'parameters': {
- 'type': 'object',
- 'properties': {
- 'id': {
- 'type': 'string',
- 'default': None
- },
- 'name': {
- 'type': 'string',
- 'required': True
- }
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "id": {"type": "string", "default": None},
+ "name": {"type": "string", "required": True},
},
- 'additionalProperties': False,
- }
+ "additionalProperties": False,
+ },
},
- 'additionalProperties': False
+ "additionalProperties": False,
}
class APIModelValidationTestCase(unittest2.TestCase):
def test_validate_default_values_are_set(self):
# no "permission_grants" attribute
- mock_model_api = MockAPIModel1(name='name')
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(mock_model_api.name, 'name')
- self.assertEqual(getattr(mock_model_api, 'enabled', None), None)
- self.assertEqual(getattr(mock_model_api, 'permission_grants', None), None)
+ mock_model_api = MockAPIModel1(name="name")
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(mock_model_api.name, "name")
+ self.assertEqual(getattr(mock_model_api, "enabled", None), None)
+ self.assertEqual(getattr(mock_model_api, "permission_grants", None), None)
mock_model_api_validated = mock_model_api.validate()
# Validate it doesn't modify object in place
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(mock_model_api.name, 'name')
- self.assertEqual(getattr(mock_model_api, 'enabled', None), None)
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(mock_model_api.name, "name")
+ self.assertEqual(getattr(mock_model_api, "enabled", None), None)
# Verify cleaned object
self.assertEqual(mock_model_api_validated.id, None)
- self.assertEqual(mock_model_api_validated.name, 'name')
+ self.assertEqual(mock_model_api_validated.name, "name")
self.assertEqual(mock_model_api_validated.enabled, True)
self.assertEqual(mock_model_api_validated.permission_grants, [])
# "permission_grants" attribute present, but child missing
- mock_model_api = MockAPIModel1(name='name', enabled=False,
- permission_grants=[{}, {'description': 'test'}])
- self.assertEqual(mock_model_api.name, 'name')
+ mock_model_api = MockAPIModel1(
+ name="name", enabled=False, permission_grants=[{}, {"description": "test"}]
+ )
+ self.assertEqual(mock_model_api.name, "name")
self.assertEqual(mock_model_api.enabled, False)
- self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}])
+ self.assertEqual(
+ mock_model_api.permission_grants, [{}, {"description": "test"}]
+ )
mock_model_api_validated = mock_model_api.validate()
# Validate it doesn't modify object in place
- self.assertEqual(mock_model_api.name, 'name')
+ self.assertEqual(mock_model_api.name, "name")
self.assertEqual(mock_model_api.enabled, False)
- self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}])
+ self.assertEqual(
+ mock_model_api.permission_grants, [{}, {"description": "test"}]
+ )
# Verify cleaned object
self.assertEqual(mock_model_api_validated.id, None)
- self.assertEqual(mock_model_api_validated.name, 'name')
+ self.assertEqual(mock_model_api_validated.name, "name")
self.assertEqual(mock_model_api_validated.enabled, False)
- self.assertEqual(mock_model_api_validated.permission_grants,
- [{'resource_uid': 'unknown', 'enabled': True},
- {'resource_uid': 'unknown', 'enabled': True, 'description': 'test'}])
+ self.assertEqual(
+ mock_model_api_validated.permission_grants,
+ [
+ {"resource_uid": "unknown", "enabled": True},
+ {"resource_uid": "unknown", "enabled": True, "description": "test"},
+ ],
+ )
def test_validate_nested_attribute_with_default_not_provided(self):
mock_model_api = MockAPIModel2()
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset')
- self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset')
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(
+ getattr(mock_model_api, "permission_grants", "notset"), "notset"
+ )
+ self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset")
mock_model_api_validated = mock_model_api.validate()
# Validate it doesn't modify object in place
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset')
- self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset')
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(
+ getattr(mock_model_api, "permission_grants", "notset"), "notset"
+ )
+ self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset")
# Verify cleaned object
self.assertEqual(mock_model_api_validated.id, None)
self.assertEqual(mock_model_api_validated.permission_grants, [])
- self.assertEqual(getattr(mock_model_api_validated, 'parameters', 'notset'), 'notset')
+ self.assertEqual(
+ getattr(mock_model_api_validated, "parameters", "notset"), "notset"
+ )
def test_validate_allow_default_none_for_any_type(self):
- mock_model_api = MockAPIModel2(permission_grants=[{'description': 'test'}],
- parameters={'name': 'test'})
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}])
- self.assertEqual(mock_model_api.parameters, {'name': 'test'})
+ mock_model_api = MockAPIModel2(
+ permission_grants=[{"description": "test"}], parameters={"name": "test"}
+ )
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}])
+ self.assertEqual(mock_model_api.parameters, {"name": "test"})
mock_model_api_validated = mock_model_api.validate()
# Validate it doesn't modify object in place
- self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset')
- self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}])
- self.assertEqual(mock_model_api.parameters, {'name': 'test'})
+ self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset")
+ self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}])
+ self.assertEqual(mock_model_api.parameters, {"name": "test"})
# Verify cleaned object
self.assertEqual(mock_model_api_validated.id, None)
- self.assertEqual(mock_model_api_validated.permission_grants,
- [{'description': 'test', 'resource_uid': None}])
- self.assertEqual(mock_model_api_validated.parameters, {'id': None, 'name': 'test'})
+ self.assertEqual(
+ mock_model_api_validated.permission_grants,
+ [{"description": "test", "resource_uid": None}],
+ )
+ self.assertEqual(
+ mock_model_api_validated.parameters, {"id": None, "name": "test"}
+ )
diff --git a/st2common/tests/unit/test_casts.py b/st2common/tests/unit/test_casts.py
index 55e95ca781..62bf0ac4e8 100644
--- a/st2common/tests/unit/test_casts.py
+++ b/st2common/tests/unit/test_casts.py
@@ -23,19 +23,19 @@
class CastsTestCase(unittest2.TestCase):
def test_cast_string(self):
- cast_func = get_cast('string')
+ cast_func = get_cast("string")
- value = 'test1'
+ value = "test1"
result = cast_func(value)
- self.assertEqual(result, 'test1')
+ self.assertEqual(result, "test1")
- value = u'test2'
+ value = "test2"
result = cast_func(value)
- self.assertEqual(result, u'test2')
+ self.assertEqual(result, "test2")
- value = ''
+ value = ""
result = cast_func(value)
- self.assertEqual(result, '')
+ self.assertEqual(result, "")
# None should be preserved
value = None
@@ -48,7 +48,7 @@ def test_cast_string(self):
self.assertRaisesRegexp(ValueError, expected_msg, cast_func, value)
def test_cast_array(self):
- cast_func = get_cast('array')
+ cast_func = get_cast("array")
# Python literal
value = str([1, 2, 3])
diff --git a/st2common/tests/unit/test_config_loader.py b/st2common/tests/unit/test_config_loader.py
index f59e3efe4f..e1849d7868 100644
--- a/st2common/tests/unit/test_config_loader.py
+++ b/st2common/tests/unit/test_config_loader.py
@@ -24,9 +24,7 @@
from st2tests.base import CleanDbTestCase
-__all__ = [
- 'ContentPackConfigLoaderTestCase'
-]
+__all__ = ["ContentPackConfigLoaderTestCase"]
class ContentPackConfigLoaderTestCase(CleanDbTestCase):
@@ -37,7 +35,7 @@ def test_ensure_local_pack_config_feature_removed(self):
# Test a scenario where all the values are loaded from pack local
# config and pack global config (pack name.yaml) doesn't exist.
# Test a scenario where no values are overridden in the datastore
- loader = ContentPackConfigLoader(pack_name='dummy_pack_4')
+ loader = ContentPackConfigLoader(pack_name="dummy_pack_4")
config = loader.get_config()
expected_config = {}
@@ -46,35 +44,39 @@ def test_ensure_local_pack_config_feature_removed(self):
def test_get_config_some_values_overriden_in_datastore(self):
# Test a scenario where some values are overriden in datastore via pack
# flobal config
- kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5',
- key_name='api_secret',
- value='some_api_secret',
- secret=True,
- user='joe')
+ kvp_db = set_datastore_value_for_config_key(
+ pack_name="dummy_pack_5",
+ key_name="api_secret",
+ value="some_api_secret",
+ secret=True,
+ user="joe",
+ )
# This is a secret so a value should be encrypted
- self.assertTrue(kvp_db.value != 'some_api_secret')
- self.assertTrue(len(kvp_db.value) > len('some_api_secret') * 2)
+ self.assertTrue(kvp_db.value != "some_api_secret")
+ self.assertTrue(len(kvp_db.value) > len("some_api_secret") * 2)
self.assertTrue(kvp_db.secret)
- kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5',
- key_name='private_key_path',
- value='some_private_key')
- self.assertEqual(kvp_db.value, 'some_private_key')
+ kvp_db = set_datastore_value_for_config_key(
+ pack_name="dummy_pack_5",
+ key_name="private_key_path",
+ value="some_private_key",
+ )
+ self.assertEqual(kvp_db.value, "some_private_key")
self.assertFalse(kvp_db.secret)
- loader = ContentPackConfigLoader(pack_name='dummy_pack_5', user='joe')
+ loader = ContentPackConfigLoader(pack_name="dummy_pack_5", user="joe")
config = loader.get_config()
# regions is provided in the pack global config
# api_secret is dynamically loaded from the datastore for a particular user
expected_config = {
- 'api_key': 'some_api_key',
- 'api_secret': 'some_api_secret',
- 'regions': ['us-west-1'],
- 'region': 'default-region-value',
- 'private_key_path': 'some_private_key',
- 'non_required_with_default_value': 'config value'
+ "api_key": "some_api_key",
+ "api_secret": "some_api_secret",
+ "regions": ["us-west-1"],
+ "region": "default-region-value",
+ "private_key_path": "some_private_key",
+ "non_required_with_default_value": "config value",
}
self.assertEqual(config, expected_config)
@@ -82,26 +84,26 @@ def test_get_config_some_values_overriden_in_datastore(self):
def test_get_config_default_value_from_config_schema_is_used(self):
# No value is provided for "region" in the config, default value from config schema
# should be used
- loader = ContentPackConfigLoader(pack_name='dummy_pack_5')
+ loader = ContentPackConfigLoader(pack_name="dummy_pack_5")
config = loader.get_config()
- self.assertEqual(config['region'], 'default-region-value')
+ self.assertEqual(config["region"], "default-region-value")
# Here a default value is specified in schema but an explicit value is provided in the
# config
- loader = ContentPackConfigLoader(pack_name='dummy_pack_1')
+ loader = ContentPackConfigLoader(pack_name="dummy_pack_1")
config = loader.get_config()
- self.assertEqual(config['region'], 'us-west-1')
+ self.assertEqual(config["region"], "us-west-1")
# Config item attribute has required: false
# Value is provided in the config - it should be used as provided
- pack_name = 'dummy_pack_5'
+ pack_name = "dummy_pack_5"
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
- self.assertEqual(config['non_required_with_default_value'], 'config value')
+ self.assertEqual(config["non_required_with_default_value"], "config value")
config_db = Config.get_by_pack(pack_name)
- del config_db['values']['non_required_with_default_value']
+ del config_db["values"]["non_required_with_default_value"]
Config.add_or_update(config_db)
# No value in the config - default value should be used
@@ -111,10 +113,12 @@ def test_get_config_default_value_from_config_schema_is_used(self):
# No config exists for that pack - default value should be used
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
- self.assertEqual(config['non_required_with_default_value'], 'some default value')
+ self.assertEqual(
+ config["non_required_with_default_value"], "some default value"
+ )
def test_default_values_from_schema_are_used_when_no_config_exists(self):
- pack_name = 'dummy_pack_5'
+ pack_name = "dummy_pack_5"
config_db = Config.get_by_pack(pack_name)
# Delete the existing config loaded in setUp
@@ -122,37 +126,37 @@ def test_default_values_from_schema_are_used_when_no_config_exists(self):
config_db.delete()
# Verify config has been deleted from the database
- self.assertRaises(StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name
+ )
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
- self.assertEqual(config['region'], 'default-region-value')
+ self.assertEqual(config["region"], "default-region-value")
def test_default_values_are_used_when_default_values_are_falsey(self):
- pack_name = 'dummy_pack_17'
+ pack_name = "dummy_pack_17"
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
# 1. Default values are used
- self.assertEqual(config['key_with_default_falsy_value_1'], False)
- self.assertEqual(config['key_with_default_falsy_value_2'], None)
- self.assertEqual(config['key_with_default_falsy_value_3'], {})
- self.assertEqual(config['key_with_default_falsy_value_4'], '')
- self.assertEqual(config['key_with_default_falsy_value_5'], 0)
- self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False)
- self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], 0)
+ self.assertEqual(config["key_with_default_falsy_value_1"], False)
+ self.assertEqual(config["key_with_default_falsy_value_2"], None)
+ self.assertEqual(config["key_with_default_falsy_value_3"], {})
+ self.assertEqual(config["key_with_default_falsy_value_4"], "")
+ self.assertEqual(config["key_with_default_falsy_value_5"], 0)
+ self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False)
+ self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], 0)
# 2. Default values are overwrriten with config values which are also falsey
values = {
- 'key_with_default_falsy_value_1': 0,
- 'key_with_default_falsy_value_2': '',
- 'key_with_default_falsy_value_3': False,
- 'key_with_default_falsy_value_4': None,
- 'key_with_default_falsy_value_5': {},
- 'key_with_default_falsy_value_6': {
- 'key_2': False
- }
+ "key_with_default_falsy_value_1": 0,
+ "key_with_default_falsy_value_2": "",
+ "key_with_default_falsy_value_3": False,
+ "key_with_default_falsy_value_4": None,
+ "key_with_default_falsy_value_5": {},
+ "key_with_default_falsy_value_6": {"key_2": False},
}
config_db = ConfigDB(pack=pack_name, values=values)
config_db = Config.add_or_update(config_db)
@@ -160,301 +164,296 @@ def test_default_values_are_used_when_default_values_are_falsey(self):
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
- self.assertEqual(config['key_with_default_falsy_value_1'], 0)
- self.assertEqual(config['key_with_default_falsy_value_2'], '')
- self.assertEqual(config['key_with_default_falsy_value_3'], False)
- self.assertEqual(config['key_with_default_falsy_value_4'], None)
- self.assertEqual(config['key_with_default_falsy_value_5'], {})
- self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False)
- self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], False)
+ self.assertEqual(config["key_with_default_falsy_value_1"], 0)
+ self.assertEqual(config["key_with_default_falsy_value_2"], "")
+ self.assertEqual(config["key_with_default_falsy_value_3"], False)
+ self.assertEqual(config["key_with_default_falsy_value_4"], None)
+ self.assertEqual(config["key_with_default_falsy_value_5"], {})
+ self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False)
+ self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], False)
def test_get_config_nested_schema_default_values_from_config_schema_are_used(self):
# Special case for more complex config schemas with attributes ntesting.
# Validate that the default values are also used for one level nested object properties.
- pack_name = 'dummy_pack_schema_with_nested_object_1'
+ pack_name = "dummy_pack_schema_with_nested_object_1"
# 1. None of the nested object values are provided
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
expected_config = {
- 'api_key': '',
- 'api_secret': '',
- 'regions': ['us-west-1', 'us-east-1'],
- 'auth_settings': {
- 'host': '127.0.0.3',
- 'port': 8080,
- 'device_uids': ['a', 'b', 'c']
- }
+ "api_key": "",
+ "api_secret": "",
+ "regions": ["us-west-1", "us-east-1"],
+ "auth_settings": {
+ "host": "127.0.0.3",
+ "port": 8080,
+ "device_uids": ["a", "b", "c"],
+ },
}
self.assertEqual(config, expected_config)
# 2. Some of the nested object values are provided (host, port)
- pack_name = 'dummy_pack_schema_with_nested_object_2'
+ pack_name = "dummy_pack_schema_with_nested_object_2"
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
expected_config = {
- 'api_key': '',
- 'api_secret': '',
- 'regions': ['us-west-1', 'us-east-1'],
- 'auth_settings': {
- 'host': '127.0.0.6',
- 'port': 9090,
- 'device_uids': ['a', 'b', 'c']
- }
+ "api_key": "",
+ "api_secret": "",
+ "regions": ["us-west-1", "us-east-1"],
+ "auth_settings": {
+ "host": "127.0.0.6",
+ "port": 9090,
+ "device_uids": ["a", "b", "c"],
+ },
}
self.assertEqual(config, expected_config)
# 3. Nested attribute (auth_settings.token) references a non-secret datastore value
- pack_name = 'dummy_pack_schema_with_nested_object_3'
-
- kvp_db = set_datastore_value_for_config_key(pack_name=pack_name,
- key_name='auth_settings_token',
- value='some_auth_settings_token')
- self.assertEqual(kvp_db.value, 'some_auth_settings_token')
+ pack_name = "dummy_pack_schema_with_nested_object_3"
+
+ kvp_db = set_datastore_value_for_config_key(
+ pack_name=pack_name,
+ key_name="auth_settings_token",
+ value="some_auth_settings_token",
+ )
+ self.assertEqual(kvp_db.value, "some_auth_settings_token")
self.assertFalse(kvp_db.secret)
loader = ContentPackConfigLoader(pack_name=pack_name)
config = loader.get_config()
expected_config = {
- 'api_key': '',
- 'api_secret': '',
- 'regions': ['us-west-1', 'us-east-1'],
- 'auth_settings': {
- 'host': '127.0.0.10',
- 'port': 8080,
- 'device_uids': ['a', 'b', 'c'],
- 'token': 'some_auth_settings_token'
- }
+ "api_key": "",
+ "api_secret": "",
+ "regions": ["us-west-1", "us-east-1"],
+ "auth_settings": {
+ "host": "127.0.0.10",
+ "port": 8080,
+ "device_uids": ["a", "b", "c"],
+ "token": "some_auth_settings_token",
+ },
}
self.assertEqual(config, expected_config)
# 4. Nested attribute (auth_settings.token) references a secret datastore value
- pack_name = 'dummy_pack_schema_with_nested_object_4'
-
- kvp_db = set_datastore_value_for_config_key(pack_name=pack_name,
- key_name='auth_settings_token',
- value='joe_token_secret',
- secret=True,
- user='joe')
- self.assertTrue(kvp_db.value != 'joe_token_secret')
- self.assertTrue(len(kvp_db.value) > len('joe_token_secret') * 2)
+ pack_name = "dummy_pack_schema_with_nested_object_4"
+
+ kvp_db = set_datastore_value_for_config_key(
+ pack_name=pack_name,
+ key_name="auth_settings_token",
+ value="joe_token_secret",
+ secret=True,
+ user="joe",
+ )
+ self.assertTrue(kvp_db.value != "joe_token_secret")
+ self.assertTrue(len(kvp_db.value) > len("joe_token_secret") * 2)
self.assertTrue(kvp_db.secret)
- kvp_db = set_datastore_value_for_config_key(pack_name=pack_name,
- key_name='auth_settings_token',
- value='alice_token_secret',
- secret=True,
- user='alice')
- self.assertTrue(kvp_db.value != 'alice_token_secret')
- self.assertTrue(len(kvp_db.value) > len('alice_token_secret') * 2)
+ kvp_db = set_datastore_value_for_config_key(
+ pack_name=pack_name,
+ key_name="auth_settings_token",
+ value="alice_token_secret",
+ secret=True,
+ user="alice",
+ )
+ self.assertTrue(kvp_db.value != "alice_token_secret")
+ self.assertTrue(len(kvp_db.value) > len("alice_token_secret") * 2)
self.assertTrue(kvp_db.secret)
- loader = ContentPackConfigLoader(pack_name=pack_name, user='joe')
+ loader = ContentPackConfigLoader(pack_name=pack_name, user="joe")
config = loader.get_config()
expected_config = {
- 'api_key': '',
- 'api_secret': '',
- 'regions': ['us-west-1', 'us-east-1'],
- 'auth_settings': {
- 'host': '127.0.0.11',
- 'port': 8080,
- 'device_uids': ['a', 'b', 'c'],
- 'token': 'joe_token_secret'
- }
+ "api_key": "",
+ "api_secret": "",
+ "regions": ["us-west-1", "us-east-1"],
+ "auth_settings": {
+ "host": "127.0.0.11",
+ "port": 8080,
+ "device_uids": ["a", "b", "c"],
+ "token": "joe_token_secret",
+ },
}
self.assertEqual(config, expected_config)
- loader = ContentPackConfigLoader(pack_name=pack_name, user='alice')
+ loader = ContentPackConfigLoader(pack_name=pack_name, user="alice")
config = loader.get_config()
expected_config = {
- 'api_key': '',
- 'api_secret': '',
- 'regions': ['us-west-1', 'us-east-1'],
- 'auth_settings': {
- 'host': '127.0.0.11',
- 'port': 8080,
- 'device_uids': ['a', 'b', 'c'],
- 'token': 'alice_token_secret'
- }
+ "api_key": "",
+ "api_secret": "",
+ "regions": ["us-west-1", "us-east-1"],
+ "auth_settings": {
+ "host": "127.0.0.11",
+ "port": 8080,
+ "device_uids": ["a", "b", "c"],
+ "token": "alice_token_secret",
+ },
}
self.assertEqual(config, expected_config)
- def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown(self):
- pack_name = 'dummy_pack_schema_with_nested_object_5'
+ def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown(
+ self,
+ ):
+ pack_name = "dummy_pack_schema_with_nested_object_5"
loader = ContentPackConfigLoader(pack_name=pack_name)
# Render fails on top-level item
- values = {
- 'level0_key': '{{st2kvXX.invalid}}'
- }
+ values = {"level0_key": "{{st2kvXX.invalid}}"}
config_db = ConfigDB(pack=pack_name, values=values)
config_db = Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key "level0_key" with '
- 'value "{{st2kvXX.invalid}}" for pack ".*?" config: '
- ' '
- '\'st2kvXX\' is undefined')
+ expected_msg = (
+ 'Failed to render dynamic configuration value for key "level0_key" with '
+ 'value "{{st2kvXX.invalid}}" for pack ".*?" config: '
+ " "
+ "'st2kvXX' is undefined"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
# Renders fails on fist level item
- values = {
- 'level0_object': {
- 'level1_key': '{{st2kvXX.invalid}}'
- }
- }
+ values = {"level0_object": {"level1_key": "{{st2kvXX.invalid}}"}}
config_db = ConfigDB(pack=pack_name, values=values)
Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key '
- '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"'
- ' for pack ".*?" config: '
- ' \'st2kvXX\' is undefined')
+ expected_msg = (
+ "Failed to render dynamic configuration value for key "
+ '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"'
+ " for pack \".*?\" config: "
+ " 'st2kvXX' is undefined"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
# Renders fails on second level item
values = {
- 'level0_object': {
- 'level1_object': {
- 'level2_key': '{{st2kvXX.invalid}}'
- }
- }
+ "level0_object": {"level1_object": {"level2_key": "{{st2kvXX.invalid}}"}}
}
config_db = ConfigDB(pack=pack_name, values=values)
Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key '
- '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"'
- ' for pack ".*?" config: '
- ' \'st2kvXX\' is undefined')
+ expected_msg = (
+ "Failed to render dynamic configuration value for key "
+ '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"'
+ " for pack \".*?\" config: "
+ " 'st2kvXX' is undefined"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
# Renders fails on list item
- values = {
- 'level0_object': [
- 'abc',
- '{{st2kvXX.invalid}}'
- ]
- }
+ values = {"level0_object": ["abc", "{{st2kvXX.invalid}}"]}
config_db = ConfigDB(pack=pack_name, values=values)
Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key '
- '"level0_object.1" with value "{{st2kvXX.invalid}}"'
- ' for pack ".*?" config: '
- ' \'st2kvXX\' is undefined')
+ expected_msg = (
+ "Failed to render dynamic configuration value for key "
+ '"level0_object.1" with value "{{st2kvXX.invalid}}"'
+ " for pack \".*?\" config: "
+ " 'st2kvXX' is undefined"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
# Renders fails on nested object in list item
- values = {
- 'level0_object': [
- {'level2_key': '{{st2kvXX.invalid}}'}
- ]
- }
+ values = {"level0_object": [{"level2_key": "{{st2kvXX.invalid}}"}]}
config_db = ConfigDB(pack=pack_name, values=values)
Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key '
- '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"'
- ' for pack ".*?" config: '
- ' \'st2kvXX\' is undefined')
+ expected_msg = (
+ "Failed to render dynamic configuration value for key "
+ '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"'
+ " for pack \".*?\" config: "
+ " 'st2kvXX' is undefined"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
# Renders fails on invalid syntax
- values = {
- 'level0_key': '{{ this is some invalid Jinja }}'
- }
+ values = {"level0_key": "{{ this is some invalid Jinja }}"}
config_db = ConfigDB(pack=pack_name, values=values)
Config.add_or_update(config_db)
- expected_msg = ('Failed to render dynamic configuration value for key '
- '"level0_key" with value "{{ this is some invalid Jinja }}"'
- ' for pack ".*?" config: '
- ' expected token \'end of print statement\', got \'Jinja\'')
+ expected_msg = (
+ "Failed to render dynamic configuration value for key "
+ '"level0_key" with value "{{ this is some invalid Jinja }}"'
+ " for pack \".*?\" config: "
+ " expected token 'end of print statement', got 'Jinja'"
+ )
self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config)
config_db.delete()
def test_get_config_dynamic_config_item(self):
- pack_name = 'dummy_pack_schema_with_nested_object_6'
+ pack_name = "dummy_pack_schema_with_nested_object_6"
loader = ContentPackConfigLoader(pack_name=pack_name)
####################
# value in top level item
- KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1'))
- values = {
- 'level0_key': '{{st2kv.system.k1}}'
- }
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1"))
+ values = {"level0_key": "{{st2kv.system.k1}}"}
config_db = ConfigDB(pack=pack_name, values=values)
config_db = Config.add_or_update(config_db)
config_rendered = loader.get_config()
- self.assertEqual(config_rendered, {'level0_key': 'v1'})
+ self.assertEqual(config_rendered, {"level0_key": "v1"})
config_db.delete()
def test_get_config_dynamic_config_item_nested_dict(self):
- pack_name = 'dummy_pack_schema_with_nested_object_7'
+ pack_name = "dummy_pack_schema_with_nested_object_7"
loader = ContentPackConfigLoader(pack_name=pack_name)
- KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2'))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2"))
####################
# values nested dictionaries
values = {
- 'level0_key': '{{st2kv.system.k0}}',
- 'level0_object': {
- 'level1_key': '{{st2kv.system.k1}}',
- 'level1_object': {
- 'level2_key': '{{st2kv.system.k2}}'
- }
- }
+ "level0_key": "{{st2kv.system.k0}}",
+ "level0_object": {
+ "level1_key": "{{st2kv.system.k1}}",
+ "level1_object": {"level2_key": "{{st2kv.system.k2}}"},
+ },
}
config_db = ConfigDB(pack=pack_name, values=values)
config_db = Config.add_or_update(config_db)
config_rendered = loader.get_config()
- self.assertEqual(config_rendered,
- {
- 'level0_key': 'v0',
- 'level0_object': {
- 'level1_key': 'v1',
- 'level1_object': {
- 'level2_key': 'v2'
- }
- }
- })
+ self.assertEqual(
+ config_rendered,
+ {
+ "level0_key": "v0",
+ "level0_object": {
+ "level1_key": "v1",
+ "level1_object": {"level2_key": "v2"},
+ },
+ },
+ )
config_db.delete()
def test_get_config_dynamic_config_item_list(self):
- pack_name = 'dummy_pack_schema_with_nested_object_7'
+ pack_name = "dummy_pack_schema_with_nested_object_7"
loader = ContentPackConfigLoader(pack_name=pack_name)
- KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1'))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1"))
####################
# values in list
values = {
- 'level0_key': [
- 'a',
- '{{st2kv.system.k0}}',
- 'b',
- '{{st2kv.system.k1}}',
+ "level0_key": [
+ "a",
+ "{{st2kv.system.k0}}",
+ "b",
+ "{{st2kv.system.k1}}",
]
}
config_db = ConfigDB(pack=pack_name, values=values)
@@ -462,44 +461,34 @@ def test_get_config_dynamic_config_item_list(self):
config_rendered = loader.get_config()
- self.assertEqual(config_rendered,
- {
- 'level0_key': [
- 'a',
- 'v0',
- 'b',
- 'v1'
- ]
- })
+ self.assertEqual(config_rendered, {"level0_key": ["a", "v0", "b", "v1"]})
config_db.delete()
def test_get_config_dynamic_config_item_nested_list(self):
- pack_name = 'dummy_pack_schema_with_nested_object_8'
+ pack_name = "dummy_pack_schema_with_nested_object_8"
loader = ContentPackConfigLoader(pack_name=pack_name)
- KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2'))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2"))
####################
# values in objects embedded in lists and nested lists
values = {
- 'level0_key': [
- {
- 'level1_key0': '{{st2kv.system.k0}}'
- },
- '{{st2kv.system.k1}}',
+ "level0_key": [
+ {"level1_key0": "{{st2kv.system.k0}}"},
+ "{{st2kv.system.k1}}",
[
- '{{st2kv.system.k0}}',
- '{{st2kv.system.k1}}',
- '{{st2kv.system.k2}}',
+ "{{st2kv.system.k0}}",
+ "{{st2kv.system.k1}}",
+ "{{st2kv.system.k2}}",
],
{
- 'level1_key2': [
- '{{st2kv.system.k2}}',
+ "level1_key2": [
+ "{{st2kv.system.k2}}",
]
- }
+ },
]
}
config_db = ConfigDB(pack=pack_name, values=values)
@@ -507,30 +496,30 @@ def test_get_config_dynamic_config_item_nested_list(self):
config_rendered = loader.get_config()
- self.assertEqual(config_rendered,
- {
- 'level0_key': [
- {
- 'level1_key0': 'v0'
- },
- 'v1',
- [
- 'v0',
- 'v1',
- 'v2',
- ],
- {
- 'level1_key2': [
- 'v2',
- ]
- }
- ]
- })
+ self.assertEqual(
+ config_rendered,
+ {
+ "level0_key": [
+ {"level1_key0": "v0"},
+ "v1",
+ [
+ "v0",
+ "v1",
+ "v2",
+ ],
+ {
+ "level1_key2": [
+ "v2",
+ ]
+ },
+ ]
+ },
+ )
config_db.delete()
def test_empty_config_object_in_the_database(self):
- pack_name = 'dummy_pack_empty_config'
+ pack_name = "dummy_pack_empty_config"
config_db = ConfigDB(pack=pack_name)
config_db = Config.add_or_update(config_db)
diff --git a/st2common/tests/unit/test_config_parser.py b/st2common/tests/unit/test_config_parser.py
index 6dc690b746..fde0385369 100644
--- a/st2common/tests/unit/test_config_parser.py
+++ b/st2common/tests/unit/test_config_parser.py
@@ -27,27 +27,27 @@ def setUp(self):
tests_config.parse_args()
def test_get_config_inexistent_pack(self):
- parser = ContentPackConfigParser(pack_name='inexistent')
+ parser = ContentPackConfigParser(pack_name="inexistent")
config = parser.get_config()
self.assertEqual(config, None)
def test_get_config_no_config(self):
- pack_name = 'dummy_pack_1'
+ pack_name = "dummy_pack_1"
parser = ContentPackConfigParser(pack_name=pack_name)
config = parser.get_config()
self.assertEqual(config, None)
def test_get_config_existing_config(self):
- pack_name = 'dummy_pack_2'
+ pack_name = "dummy_pack_2"
parser = ContentPackConfigParser(pack_name=pack_name)
config = parser.get_config()
- self.assertEqual(config.config['section1']['key1'], 'value1')
- self.assertEqual(config.config['section2']['key10'], 'value10')
+ self.assertEqual(config.config["section1"]["key1"], "value1")
+ self.assertEqual(config.config["section2"]["key10"], "value10")
def test_get_config_for_unicode_char(self):
- pack_name = 'dummy_pack_18'
+ pack_name = "dummy_pack_18"
parser = ContentPackConfigParser(pack_name=pack_name)
config = parser.get_config()
- self.assertEqual(config.config['section1']['key1'], u'测试')
+ self.assertEqual(config.config["section1"]["key1"], "测试")
diff --git a/st2common/tests/unit/test_configs_registrar.py b/st2common/tests/unit/test_configs_registrar.py
index 09d002eb6a..821cec75fa 100644
--- a/st2common/tests/unit/test_configs_registrar.py
+++ b/st2common/tests/unit/test_configs_registrar.py
@@ -30,15 +30,23 @@
from st2tests import fixturesloader
-__all__ = [
- 'ConfigsRegistrarTestCase'
-]
-
-PACK_1_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1')
-PACK_6_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_6')
-PACK_19_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_19')
-PACK_11_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_11')
-PACK_22_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_22')
+__all__ = ["ConfigsRegistrarTestCase"]
+
+PACK_1_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1"
+)
+PACK_6_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_6"
+)
+PACK_19_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_19"
+)
+PACK_11_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_11"
+)
+PACK_22_PATH = os.path.join(
+ fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_22"
+)
class ConfigsRegistrarTestCase(CleanDbTestCase):
@@ -52,7 +60,7 @@ def test_register_configs_for_all_packs(self):
registrar = ConfigsRegistrar(use_pack_cache=False)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_1_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_1_PATH}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_from_packs(base_dirs=packs_base_paths)
@@ -64,9 +72,9 @@ def test_register_configs_for_all_packs(self):
self.assertEqual(len(config_dbs), 1)
config_db = config_dbs[0]
- self.assertEqual(config_db.values['api_key'], '{{st2kv.user.api_key}}')
- self.assertEqual(config_db.values['api_secret'], SUPER_SECRET_PARAMETER)
- self.assertEqual(config_db.values['region'], 'us-west-1')
+ self.assertEqual(config_db.values["api_key"], "{{st2kv.user.api_key}}")
+ self.assertEqual(config_db.values["api_secret"], SUPER_SECRET_PARAMETER)
+ self.assertEqual(config_db.values["region"], "us-west-1")
def test_register_all_configs_invalid_config_no_config_schema(self):
# verify_ configs is on, but ConfigSchema for the pack doesn't exist so
@@ -81,7 +89,7 @@ def test_register_all_configs_invalid_config_no_config_schema(self):
registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=False)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_from_packs(base_dirs=packs_base_paths)
@@ -92,7 +100,9 @@ def test_register_all_configs_invalid_config_no_config_schema(self):
self.assertEqual(len(pack_dbs), 1)
self.assertEqual(len(config_dbs), 1)
- def test_register_all_configs_with_config_schema_validation_validation_failure_1(self):
+ def test_register_all_configs_with_config_schema_validation_validation_failure_1(
+ self,
+ ):
# Verify DB is empty
pack_dbs = Pack.get_all()
config_dbs = Config.get_all()
@@ -100,28 +110,38 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_1
self.assertEqual(len(pack_dbs), 0)
self.assertEqual(len(config_dbs), 0)
- registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True,
- validate_configs=True)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=False, fail_on_failure=True, validate_configs=True
+ )
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH}
# Register ConfigSchema for pack
registrar._register_pack_db = mock.Mock()
- registrar._register_pack(pack_name='dummy_pack_5', pack_dir=PACK_6_PATH)
+ registrar._register_pack(pack_name="dummy_pack_5", pack_dir=PACK_6_PATH)
packs_base_paths = content_utils.get_packs_base_paths()
if six.PY3:
- expected_msg = ('Failed validating attribute "regions" in config for pack '
- '"dummy_pack_6" (.*?): 1000 is not of type \'array\'')
+ expected_msg = (
+ 'Failed validating attribute "regions" in config for pack '
+ "\"dummy_pack_6\" (.*?): 1000 is not of type 'array'"
+ )
else:
- expected_msg = ('Failed validating attribute "regions" in config for pack '
- '"dummy_pack_6" (.*?): 1000 is not of type u\'array\'')
-
- self.assertRaisesRegexp(ValueError, expected_msg,
- registrar.register_from_packs,
- base_dirs=packs_base_paths)
-
- def test_register_all_configs_with_config_schema_validation_validation_failure_2(self):
+ expected_msg = (
+ 'Failed validating attribute "regions" in config for pack '
+ "\"dummy_pack_6\" (.*?): 1000 is not of type u'array'"
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_from_packs,
+ base_dirs=packs_base_paths,
+ )
+
+ def test_register_all_configs_with_config_schema_validation_validation_failure_2(
+ self,
+ ):
# Verify DB is empty
pack_dbs = Pack.get_all()
config_dbs = Config.get_all()
@@ -129,30 +149,40 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_2
self.assertEqual(len(pack_dbs), 0)
self.assertEqual(len(config_dbs), 0)
- registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True,
- validate_configs=True)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=False, fail_on_failure=True, validate_configs=True
+ )
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_19': PACK_19_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_19": PACK_19_PATH}
# Register ConfigSchema for pack
registrar._register_pack_db = mock.Mock()
- registrar._register_pack(pack_name='dummy_pack_19', pack_dir=PACK_19_PATH)
+ registrar._register_pack(pack_name="dummy_pack_19", pack_dir=PACK_19_PATH)
packs_base_paths = content_utils.get_packs_base_paths()
if six.PY3:
- expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack '
- '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type '
- '\'string\'')
+ expected_msg = (
+ 'Failed validating attribute "instances.0.alias" in config for pack '
+ "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type "
+ "'string'"
+ )
else:
- expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack '
- '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type '
- 'u\'string\'')
-
- self.assertRaisesRegexp(ValueError, expected_msg,
- registrar.register_from_packs,
- base_dirs=packs_base_paths)
-
- def test_register_all_configs_with_config_schema_validation_validation_failure_3(self):
+ expected_msg = (
+ 'Failed validating attribute "instances.0.alias" in config for pack '
+ "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type "
+ "u'string'"
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_from_packs,
+ base_dirs=packs_base_paths,
+ )
+
+ def test_register_all_configs_with_config_schema_validation_validation_failure_3(
+ self,
+ ):
# This test checks for values containing "decrypt_kv" jinja filter in the config
# object where keys have "secret: True" set in the schema.
@@ -163,26 +193,34 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_3
self.assertEqual(len(pack_dbs), 0)
self.assertEqual(len(config_dbs), 0)
- registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True,
- validate_configs=True)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=False, fail_on_failure=True, validate_configs=True
+ )
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_11': PACK_11_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_11": PACK_11_PATH}
# Register ConfigSchema for pack
registrar._register_pack_db = mock.Mock()
- registrar._register_pack(pack_name='dummy_pack_11', pack_dir=PACK_11_PATH)
+ registrar._register_pack(pack_name="dummy_pack_11", pack_dir=PACK_11_PATH)
packs_base_paths = content_utils.get_packs_base_paths()
- expected_msg = ('Values specified as "secret: True" in config schema are automatically '
- 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
- 'for such values. Please check the specified values in the config or '
- 'the default values in the schema.')
-
- self.assertRaisesRegexp(ValueError, expected_msg,
- registrar.register_from_packs,
- base_dirs=packs_base_paths)
-
- def test_register_all_configs_with_config_schema_validation_validation_failure_4(self):
+ expected_msg = (
+ 'Values specified as "secret: True" in config schema are automatically '
+ 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
+ "for such values. Please check the specified values in the config or "
+ "the default values in the schema."
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_from_packs,
+ base_dirs=packs_base_paths,
+ )
+
+ def test_register_all_configs_with_config_schema_validation_validation_failure_4(
+ self,
+ ):
# This test checks for default values containing "decrypt_kv" jinja filter for
# keys which have "secret: True" set.
@@ -193,21 +231,27 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_4
self.assertEqual(len(pack_dbs), 0)
self.assertEqual(len(config_dbs), 0)
- registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True,
- validate_configs=True)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=False, fail_on_failure=True, validate_configs=True
+ )
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_22': PACK_22_PATH}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_22": PACK_22_PATH}
# Register ConfigSchema for pack
registrar._register_pack_db = mock.Mock()
- registrar._register_pack(pack_name='dummy_pack_22', pack_dir=PACK_22_PATH)
+ registrar._register_pack(pack_name="dummy_pack_22", pack_dir=PACK_22_PATH)
packs_base_paths = content_utils.get_packs_base_paths()
- expected_msg = ('Values specified as "secret: True" in config schema are automatically '
- 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
- 'for such values. Please check the specified values in the config or '
- 'the default values in the schema.')
-
- self.assertRaisesRegexp(ValueError, expected_msg,
- registrar.register_from_packs,
- base_dirs=packs_base_paths)
+ expected_msg = (
+ 'Values specified as "secret: True" in config schema are automatically '
+ 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed '
+ "for such values. Please check the specified values in the config or "
+ "the default values in the schema."
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_from_packs,
+ base_dirs=packs_base_paths,
+ )
diff --git a/st2common/tests/unit/test_connection_retry_wrapper.py b/st2common/tests/unit/test_connection_retry_wrapper.py
index 8c75ff4955..831ac8c22e 100644
--- a/st2common/tests/unit/test_connection_retry_wrapper.py
+++ b/st2common/tests/unit/test_connection_retry_wrapper.py
@@ -21,19 +21,18 @@
class TestClusterRetryContext(unittest.TestCase):
-
def test_single_node_cluster_retry(self):
retry_context = ClusterRetryContext(cluster_size=1)
should_stop, wait = retry_context.test_should_stop()
- self.assertFalse(should_stop, 'Not done trying.')
+ self.assertFalse(should_stop, "Not done trying.")
self.assertEqual(wait, 10)
should_stop, wait = retry_context.test_should_stop()
- self.assertFalse(should_stop, 'Not done trying.')
+ self.assertFalse(should_stop, "Not done trying.")
self.assertEqual(wait, 10)
should_stop, wait = retry_context.test_should_stop()
- self.assertTrue(should_stop, 'Done trying.')
+ self.assertTrue(should_stop, "Done trying.")
self.assertEqual(wait, -1)
def test_should_stop_second_channel_open_error_should_be_non_fatal(self):
@@ -58,10 +57,10 @@ def test_multiple_node_cluster_retry(self):
for i in range(last_index + 1):
should_stop, wait = retry_context.test_should_stop()
if i == last_index:
- self.assertTrue(should_stop, 'Done trying.')
+ self.assertTrue(should_stop, "Done trying.")
self.assertEqual(wait, -1)
else:
- self.assertFalse(should_stop, 'Not done trying.')
+ self.assertFalse(should_stop, "Not done trying.")
# on cluster boundaries the wait is longer. Short wait when switching
# to a different server within a cluster.
if (i + 1) % cluster_size == 0:
@@ -72,5 +71,5 @@ def test_multiple_node_cluster_retry(self):
def test_zero_node_cluster_retry(self):
retry_context = ClusterRetryContext(cluster_size=0)
should_stop, wait = retry_context.test_should_stop()
- self.assertTrue(should_stop, 'Done trying.')
+ self.assertTrue(should_stop, "Done trying.")
self.assertEqual(wait, -1)
diff --git a/st2common/tests/unit/test_content_loader.py b/st2common/tests/unit/test_content_loader.py
index c20afda87a..8b8e650afb 100644
--- a/st2common/tests/unit/test_content_loader.py
+++ b/st2common/tests/unit/test_content_loader.py
@@ -23,64 +23,81 @@
from st2common.content.loader import LOG
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
-RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources'))
+RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources"))
class ContentLoaderTest(unittest2.TestCase):
def test_get_sensors(self):
- packs_base_path = os.path.join(RESOURCES_DIR, 'packs/')
+ packs_base_path = os.path.join(RESOURCES_DIR, "packs/")
loader = ContentPackLoader()
- pack_sensors = loader.get_content(base_dirs=[packs_base_path], content_type='sensors')
- self.assertIsNotNone(pack_sensors.get('pack1', None))
+ pack_sensors = loader.get_content(
+ base_dirs=[packs_base_path], content_type="sensors"
+ )
+ self.assertIsNotNone(pack_sensors.get("pack1", None))
def test_get_sensors_pack_missing_sensors(self):
loader = ContentPackLoader()
- fail_pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2')
+ fail_pack_path = os.path.join(RESOURCES_DIR, "packs/pack2")
self.assertTrue(os.path.exists(fail_pack_path))
self.assertEqual(loader._get_sensors(fail_pack_path), None)
def test_invalid_content_type(self):
- packs_base_path = os.path.join(RESOURCES_DIR, 'packs/')
+ packs_base_path = os.path.join(RESOURCES_DIR, "packs/")
loader = ContentPackLoader()
- self.assertRaises(ValueError, loader.get_content, base_dirs=[packs_base_path],
- content_type='stuff')
+ self.assertRaises(
+ ValueError,
+ loader.get_content,
+ base_dirs=[packs_base_path],
+ content_type="stuff",
+ )
def test_get_content_multiple_directories(self):
- packs_base_path_1 = os.path.join(RESOURCES_DIR, 'packs/')
- packs_base_path_2 = os.path.join(RESOURCES_DIR, 'packs2/')
+ packs_base_path_1 = os.path.join(RESOURCES_DIR, "packs/")
+ packs_base_path_2 = os.path.join(RESOURCES_DIR, "packs2/")
base_dirs = [packs_base_path_1, packs_base_path_2]
LOG.warning = Mock()
loader = ContentPackLoader()
- sensors = loader.get_content(base_dirs=base_dirs, content_type='sensors')
- self.assertIn('pack1', sensors) # from packs/
- self.assertIn('pack3', sensors) # from packs2/
+ sensors = loader.get_content(base_dirs=base_dirs, content_type="sensors")
+ self.assertIn("pack1", sensors) # from packs/
+ self.assertIn("pack3", sensors) # from packs2/
# Assert that a warning is emitted when a duplicated pack is found
- expected_msg = ('Pack "pack1" already found in '
- '"%s/packs/", ignoring content from '
- '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR))
+ expected_msg = (
+ 'Pack "pack1" already found in '
+ '"%s/packs/", ignoring content from '
+ '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR)
+ )
LOG.warning.assert_called_once_with(expected_msg)
def test_get_content_from_pack_success(self):
loader = ContentPackLoader()
- pack_path = os.path.join(RESOURCES_DIR, 'packs/pack1')
+ pack_path = os.path.join(RESOURCES_DIR, "packs/pack1")
- sensors = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors')
- self.assertTrue(sensors.endswith('packs/pack1/sensors'))
+ sensors = loader.get_content_from_pack(
+ pack_dir=pack_path, content_type="sensors"
+ )
+ self.assertTrue(sensors.endswith("packs/pack1/sensors"))
def test_get_content_from_pack_directory_doesnt_exist(self):
loader = ContentPackLoader()
- pack_path = os.path.join(RESOURCES_DIR, 'packs/pack100')
+ pack_path = os.path.join(RESOURCES_DIR, "packs/pack100")
- message_regex = 'Directory .*? doesn\'t exist'
- self.assertRaisesRegexp(ValueError, message_regex, loader.get_content_from_pack,
- pack_dir=pack_path, content_type='sensors')
+ message_regex = "Directory .*? doesn't exist"
+ self.assertRaisesRegexp(
+ ValueError,
+ message_regex,
+ loader.get_content_from_pack,
+ pack_dir=pack_path,
+ content_type="sensors",
+ )
def test_get_content_from_pack_no_sensors(self):
loader = ContentPackLoader()
- pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2')
+ pack_path = os.path.join(RESOURCES_DIR, "packs/pack2")
- result = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors')
+ result = loader.get_content_from_pack(
+ pack_dir=pack_path, content_type="sensors"
+ )
self.assertEqual(result, None)
diff --git a/st2common/tests/unit/test_content_utils.py b/st2common/tests/unit/test_content_utils.py
index 703c75aa70..523114a613 100644
--- a/st2common/tests/unit/test_content_utils.py
+++ b/st2common/tests/unit/test_content_utils.py
@@ -39,205 +39,260 @@ def setUpClass(cls):
tests_config.parse_args()
def test_get_pack_base_paths(self):
- cfg.CONF.content.system_packs_base_path = ''
- cfg.CONF.content.packs_base_paths = '/opt/path1'
+ cfg.CONF.content.system_packs_base_path = ""
+ cfg.CONF.content.packs_base_paths = "/opt/path1"
result = get_packs_base_paths()
- self.assertEqual(result, ['/opt/path1'])
+ self.assertEqual(result, ["/opt/path1"])
# Multiple paths, no trailing colon
- cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2'
+ cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2"
result = get_packs_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
# Multiple paths, trailing colon
- cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:'
+ cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2:"
result = get_packs_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
# Multiple same paths
- cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2'
+ cfg.CONF.content.packs_base_paths = (
+ "/opt/path1:/opt/path2:/opt/path1:/opt/path2"
+ )
result = get_packs_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
# Assert system path is always first
- cfg.CONF.content.system_packs_base_path = '/opt/system'
- cfg.CONF.content.packs_base_paths = '/opt/path2:/opt/path1'
+ cfg.CONF.content.system_packs_base_path = "/opt/system"
+ cfg.CONF.content.packs_base_paths = "/opt/path2:/opt/path1"
result = get_packs_base_paths()
- self.assertEqual(result, ['/opt/system', '/opt/path2', '/opt/path1'])
+ self.assertEqual(result, ["/opt/system", "/opt/path2", "/opt/path1"])
# More scenarios
orig_path = cfg.CONF.content.system_packs_base_path
- cfg.CONF.content.system_packs_base_path = '/tests/packs'
+ cfg.CONF.content.system_packs_base_path = "/tests/packs"
- names = [
- 'test_pack_1',
- 'test_pack_2',
- 'ma_pack'
- ]
+ names = ["test_pack_1", "test_pack_2", "ma_pack"]
for name in names:
actual = get_pack_base_path(pack_name=name)
- expected = os.path.join(cfg.CONF.content.system_packs_base_path,
- name)
+ expected = os.path.join(cfg.CONF.content.system_packs_base_path, name)
self.assertEqual(actual, expected)
cfg.CONF.content.system_packs_base_path = orig_path
def test_get_aliases_base_paths(self):
- cfg.CONF.content.aliases_base_paths = '/opt/path1'
+ cfg.CONF.content.aliases_base_paths = "/opt/path1"
result = get_aliases_base_paths()
- self.assertEqual(result, ['/opt/path1'])
+ self.assertEqual(result, ["/opt/path1"])
# Multiple paths, no trailing colon
- cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2'
+ cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2"
result = get_aliases_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
# Multiple paths, trailing colon
- cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:'
+ cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2:"
result = get_aliases_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
# Multiple same paths
- cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2'
+ cfg.CONF.content.aliases_base_paths = (
+ "/opt/path1:/opt/path2:/opt/path1:/opt/path2"
+ )
result = get_aliases_base_paths()
- self.assertEqual(result, ['/opt/path1', '/opt/path2'])
+ self.assertEqual(result, ["/opt/path1", "/opt/path2"])
def test_get_pack_resource_file_abs_path(self):
# Mock the packs path to point to the fixtures directory
cfg.CONF.content.packs_base_paths = get_fixtures_packs_base_path()
# Invalid resource type
- expected_msg = 'Invalid resource type: fooo'
- self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path,
- pack_ref='dummy_pack_1',
- resource_type='fooo',
- file_path='test.py')
+ expected_msg = "Invalid resource type: fooo"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_pack_resource_file_abs_path,
+ pack_ref="dummy_pack_1",
+ resource_type="fooo",
+ file_path="test.py",
+ )
# Invalid paths (directory traversal and absolute paths)
- file_paths = ['/tmp/foo.py', '../foo.py', '/etc/passwd', '../../foo.py',
- '/opt/stackstorm/packs/invalid_pack/actions/my_action.py',
- '../../foo.py']
+ file_paths = [
+ "/tmp/foo.py",
+ "../foo.py",
+ "/etc/passwd",
+ "../../foo.py",
+ "/opt/stackstorm/packs/invalid_pack/actions/my_action.py",
+ "../../foo.py",
+ ]
for file_path in file_paths:
# action resource_type
- expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the '
- r'pack actions directory (.*). For example "my_action.py"\.' %
- (file_path))
- self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path,
- pack_ref='dummy_pack_1',
- resource_type='action',
- file_path=file_path)
+ expected_msg = (
+ r'Invalid file path: ".*%s"\. File path needs to be relative to the '
+ r'pack actions directory (.*). For example "my_action.py"\.'
+ % (file_path)
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_pack_resource_file_abs_path,
+ pack_ref="dummy_pack_1",
+ resource_type="action",
+ file_path=file_path,
+ )
# sensor resource_type
- expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the '
- r'pack sensors directory (.*). For example "my_sensor.py"\.' %
- (file_path))
- self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path,
- pack_ref='dummy_pack_1',
- resource_type='sensor',
- file_path=file_path)
+ expected_msg = (
+ r'Invalid file path: ".*%s"\. File path needs to be relative to the '
+ r'pack sensors directory (.*). For example "my_sensor.py"\.'
+ % (file_path)
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_pack_resource_file_abs_path,
+ pack_ref="dummy_pack_1",
+ resource_type="sensor",
+ file_path=file_path,
+ )
# no resource type
- expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the '
- r'pack directory (.*). For example "my_action.py"\.' %
- (file_path))
- self.assertRaisesRegexp(ValueError, expected_msg, get_pack_file_abs_path,
- pack_ref='dummy_pack_1',
- file_path=file_path)
+ expected_msg = (
+ r'Invalid file path: ".*%s"\. File path needs to be relative to the '
+ r'pack directory (.*). For example "my_action.py"\.' % (file_path)
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_pack_file_abs_path,
+ pack_ref="dummy_pack_1",
+ file_path=file_path,
+ )
# Valid paths
- file_paths = ['foo.py', 'a/foo.py', 'a/b/foo.py']
+ file_paths = ["foo.py", "a/foo.py", "a/b/foo.py"]
for file_path in file_paths:
- expected = os.path.join(get_fixtures_packs_base_path(),
- 'dummy_pack_1/actions', file_path)
- result = get_pack_resource_file_abs_path(pack_ref='dummy_pack_1',
- resource_type='action',
- file_path=file_path)
+ expected = os.path.join(
+ get_fixtures_packs_base_path(), "dummy_pack_1/actions", file_path
+ )
+ result = get_pack_resource_file_abs_path(
+ pack_ref="dummy_pack_1", resource_type="action", file_path=file_path
+ )
self.assertEqual(result, expected)
def test_get_entry_point_absolute_path(self):
orig_path = cfg.CONF.content.system_packs_base_path
- cfg.CONF.content.system_packs_base_path = '/tests/packs'
+ cfg.CONF.content.system_packs_base_path = "/tests/packs"
acutal_path = get_entry_point_abs_path(
- pack='foo',
- entry_point='/tests/packs/foo/bar.py')
- self.assertEqual(acutal_path, '/tests/packs/foo/bar.py', 'Entry point path doesn\'t match.')
+ pack="foo", entry_point="/tests/packs/foo/bar.py"
+ )
+ self.assertEqual(
+ acutal_path, "/tests/packs/foo/bar.py", "Entry point path doesn't match."
+ )
cfg.CONF.content.system_packs_base_path = orig_path
def test_get_entry_point_absolute_path_empty(self):
orig_path = cfg.CONF.content.system_packs_base_path
- cfg.CONF.content.system_packs_base_path = '/tests/packs'
- acutal_path = get_entry_point_abs_path(pack='foo', entry_point=None)
- self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.')
- acutal_path = get_entry_point_abs_path(pack='foo', entry_point='')
- self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.')
+ cfg.CONF.content.system_packs_base_path = "/tests/packs"
+ acutal_path = get_entry_point_abs_path(pack="foo", entry_point=None)
+ self.assertEqual(acutal_path, None, "Entry point path doesn't match.")
+ acutal_path = get_entry_point_abs_path(pack="foo", entry_point="")
+ self.assertEqual(acutal_path, None, "Entry point path doesn't match.")
cfg.CONF.content.system_packs_base_path = orig_path
def test_get_entry_point_relative_path(self):
orig_path = cfg.CONF.content.system_packs_base_path
- cfg.CONF.content.system_packs_base_path = '/tests/packs'
- acutal_path = get_entry_point_abs_path(pack='foo', entry_point='foo/bar.py')
- expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions',
- 'foo/bar.py')
- self.assertEqual(acutal_path, expected_path, 'Entry point path doesn\'t match.')
+ cfg.CONF.content.system_packs_base_path = "/tests/packs"
+ acutal_path = get_entry_point_abs_path(pack="foo", entry_point="foo/bar.py")
+ expected_path = os.path.join(
+ cfg.CONF.content.system_packs_base_path, "foo", "actions", "foo/bar.py"
+ )
+ self.assertEqual(acutal_path, expected_path, "Entry point path doesn't match.")
cfg.CONF.content.system_packs_base_path = orig_path
def test_get_action_libs_abs_path(self):
orig_path = cfg.CONF.content.system_packs_base_path
- cfg.CONF.content.system_packs_base_path = '/tests/packs'
+ cfg.CONF.content.system_packs_base_path = "/tests/packs"
# entry point relative.
- acutal_path = get_action_libs_abs_path(pack='foo', entry_point='foo/bar.py')
- expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions',
- os.path.join('foo', ACTION_LIBS_DIR))
- self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.')
+ acutal_path = get_action_libs_abs_path(pack="foo", entry_point="foo/bar.py")
+ expected_path = os.path.join(
+ cfg.CONF.content.system_packs_base_path,
+ "foo",
+ "actions",
+ os.path.join("foo", ACTION_LIBS_DIR),
+ )
+ self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.")
# entry point absolute.
acutal_path = get_action_libs_abs_path(
- pack='foo',
- entry_point='/tests/packs/foo/tmp/foo.py')
- expected_path = os.path.join('/tests/packs/foo/tmp', ACTION_LIBS_DIR)
- self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.')
+ pack="foo", entry_point="/tests/packs/foo/tmp/foo.py"
+ )
+ expected_path = os.path.join("/tests/packs/foo/tmp", ACTION_LIBS_DIR)
+ self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.")
cfg.CONF.content.system_packs_base_path = orig_path
def test_get_relative_path_to_pack_file(self):
packs_base_paths = get_fixtures_packs_base_path()
- pack_ref = 'dummy_pack_1'
+ pack_ref = "dummy_pack_1"
# 1. Valid paths
- file_path = os.path.join(packs_base_paths, 'dummy_pack_1/pack.yaml')
+ file_path = os.path.join(packs_base_paths, "dummy_pack_1/pack.yaml")
result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path)
- self.assertEqual(result, 'pack.yaml')
+ self.assertEqual(result, "pack.yaml")
- file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/action.meta.yaml')
+ file_path = os.path.join(
+ packs_base_paths, "dummy_pack_1/actions/action.meta.yaml"
+ )
result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path)
- self.assertEqual(result, 'actions/action.meta.yaml')
+ self.assertEqual(result, "actions/action.meta.yaml")
- file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/lib/foo.py')
+ file_path = os.path.join(packs_base_paths, "dummy_pack_1/actions/lib/foo.py")
result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path)
- self.assertEqual(result, 'actions/lib/foo.py')
+ self.assertEqual(result, "actions/lib/foo.py")
# Already relative
- file_path = 'actions/lib/foo2.py'
+ file_path = "actions/lib/foo2.py"
result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path)
- self.assertEqual(result, 'actions/lib/foo2.py')
+ self.assertEqual(result, "actions/lib/foo2.py")
# 2. Invalid path - outside pack directory
- expected_msg = r'file_path (.*?) is not located inside the pack directory (.*?)'
-
- file_path = os.path.join(packs_base_paths, 'dummy_pack_2/actions/lib/foo.py')
- self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file,
- pack_ref=pack_ref, file_path=file_path)
-
- file_path = '/tmp/foo/bar.py'
- self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file,
- pack_ref=pack_ref, file_path=file_path)
-
- file_path = os.path.join(packs_base_paths, '../dummy_pack_1/pack.yaml')
- self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file,
- pack_ref=pack_ref, file_path=file_path)
-
- file_path = os.path.join(packs_base_paths, '../../dummy_pack_1/pack.yaml')
- self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file,
- pack_ref=pack_ref, file_path=file_path)
+ expected_msg = r"file_path (.*?) is not located inside the pack directory (.*?)"
+
+ file_path = os.path.join(packs_base_paths, "dummy_pack_2/actions/lib/foo.py")
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_relative_path_to_pack_file,
+ pack_ref=pack_ref,
+ file_path=file_path,
+ )
+
+ file_path = "/tmp/foo/bar.py"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_relative_path_to_pack_file,
+ pack_ref=pack_ref,
+ file_path=file_path,
+ )
+
+ file_path = os.path.join(packs_base_paths, "../dummy_pack_1/pack.yaml")
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_relative_path_to_pack_file,
+ pack_ref=pack_ref,
+ file_path=file_path,
+ )
+
+ file_path = os.path.join(packs_base_paths, "../../dummy_pack_1/pack.yaml")
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ get_relative_path_to_pack_file,
+ pack_ref=pack_ref,
+ file_path=file_path,
+ )
diff --git a/st2common/tests/unit/test_crypto_utils.py b/st2common/tests/unit/test_crypto_utils.py
index 3bd63ecefe..5f8f07fa69 100644
--- a/st2common/tests/unit/test_crypto_utils.py
+++ b/st2common/tests/unit/test_crypto_utils.py
@@ -40,37 +40,32 @@
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'CryptoUtilsTestCase',
- 'CryptoUtilsKeyczarCompatibilityTestCase'
-]
+__all__ = ["CryptoUtilsTestCase", "CryptoUtilsKeyczarCompatibilityTestCase"]
-KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'keyczar_keys/')
+KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "keyczar_keys/")
class CryptoUtilsTestCase(TestCase):
-
@classmethod
def setUpClass(cls):
super(CryptoUtilsTestCase, cls).setUpClass()
CryptoUtilsTestCase.test_crypto_key = AESKey.generate()
def test_symmetric_encrypt_decrypt_short_string_needs_to_be_padded(self):
- original = u'a'
+ original = "a"
crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original)
plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto)
self.assertEqual(plain, original)
def test_symmetric_encrypt_decrypt_utf8_character(self):
values = [
- u'£',
- u'£££',
- u'££££££',
- u'č š hello đ č p ž Ž',
- u'hello 💩',
- u'💩💩💩💩💩'
- u'💩💩💩',
- u'💩😁'
+ "£",
+ "£££",
+ "££££££",
+ "č š hello đ č p ž Ž",
+ "hello 💩",
+ "💩💩💩💩💩" "💩💩💩",
+ "💩😁",
]
for index, original in enumerate(values):
@@ -81,13 +76,13 @@ def test_symmetric_encrypt_decrypt_utf8_character(self):
self.assertEqual(index, (len(values) - 1))
def test_symmetric_encrypt_decrypt(self):
- original = 'secret'
+ original = "secret"
crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original)
plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto)
self.assertEqual(plain, original)
def test_encrypt_output_is_diff_due_to_diff_IV(self):
- original = 'Kami is a little boy.'
+ original = "Kami is a little boy."
cryptos = set()
for _ in range(0, 10000):
@@ -97,7 +92,7 @@ def test_encrypt_output_is_diff_due_to_diff_IV(self):
def test_decrypt_ciphertext_is_too_short(self):
aes_key = AESKey.generate()
- plaintext = 'hello world ponies 1'
+ plaintext = "hello world ponies 1"
encrypted = cryptography_symmetric_encrypt(aes_key, plaintext)
# Verify original non manipulated value can be decrypted
@@ -117,13 +112,18 @@ def test_decrypt_ciphertext_is_too_short(self):
encrypted_malformed = binascii.hexlify(encrypted_malformed)
# Verify corrupted value results in an excpetion
- expected_msg = 'Invalid or malformed ciphertext'
- self.assertRaisesRegexp(ValueError, expected_msg, cryptography_symmetric_decrypt,
- aes_key, encrypted_malformed)
+ expected_msg = "Invalid or malformed ciphertext"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ cryptography_symmetric_decrypt,
+ aes_key,
+ encrypted_malformed,
+ )
def test_exception_is_thrown_on_invalid_hmac_signature(self):
aes_key = AESKey.generate()
- plaintext = 'hello world ponies 2'
+ plaintext = "hello world ponies 2"
encrypted = cryptography_symmetric_encrypt(aes_key, plaintext)
# Verify original non manipulated value can be decrypted
@@ -133,13 +133,18 @@ def test_exception_is_thrown_on_invalid_hmac_signature(self):
# Corrupt the HMAC signature (last part is the HMAC signature)
encrypted_malformed = binascii.unhexlify(encrypted)
encrypted_malformed = encrypted_malformed[:-3]
- encrypted_malformed += b'abc'
+ encrypted_malformed += b"abc"
encrypted_malformed = binascii.hexlify(encrypted_malformed)
# Verify corrupted value results in an excpetion
- expected_msg = 'Signature did not match digest'
- self.assertRaisesRegexp(InvalidSignature, expected_msg, cryptography_symmetric_decrypt,
- aes_key, encrypted_malformed)
+ expected_msg = "Signature did not match digest"
+ self.assertRaisesRegexp(
+ InvalidSignature,
+ expected_msg,
+ cryptography_symmetric_decrypt,
+ aes_key,
+ encrypted_malformed,
+ )
class CryptoUtilsKeyczarCompatibilityTestCase(TestCase):
@@ -150,44 +155,69 @@ class CryptoUtilsKeyczarCompatibilityTestCase(TestCase):
def test_aes_key_class(self):
# 1. Unsupported mode
- expected_msg = 'Unsupported mode: EBC'
- self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a',
- hmac_key_string='b', hmac_key_size=128, mode='EBC')
+ expected_msg = "Unsupported mode: EBC"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ AESKey,
+ aes_key_string="a",
+ hmac_key_string="b",
+ hmac_key_size=128,
+ mode="EBC",
+ )
# 2. AES key is too small
- expected_msg = 'Unsafe key size: 64'
- self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a',
- hmac_key_string='b', hmac_key_size=128, mode='CBC', size=64)
+ expected_msg = "Unsafe key size: 64"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ AESKey,
+ aes_key_string="a",
+ hmac_key_string="b",
+ hmac_key_size=128,
+ mode="CBC",
+ size=64,
+ )
def test_loading_keys_from_keyczar_formatted_key_files(self):
- key_path = os.path.join(KEY_FIXTURES_PATH, 'one.json')
+ key_path = os.path.join(KEY_FIXTURES_PATH, "one.json")
aes_key = read_crypto_key(key_path=key_path)
- self.assertEqual(aes_key.hmac_key_string, 'lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc')
+ self.assertEqual(
+ aes_key.hmac_key_string, "lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc"
+ )
self.assertEqual(aes_key.hmac_key_size, 256)
- self.assertEqual(aes_key.aes_key_string, 'vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58')
- self.assertEqual(aes_key.mode, 'CBC')
+ self.assertEqual(
+ aes_key.aes_key_string, "vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58"
+ )
+ self.assertEqual(aes_key.mode, "CBC")
self.assertEqual(aes_key.size, 256)
- key_path = os.path.join(KEY_FIXTURES_PATH, 'two.json')
+ key_path = os.path.join(KEY_FIXTURES_PATH, "two.json")
aes_key = read_crypto_key(key_path=key_path)
- self.assertEqual(aes_key.hmac_key_string, '92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s')
+ self.assertEqual(
+ aes_key.hmac_key_string, "92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s"
+ )
self.assertEqual(aes_key.hmac_key_size, 256)
- self.assertEqual(aes_key.aes_key_string, 'fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs')
- self.assertEqual(aes_key.mode, 'CBC')
+ self.assertEqual(
+ aes_key.aes_key_string, "fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs"
+ )
+ self.assertEqual(aes_key.mode, "CBC")
self.assertEqual(aes_key.size, 256)
- key_path = os.path.join(KEY_FIXTURES_PATH, 'five.json')
+ key_path = os.path.join(KEY_FIXTURES_PATH, "five.json")
aes_key = read_crypto_key(key_path=key_path)
- self.assertEqual(aes_key.hmac_key_string, 'GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM')
+ self.assertEqual(
+ aes_key.hmac_key_string, "GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM"
+ )
self.assertEqual(aes_key.hmac_key_size, 256)
- self.assertEqual(aes_key.aes_key_string, 'EeBcUcbH14tL0w_fF5siEw')
- self.assertEqual(aes_key.mode, 'CBC')
+ self.assertEqual(aes_key.aes_key_string, "EeBcUcbH14tL0w_fF5siEw")
+ self.assertEqual(aes_key.mode, "CBC")
self.assertEqual(aes_key.size, 128)
def test_key_generation_file_format_is_fully_keyczar_compatible(self):
@@ -197,13 +227,13 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self):
json_parsed = json.loads(key_json)
expected = {
- 'hmacKey': {
- 'hmacKeyString': aes_key.hmac_key_string,
- 'size': aes_key.hmac_key_size
+ "hmacKey": {
+ "hmacKeyString": aes_key.hmac_key_string,
+ "size": aes_key.hmac_key_size,
},
- 'aesKeyString': aes_key.aes_key_string,
- 'mode': aes_key.mode,
- 'size': aes_key.size
+ "aesKeyString": aes_key.aes_key_string,
+ "mode": aes_key.mode,
+ "size": aes_key.size,
}
self.assertEqual(json_parsed, expected)
@@ -211,15 +241,14 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self):
def test_symmetric_encrypt_decrypt_cryptography(self):
key = AESKey.generate()
plaintexts = [
- 'a b c',
- 'ab',
- 'hello foo',
- 'hell',
- 'bar5'
- 'hello hello bar bar hello',
- 'a',
- '',
- 'c'
+ "a b c",
+ "ab",
+ "hello foo",
+ "hell",
+ "bar5" "hello hello bar bar hello",
+ "a",
+ "",
+ "c",
]
for plaintext in plaintexts:
@@ -228,13 +257,13 @@ def test_symmetric_encrypt_decrypt_cryptography(self):
self.assertEqual(decrypted, plaintext)
- @unittest2.skipIf(six.PY3, 'keyczar doesn\'t work under Python 3')
+ @unittest2.skipIf(six.PY3, "keyczar doesn't work under Python 3")
def test_symmetric_encrypt_decrypt_roundtrips_1(self):
encrypt_keys = [
AESKey.generate(),
AESKey.generate(),
AESKey.generate(),
- AESKey.generate()
+ AESKey.generate(),
]
# Verify all keys are unique
@@ -248,7 +277,7 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self):
self.assertEqual(len(aes_key_strings), 4)
self.assertEqual(len(hmac_key_strings), 4)
- plaintext = 'hello world test dummy 8 9 5 1 bar2'
+ plaintext = "hello world test dummy 8 9 5 1 bar2"
# Verify that round trips work and that cryptography based primitives are fully compatible
# with keyczar format
@@ -261,14 +290,19 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self):
self.assertNotEqual(data_enc_keyczar, data_enc_cryptography)
data_dec_keyczar_keyczar = keyczar_symmetric_decrypt(key, data_enc_keyczar)
- data_dec_keyczar_cryptography = keyczar_symmetric_decrypt(key, data_enc_cryptography)
+ data_dec_keyczar_cryptography = keyczar_symmetric_decrypt(
+ key, data_enc_cryptography
+ )
self.assertEqual(data_dec_keyczar_keyczar, plaintext)
self.assertEqual(data_dec_keyczar_cryptography, plaintext)
- data_dec_cryptography_cryptography = cryptography_symmetric_decrypt(key,
- data_enc_cryptography)
- data_dec_cryptography_keyczar = cryptography_symmetric_decrypt(key, data_enc_keyczar)
+ data_dec_cryptography_cryptography = cryptography_symmetric_decrypt(
+ key, data_enc_cryptography
+ )
+ data_dec_cryptography_keyczar = cryptography_symmetric_decrypt(
+ key, data_enc_keyczar
+ )
self.assertEqual(data_dec_cryptography_cryptography, plaintext)
self.assertEqual(data_dec_cryptography_keyczar, plaintext)
diff --git a/st2common/tests/unit/test_datastore.py b/st2common/tests/unit/test_datastore.py
index 30d3c7dc76..1e3dc86d30 100644
--- a/st2common/tests/unit/test_datastore.py
+++ b/st2common/tests/unit/test_datastore.py
@@ -28,12 +28,10 @@
from st2tests import DbTestCase
from st2tests import config
-__all__ = [
- 'DatastoreServiceTestCase'
-]
+__all__ = ["DatastoreServiceTestCase"]
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
-RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources'))
+RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources"))
class DatastoreServiceTestCase(DbTestCase):
@@ -41,9 +39,9 @@ def setUp(self):
super(DatastoreServiceTestCase, self).setUp()
config.parse_args()
- self._datastore_service = BaseDatastoreService(logger=mock.Mock(),
- pack_name='core',
- class_name='TestSensor')
+ self._datastore_service = BaseDatastoreService(
+ logger=mock.Mock(), pack_name="core", class_name="TestSensor"
+ )
self._datastore_service.get_api_client = mock.Mock()
def test_datastore_operations_list_values(self):
@@ -53,14 +51,14 @@ def test_datastore_operations_list_values(self):
self._set_mock_api_client(mock_api_client)
self._datastore_service.list_values(local=True, prefix=None)
- mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:')
- self._datastore_service.list_values(local=True, prefix='ponies')
- mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:ponies')
+ mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:")
+ self._datastore_service.list_values(local=True, prefix="ponies")
+ mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:ponies")
self._datastore_service.list_values(local=False, prefix=None)
mock_api_client.keys.get_all.assert_called_with(prefix=None)
- self._datastore_service.list_values(local=False, prefix='ponies')
- mock_api_client.keys.get_all.assert_called_with(prefix='ponies')
+ self._datastore_service.list_values(local=False, prefix="ponies")
+ mock_api_client.keys.get_all.assert_called_with(prefix="ponies")
# No values in the datastore
mock_api_client = mock.Mock()
@@ -74,11 +72,11 @@ def test_datastore_operations_list_values(self):
# Values in the datastore
kvp1 = KeyValuePair()
- kvp1.name = 'test1'
- kvp1.value = 'bar'
+ kvp1.name = "test1"
+ kvp1.value = "bar"
kvp2 = KeyValuePair()
- kvp2.name = 'test2'
- kvp2.value = 'bar'
+ kvp2.name = "test2"
+ kvp2.value = "bar"
mock_return_value = [kvp1, kvp2]
mock_api_client.keys.get_all.return_value = mock_return_value
self._set_mock_api_client(mock_api_client)
@@ -90,12 +88,12 @@ def test_datastore_operations_list_values(self):
def test_datastore_operations_get_value(self):
mock_api_client = mock.Mock()
kvp1 = KeyValuePair()
- kvp1.name = 'test1'
- kvp1.value = 'bar'
+ kvp1.name = "test1"
+ kvp1.value = "bar"
mock_api_client.keys.get_by_id.return_value = kvp1
self._set_mock_api_client(mock_api_client)
- value = self._datastore_service.get_value(name='test1', local=False)
+ value = self._datastore_service.get_value(name="test1", local=False)
self.assertEqual(value, kvp1.value)
def test_datastore_operations_set_value(self):
@@ -103,10 +101,12 @@ def test_datastore_operations_set_value(self):
mock_api_client.keys.update.return_value = True
self._set_mock_api_client(mock_api_client)
- value = self._datastore_service.set_value(name='test1', value='foo', local=False)
+ value = self._datastore_service.set_value(
+ name="test1", value="foo", local=False
+ )
self.assertTrue(value)
- kvp = mock_api_client.keys.update.call_args[1]['instance']
- self.assertEqual(kvp.value, 'foo')
+ kvp = mock_api_client.keys.update.call_args[1]["instance"]
+ self.assertEqual(kvp.value, "foo")
self.assertEqual(kvp.scope, SYSTEM_SCOPE)
def test_datastore_operations_delete_value(self):
@@ -114,53 +114,69 @@ def test_datastore_operations_delete_value(self):
mock_api_client.keys.delete.return_value = True
self._set_mock_api_client(mock_api_client)
- value = self._datastore_service.delete_value(name='test', local=False)
+ value = self._datastore_service.delete_value(name="test", local=False)
self.assertTrue(value)
def test_datastore_operations_set_encrypted_value(self):
mock_api_client = mock.Mock()
mock_api_client.keys.update.return_value = True
self._set_mock_api_client(mock_api_client)
- value = self._datastore_service.set_value(name='test1', value='foo', local=False,
- encrypt=True)
+ value = self._datastore_service.set_value(
+ name="test1", value="foo", local=False, encrypt=True
+ )
self.assertTrue(value)
- kvp = mock_api_client.keys.update.call_args[1]['instance']
- self.assertEqual(kvp.value, 'foo')
+ kvp = mock_api_client.keys.update.call_args[1]["instance"]
+ self.assertEqual(kvp.value, "foo")
self.assertTrue(kvp.secret)
self.assertEqual(kvp.scope, SYSTEM_SCOPE)
def test_datastore_unsupported_scope(self):
- self.assertRaises(ValueError, self._datastore_service.get_value, name='test1',
- scope='NOT_SYSTEM')
- self.assertRaises(ValueError, self._datastore_service.set_value, name='test1',
- value='foo', scope='NOT_SYSTEM')
- self.assertRaises(ValueError, self._datastore_service.delete_value, name='test1',
- scope='NOT_SYSTEM')
+ self.assertRaises(
+ ValueError,
+ self._datastore_service.get_value,
+ name="test1",
+ scope="NOT_SYSTEM",
+ )
+ self.assertRaises(
+ ValueError,
+ self._datastore_service.set_value,
+ name="test1",
+ value="foo",
+ scope="NOT_SYSTEM",
+ )
+ self.assertRaises(
+ ValueError,
+ self._datastore_service.delete_value,
+ name="test1",
+ scope="NOT_SYSTEM",
+ )
def test_datastore_get_exception(self):
mock_api_client = mock.Mock()
mock_api_client.keys.get_by_id.side_effect = ValueError("Exception test")
self._set_mock_api_client(mock_api_client)
- value = self._datastore_service.get_value(name='test1')
+ value = self._datastore_service.get_value(name="test1")
self.assertEqual(value, None)
def test_datastore_delete_exception(self):
mock_api_client = mock.Mock()
mock_api_client.keys.delete.side_effect = ValueError("Exception test")
self._set_mock_api_client(mock_api_client)
- delete_success = self._datastore_service.delete_value(name='test1')
+ delete_success = self._datastore_service.delete_value(name="test1")
self.assertEqual(delete_success, False)
def test_datastore_token_timeout(self):
- datastore_service = SensorDatastoreService(logger=mock.Mock(),
- pack_name='core',
- class_name='TestSensor',
- api_username='sensor_service')
+ datastore_service = SensorDatastoreService(
+ logger=mock.Mock(),
+ pack_name="core",
+ class_name="TestSensor",
+ api_username="sensor_service",
+ )
mock_api_client = mock.Mock()
kvp1 = KeyValuePair()
- kvp1.name = 'test1'
- kvp1.value = 'bar'
+ kvp1.name = "test1"
+ kvp1.value = "bar"
mock_api_client.keys.get_by_id.return_value = kvp1
token_expire_time = get_datetime_utc_now() - timedelta(seconds=5)
@@ -170,10 +186,9 @@ def test_datastore_token_timeout(self):
self._set_mock_api_client(mock_api_client)
with mock.patch(
- 'st2common.services.datastore.Client',
- return_value=mock_api_client
+ "st2common.services.datastore.Client", return_value=mock_api_client
) as datastore_client:
- value = datastore_service.get_value(name='test1', local=False)
+ value = datastore_service.get_value(name="test1", local=False)
self.assertTrue(datastore_client.called)
self.assertEqual(value, kvp1.value)
self.assertGreater(datastore_service._token_expire, token_expire_time)
diff --git a/st2common/tests/unit/test_date_utils.py b/st2common/tests/unit/test_date_utils.py
index 1b1d3b465c..d453edb8f7 100644
--- a/st2common/tests/unit/test_date_utils.py
+++ b/st2common/tests/unit/test_date_utils.py
@@ -25,44 +25,44 @@
class DateUtilsTestCase(unittest2.TestCase):
def test_get_datetime_utc_now(self):
date = date_utils.get_datetime_utc_now()
- self.assertEqual(date.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(date.tzinfo.tzname(None), "UTC")
def test_add_utc_tz(self):
dt = datetime.datetime.utcnow()
self.assertIsNone(dt.tzinfo)
dt = date_utils.add_utc_tz(dt)
self.assertIsNotNone(dt.tzinfo)
- self.assertEqual(dt.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(dt.tzinfo.tzname(None), "UTC")
def test_convert_to_utc(self):
date_without_tz = datetime.datetime.utcnow()
self.assertEqual(date_without_tz.tzinfo, None)
result = date_utils.convert_to_utc(date_without_tz)
- self.assertEqual(result.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(result.tzinfo.tzname(None), "UTC")
date_with_pdt_tz = datetime.datetime(2015, 10, 28, 10, 0, 0, 0)
- date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone('US/Pacific'))
- self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), 'US/Pacific')
+ date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone("US/Pacific"))
+ self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), "US/Pacific")
result = date_utils.convert_to_utc(date_with_pdt_tz)
- self.assertEqual(str(result), '2015-10-28 17:53:00+00:00')
- self.assertEqual(result.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(str(result), "2015-10-28 17:53:00+00:00")
+ self.assertEqual(result.tzinfo.tzname(None), "UTC")
def test_parse(self):
- date_str_without_tz = 'January 1st, 2014 10:00:00'
+ date_str_without_tz = "January 1st, 2014 10:00:00"
result = date_utils.parse(value=date_str_without_tz)
- self.assertEqual(str(result), '2014-01-01 10:00:00+00:00')
- self.assertEqual(result.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(str(result), "2014-01-01 10:00:00+00:00")
+ self.assertEqual(result.tzinfo.tzname(None), "UTC")
# preserve original tz
- date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00'
+ date_str_with_tz = "January 1st, 2014 10:00:00 +07:00"
result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=True)
- self.assertEqual(str(result), '2014-01-01 10:00:00+07:00')
+ self.assertEqual(str(result), "2014-01-01 10:00:00+07:00")
self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=7))
# convert to utc
- date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00'
+ date_str_with_tz = "January 1st, 2014 10:00:00 +07:00"
result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=False)
- self.assertEqual(str(result), '2014-01-01 03:00:00+00:00')
+ self.assertEqual(str(result), "2014-01-01 03:00:00+00:00")
self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=0))
- self.assertEqual(result.tzinfo.tzname(None), 'UTC')
+ self.assertEqual(result.tzinfo.tzname(None), "UTC")
diff --git a/st2common/tests/unit/test_db.py b/st2common/tests/unit/test_db.py
index 756c0a105e..da0157127e 100644
--- a/st2common/tests/unit/test_db.py
+++ b/st2common/tests/unit/test_db.py
@@ -18,6 +18,7 @@
# NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail.
# See https://github.com/StackStorm/st2/pull/4834 for details
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import ssl
@@ -52,47 +53,50 @@
__all__ = [
- 'DbConnectionTestCase',
- 'DbConnectionTestCase',
- 'ReactorModelTestCase',
- 'ActionModelTestCase',
- 'KeyValuePairModelTestCase'
+ "DbConnectionTestCase",
+ "DbConnectionTestCase",
+ "ReactorModelTestCase",
+ "ActionModelTestCase",
+ "KeyValuePairModelTestCase",
]
SKIP_DELETE = False
-DUMMY_DESCRIPTION = 'Sample Description.'
+DUMMY_DESCRIPTION = "Sample Description."
class DbIndexNameTestCase(TestCase):
"""
Test which verifies that model index name are not longer than the specified limit.
"""
+
LIMIT = 65
def test_index_name_length(self):
- db_name = 'st2'
+ db_name = "st2"
for model in ALL_MODELS:
collection_name = model._get_collection_name()
- model_indexes = model._meta['index_specs']
+ model_indexes = model._meta["index_specs"]
for index_specs in model_indexes:
- index_name = index_specs.get('name', None)
+ index_name = index_specs.get("name", None)
if index_name:
# Custom index name defined by the developer
index_field_name = index_name
else:
# No explicit index name specified, one is auto-generated using
# .. schema
- index_fields = dict(index_specs['fields']).keys()
- index_field_name = '.'.join(index_fields)
+ index_fields = dict(index_specs["fields"]).keys()
+ index_field_name = ".".join(index_fields)
- index_name = '%s.%s.%s' % (db_name, collection_name, index_field_name)
+ index_name = "%s.%s.%s" % (db_name, collection_name, index_field_name)
if len(index_name) > self.LIMIT:
- self.fail('Index name "%s" for model "%s" is longer than %s characters. '
- 'Please manually define name for this index so it\'s shorter than '
- 'that' % (index_name, model.__name__, self.LIMIT))
+ self.fail(
+ 'Index name "%s" for model "%s" is longer than %s characters. '
+ "Please manually define name for this index so it's shorter than "
+ "that" % (index_name, model.__name__, self.LIMIT)
+ )
class DbConnectionTestCase(DbTestCase):
@@ -111,210 +115,293 @@ def test_check_connect(self):
"""
client = mongoengine.connection.get_connection()
- expected_str = "host=['%s:%s']" % (cfg.CONF.database.host, cfg.CONF.database.port)
- self.assertIn(expected_str, str(client), 'Not connected to desired host.')
+ expected_str = "host=['%s:%s']" % (
+ cfg.CONF.database.host,
+ cfg.CONF.database.port,
+ )
+ self.assertIn(expected_str, str(client), "Not connected to desired host.")
def test_get_ssl_kwargs(self):
# 1. No SSL kwargs provided
ssl_kwargs = _get_ssl_kwargs()
- self.assertEqual(ssl_kwargs, {'ssl': False})
+ self.assertEqual(ssl_kwargs, {"ssl": False})
# 2. ssl kwarg provided
ssl_kwargs = _get_ssl_kwargs(ssl=True)
- self.assertEqual(ssl_kwargs, {'ssl': True, 'ssl_match_hostname': True})
+ self.assertEqual(ssl_kwargs, {"ssl": True, "ssl_match_hostname": True})
# 2. authentication_mechanism kwarg provided
- ssl_kwargs = _get_ssl_kwargs(authentication_mechanism='MONGODB-X509')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_match_hostname': True,
- 'authentication_mechanism': 'MONGODB-X509'
- })
+ ssl_kwargs = _get_ssl_kwargs(authentication_mechanism="MONGODB-X509")
+ self.assertEqual(
+ ssl_kwargs,
+ {
+ "ssl": True,
+ "ssl_match_hostname": True,
+ "authentication_mechanism": "MONGODB-X509",
+ },
+ )
# 3. ssl_keyfile provided
- ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_keyfile': '/tmp/keyfile',
- 'ssl_match_hostname': True
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile")
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ssl_keyfile": "/tmp/keyfile", "ssl_match_hostname": True},
+ )
# 4. ssl_certfile provided
- ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_certfile': '/tmp/certfile',
- 'ssl_match_hostname': True
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile")
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ssl_certfile": "/tmp/certfile", "ssl_match_hostname": True},
+ )
# 5. ssl_ca_certs provided
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_ca_certs': '/tmp/ca_certs',
- 'ssl_match_hostname': True
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs")
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ssl_ca_certs": "/tmp/ca_certs", "ssl_match_hostname": True},
+ )
# 6. ssl_ca_certs and ssl_cert_reqs combinations
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_ca_certs': '/tmp/ca_certs',
- 'ssl_cert_reqs': ssl.CERT_NONE,
- 'ssl_match_hostname': True
- })
-
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_ca_certs': '/tmp/ca_certs',
- 'ssl_cert_reqs': ssl.CERT_OPTIONAL,
- 'ssl_match_hostname': True
- })
-
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ssl_ca_certs': '/tmp/ca_certs',
- 'ssl_cert_reqs': ssl.CERT_REQUIRED,
- 'ssl_match_hostname': True
- })
-
- @mock.patch('st2common.models.db.mongoengine')
+ ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none")
+ self.assertEqual(
+ ssl_kwargs,
+ {
+ "ssl": True,
+ "ssl_ca_certs": "/tmp/ca_certs",
+ "ssl_cert_reqs": ssl.CERT_NONE,
+ "ssl_match_hostname": True,
+ },
+ )
+
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional"
+ )
+ self.assertEqual(
+ ssl_kwargs,
+ {
+ "ssl": True,
+ "ssl_ca_certs": "/tmp/ca_certs",
+ "ssl_cert_reqs": ssl.CERT_OPTIONAL,
+ "ssl_match_hostname": True,
+ },
+ )
+
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required"
+ )
+ self.assertEqual(
+ ssl_kwargs,
+ {
+ "ssl": True,
+ "ssl_ca_certs": "/tmp/ca_certs",
+ "ssl_cert_reqs": ssl.CERT_REQUIRED,
+ "ssl_match_hostname": True,
+ },
+ )
+
+ @mock.patch("st2common.models.db.mongoengine")
def test_db_setup(self, mock_mongoengine):
- db_setup(db_name='name', db_host='host', db_port=12345, username='username',
- password='password', authentication_mechanism='MONGODB-X509')
+ db_setup(
+ db_name="name",
+ db_host="host",
+ db_port=12345,
+ username="username",
+ password="password",
+ authentication_mechanism="MONGODB-X509",
+ )
call_args = mock_mongoengine.connection.connect.call_args_list[0][0]
call_kwargs = mock_mongoengine.connection.connect.call_args_list[0][1]
- self.assertEqual(call_args, ('name',))
- self.assertEqual(call_kwargs, {
- 'host': 'host',
- 'port': 12345,
- 'username': 'username',
- 'password': 'password',
- 'tz_aware': True,
- 'authentication_mechanism': 'MONGODB-X509',
- 'ssl': True,
- 'ssl_match_hostname': True,
- 'connectTimeoutMS': 3000,
- 'serverSelectionTimeoutMS': 3000
- })
-
- @mock.patch('st2common.models.db.mongoengine')
- @mock.patch('st2common.models.db.LOG')
+ self.assertEqual(call_args, ("name",))
+ self.assertEqual(
+ call_kwargs,
+ {
+ "host": "host",
+ "port": 12345,
+ "username": "username",
+ "password": "password",
+ "tz_aware": True,
+ "authentication_mechanism": "MONGODB-X509",
+ "ssl": True,
+ "ssl_match_hostname": True,
+ "connectTimeoutMS": 3000,
+ "serverSelectionTimeoutMS": 3000,
+ },
+ )
+
+ @mock.patch("st2common.models.db.mongoengine")
+ @mock.patch("st2common.models.db.LOG")
def test_db_setup_connecting_info_logging(self, mock_log, mock_mongoengine):
# Verify that password is not included in the log message
- db_name = 'st2'
- db_port = '27017'
- username = 'user_st2'
- password = 'pass_st2'
+ db_name = "st2"
+ db_port = "27017"
+ username = "user_st2"
+ password = "pass_st2"
# 1. Password provided as separate argument
- db_host = 'localhost'
- username = 'user_st2'
- password = 'pass_st2'
- db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username,
- password=password)
-
- expected_message = 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".'
+ db_host = "localhost"
+ username = "user_st2"
+ password = "pass_st2"
+ db_setup(
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".'
+ )
actual_message = mock_log.info.call_args_list[0][0][0]
self.assertEqual(expected_message, actual_message)
# Check for helpful error messages if the connection is successful
- expected_log_message = ('Successfully connected to database "st2" @ "localhost:27017" as '
- 'user "user_st2".')
+ expected_log_message = (
+ 'Successfully connected to database "st2" @ "localhost:27017" as '
+ 'user "user_st2".'
+ )
actual_log_message = mock_log.info.call_args_list[1][0][0]
self.assertEqual(expected_log_message, actual_log_message)
# 2. Password provided as part of uri string (single host)
- db_host = 'mongodb://user_st22:pass_st22@127.0.0.2:5555'
+ db_host = "mongodb://user_st22:pass_st22@127.0.0.2:5555"
username = None
password = None
- db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username,
- password=password)
-
- expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".'
+ db_setup(
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".'
+ )
actual_message = mock_log.info.call_args_list[2][0][0]
self.assertEqual(expected_message, actual_message)
- expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as '
- 'user "user_st22".')
+ expected_log_message = (
+ 'Successfully connected to database "st2" @ "127.0.0.2:5555" as '
+ 'user "user_st22".'
+ )
actual_log_message = mock_log.info.call_args_list[3][0][0]
self.assertEqual(expected_log_message, actual_log_message)
# 3. Password provided as part of uri string (single host) - username
# provided as argument has precedence
- db_host = 'mongodb://user_st210:pass_st23@127.0.0.2:5555'
- username = 'user_st23'
+ db_host = "mongodb://user_st210:pass_st23@127.0.0.2:5555"
+ username = "user_st23"
password = None
- db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username,
- password=password)
-
- expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".'
+ db_setup(
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".'
+ )
actual_message = mock_log.info.call_args_list[4][0][0]
self.assertEqual(expected_message, actual_message)
- expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as '
- 'user "user_st23".')
+ expected_log_message = (
+ 'Successfully connected to database "st2" @ "127.0.0.2:5555" as '
+ 'user "user_st23".'
+ )
actual_log_message = mock_log.info.call_args_list[5][0][0]
self.assertEqual(expected_log_message, actual_log_message)
# 4. Just host provided in the url string
- db_host = 'mongodb://127.0.0.2:5555'
- username = 'user_st24'
- password = 'foobar'
- db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username,
- password=password)
-
- expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".'
+ db_host = "mongodb://127.0.0.2:5555"
+ username = "user_st24"
+ password = "foobar"
+ db_setup(
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".'
+ )
actual_message = mock_log.info.call_args_list[6][0][0]
self.assertEqual(expected_message, actual_message)
- expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as '
- 'user "user_st24".')
+ expected_log_message = (
+ 'Successfully connected to database "st2" @ "127.0.0.2:5555" as '
+ 'user "user_st24".'
+ )
actual_log_message = mock_log.info.call_args_list[7][0][0]
self.assertEqual(expected_log_message, actual_log_message)
# 5. Multiple hosts specified as part of connection uri
- db_host = 'mongodb://user6:pass6@host1,host2,host3'
+ db_host = "mongodb://user6:pass6@host1,host2,host3"
username = None
- password = 'foobar'
- db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username,
- password=password)
-
- expected_message = ('Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 '
- '(replica set)" as user "user6".')
+ password = "foobar"
+ db_setup(
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 '
+ '(replica set)" as user "user6".'
+ )
actual_message = mock_log.info.call_args_list[8][0][0]
self.assertEqual(expected_message, actual_message)
- expected_log_message = ('Successfully connected to database "st2" @ '
- '"host1:27017,host2:27017,host3:27017 '
- '(replica set)" as user "user6".')
+ expected_log_message = (
+ 'Successfully connected to database "st2" @ '
+ '"host1:27017,host2:27017,host3:27017 '
+ '(replica set)" as user "user6".'
+ )
actual_log_message = mock_log.info.call_args_list[9][0][0]
self.assertEqual(expected_log_message, actual_log_message)
# 6. Check for error message when failing to establish a connection
mock_connect = mock.Mock()
- mock_connect.admin.command = mock.Mock(side_effect=ConnectionFailure('Failed to connect'))
+ mock_connect.admin.command = mock.Mock(
+ side_effect=ConnectionFailure("Failed to connect")
+ )
mock_mongoengine.connection.connect.return_value = mock_connect
- db_host = 'mongodb://localhost:9797'
- username = 'user_st2'
- password = 'pass_st2'
-
- expected_msg = 'Failed to connect'
- self.assertRaisesRegexp(ConnectionFailure, expected_msg, db_setup,
- db_name=db_name, db_host=db_host, db_port=db_port,
- username=username, password=password)
-
- expected_message = 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".'
+ db_host = "mongodb://localhost:9797"
+ username = "user_st2"
+ password = "pass_st2"
+
+ expected_msg = "Failed to connect"
+ self.assertRaisesRegexp(
+ ConnectionFailure,
+ expected_msg,
+ db_setup,
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ username=username,
+ password=password,
+ )
+
+ expected_message = (
+ 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".'
+ )
actual_message = mock_log.info.call_args_list[10][0][0]
self.assertEqual(expected_message, actual_message)
- expected_message = ('Failed to connect to database "st2" @ "localhost:9797" as user '
- '"user_st2": Failed to connect')
+ expected_message = (
+ 'Failed to connect to database "st2" @ "localhost:9797" as user '
+ '"user_st2": Failed to connect'
+ )
actual_message = mock_log.error.call_args_list[0][0][0]
self.assertEqual(expected_message, actual_message)
@@ -323,29 +410,43 @@ def test_db_connect_server_selection_timeout_ssl_on_non_ssl_listener(self):
# and propagating the error
disconnect()
- db_name = 'st2'
- db_host = 'localhost'
+ db_name = "st2"
+ db_host = "localhost"
db_port = 27017
- cfg.CONF.set_override(name='connection_timeout', group='database', override=1000)
+ cfg.CONF.set_override(
+ name="connection_timeout", group="database", override=1000
+ )
start = time.time()
- self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host,
- db_port=db_port, ssl=True)
+ self.assertRaises(
+ ServerSelectionTimeoutError,
+ db_setup,
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ ssl=True,
+ )
end = time.time()
- diff = (end - start)
+ diff = end - start
self.assertTrue(diff >= 1)
disconnect()
- cfg.CONF.set_override(name='connection_timeout', group='database', override=400)
+ cfg.CONF.set_override(name="connection_timeout", group="database", override=400)
start = time.time()
- self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host,
- db_port=db_port, ssl=True)
+ self.assertRaises(
+ ServerSelectionTimeoutError,
+ db_setup,
+ db_name=db_name,
+ db_host=db_host,
+ db_port=db_port,
+ ssl=True,
+ )
end = time.time()
- diff = (end - start)
+ diff = end - start
self.assertTrue(diff >= 0.4)
@@ -364,60 +465,63 @@ def test_cleanup(self):
self.assertNotIn(cfg.CONF.database.db_name, connection.database_names())
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ReactorModelTestCase(DbTestCase):
-
def test_triggertype_crud(self):
saved = ReactorModelTestCase._create_save_triggertype()
retrieved = TriggerType.get_by_id(saved.id)
- self.assertEqual(saved.name, retrieved.name,
- 'Same triggertype was not returned.')
+ self.assertEqual(
+ saved.name, retrieved.name, "Same triggertype was not returned."
+ )
# test update
- self.assertEqual(retrieved.description, '')
+ self.assertEqual(retrieved.description, "")
retrieved.description = DUMMY_DESCRIPTION
saved = TriggerType.add_or_update(retrieved)
retrieved = TriggerType.get_by_id(saved.id)
- self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.')
+ self.assertEqual(
+ retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed."
+ )
# cleanup
ReactorModelTestCase._delete([retrieved])
try:
retrieved = TriggerType.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_trigger_crud(self):
triggertype = ReactorModelTestCase._create_save_triggertype()
saved = ReactorModelTestCase._create_save_trigger(triggertype)
retrieved = Trigger.get_by_id(saved.id)
- self.assertEqual(saved.name, retrieved.name,
- 'Same trigger was not returned.')
+ self.assertEqual(saved.name, retrieved.name, "Same trigger was not returned.")
# test update
- self.assertEqual(retrieved.description, '')
+ self.assertEqual(retrieved.description, "")
retrieved.description = DUMMY_DESCRIPTION
saved = Trigger.add_or_update(retrieved)
retrieved = Trigger.get_by_id(saved.id)
- self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.')
+ self.assertEqual(
+ retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed."
+ )
# cleanup
ReactorModelTestCase._delete([retrieved, triggertype])
try:
retrieved = Trigger.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_triggerinstance_crud(self):
triggertype = ReactorModelTestCase._create_save_triggertype()
trigger = ReactorModelTestCase._create_save_trigger(triggertype)
saved = ReactorModelTestCase._create_save_triggerinstance(trigger)
retrieved = TriggerInstance.get_by_id(saved.id)
- self.assertIsNotNone(retrieved, 'No triggerinstance created.')
+ self.assertIsNotNone(retrieved, "No triggerinstance created.")
ReactorModelTestCase._delete([retrieved, trigger, triggertype])
try:
retrieved = TriggerInstance.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_rule_crud(self):
triggertype = ReactorModelTestCase._create_save_triggertype()
@@ -426,20 +530,22 @@ def test_rule_crud(self):
action = ActionModelTestCase._create_save_action(runnertype)
saved = ReactorModelTestCase._create_save_rule(trigger, action)
retrieved = Rule.get_by_id(saved.id)
- self.assertEqual(saved.name, retrieved.name, 'Same rule was not returned.')
+ self.assertEqual(saved.name, retrieved.name, "Same rule was not returned.")
# test update
self.assertEqual(retrieved.enabled, True)
retrieved.enabled = False
saved = Rule.add_or_update(retrieved)
retrieved = Rule.get_by_id(saved.id)
- self.assertEqual(retrieved.enabled, False, 'Update to rule failed.')
+ self.assertEqual(retrieved.enabled, False, "Update to rule failed.")
# cleanup
- ReactorModelTestCase._delete([retrieved, trigger, action, runnertype, triggertype])
+ ReactorModelTestCase._delete(
+ [retrieved, trigger, action, runnertype, triggertype]
+ )
try:
retrieved = Rule.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_rule_lookup(self):
triggertype = ReactorModelTestCase._create_save_triggertype()
@@ -447,10 +553,12 @@ def test_rule_lookup(self):
runnertype = ActionModelTestCase._create_save_runnertype()
action = ActionModelTestCase._create_save_action(runnertype)
saved = ReactorModelTestCase._create_save_rule(trigger, action)
- retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger))
- self.assertEqual(1, len(retrievedrules), 'No rules found.')
+ retrievedrules = Rule.query(
+ trigger=reference.get_str_resource_ref_from_model(trigger)
+ )
+ self.assertEqual(1, len(retrievedrules), "No rules found.")
for retrievedrule in retrievedrules:
- self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.')
+ self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.")
ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype])
def test_rule_lookup_enabled(self):
@@ -459,12 +567,12 @@ def test_rule_lookup_enabled(self):
runnertype = ActionModelTestCase._create_save_runnertype()
action = ActionModelTestCase._create_save_action(runnertype)
saved = ReactorModelTestCase._create_save_rule(trigger, action)
- retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger),
- enabled=True)
- self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.')
+ retrievedrules = Rule.query(
+ trigger=reference.get_str_resource_ref_from_model(trigger), enabled=True
+ )
+ self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.")
for retrievedrule in retrievedrules:
- self.assertEqual(saved.id, retrievedrule.id,
- 'Incorrect rule returned.')
+ self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.")
ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype])
def test_rule_lookup_disabled(self):
@@ -473,49 +581,64 @@ def test_rule_lookup_disabled(self):
runnertype = ActionModelTestCase._create_save_runnertype()
action = ActionModelTestCase._create_save_action(runnertype)
saved = ReactorModelTestCase._create_save_rule(trigger, action, False)
- retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger),
- enabled=False)
- self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.')
+ retrievedrules = Rule.query(
+ trigger=reference.get_str_resource_ref_from_model(trigger), enabled=False
+ )
+ self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.")
for retrievedrule in retrievedrules:
- self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.')
+ self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.")
ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype])
def test_trigger_lookup(self):
triggertype = ReactorModelTestCase._create_save_triggertype()
saved = ReactorModelTestCase._create_save_trigger(triggertype)
retrievedtriggers = Trigger.query(name=saved.name)
- self.assertEqual(1, len(retrievedtriggers), 'No triggers found.')
+ self.assertEqual(1, len(retrievedtriggers), "No triggers found.")
for retrievedtrigger in retrievedtriggers:
- self.assertEqual(saved.id, retrievedtrigger.id,
- 'Incorrect trigger returned.')
+ self.assertEqual(
+ saved.id, retrievedtrigger.id, "Incorrect trigger returned."
+ )
ReactorModelTestCase._delete([saved, triggertype])
@staticmethod
def _create_save_triggertype():
- created = TriggerTypeDB(pack='dummy_pack_1', name='triggertype-1', description='',
- payload_schema={}, parameters_schema={})
+ created = TriggerTypeDB(
+ pack="dummy_pack_1",
+ name="triggertype-1",
+ description="",
+ payload_schema={},
+ parameters_schema={},
+ )
return Trigger.add_or_update(created)
@staticmethod
def _create_save_trigger(triggertype):
- created = TriggerDB(pack='dummy_pack_1', name='trigger-1', description='',
- type=triggertype.get_reference().ref, parameters={})
+ created = TriggerDB(
+ pack="dummy_pack_1",
+ name="trigger-1",
+ description="",
+ type=triggertype.get_reference().ref,
+ parameters={},
+ )
return Trigger.add_or_update(created)
@staticmethod
def _create_save_triggerinstance(trigger):
- created = TriggerInstanceDB(trigger=trigger.get_reference().ref, payload={},
- occurrence_time=date_utils.get_datetime_utc_now(),
- status=TRIGGER_INSTANCE_PROCESSED)
+ created = TriggerInstanceDB(
+ trigger=trigger.get_reference().ref,
+ payload={},
+ occurrence_time=date_utils.get_datetime_utc_now(),
+ status=TRIGGER_INSTANCE_PROCESSED,
+ )
return TriggerInstance.add_or_update(created)
@staticmethod
def _create_save_rule(trigger, action=None, enabled=True):
- name = 'rule-1'
- pack = 'default'
+ name = "rule-1"
+ pack = "default"
ref = ResourceReference.to_string_reference(name=name, pack=pack)
created = RuleDB(name=name, pack=pack, ref=ref)
- created.description = ''
+ created.description = ""
created.enabled = enabled
created.trigger = reference.get_str_resource_ref_from_model(trigger)
created.criteria = {}
@@ -547,44 +670,21 @@ def _delete(model_objects):
"description": "awesomeness",
"type": "object",
"properties": {
- "r1": {
- "type": "object",
- "properties": {
- "r1a": {
- "type": "string"
- }
- }
- },
- "r2": {
- "type": "string",
- "required": True
- },
- "p1": {
- "type": "string",
- "required": True
- },
- "p2": {
- "type": "number",
- "default": 2868
- },
- "p3": {
- "type": "boolean",
- "default": False
- },
- "p4": {
- "type": "string",
- "secret": True
- }
+ "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}},
+ "r2": {"type": "string", "required": True},
+ "p1": {"type": "string", "required": True},
+ "p2": {"type": "number", "default": 2868},
+ "p3": {"type": "boolean", "default": False},
+ "p4": {"type": "string", "secret": True},
},
- "additionalProperties": False
+ "additionalProperties": False,
}
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ActionModelTestCase(DbTestCase):
-
def tearDown(self):
- runnertype = RunnerType.get_by_name('python')
+ runnertype = RunnerType.get_by_name("python")
self._delete([runnertype])
super(ActionModelTestCase, self).tearDown()
@@ -592,15 +692,16 @@ def test_action_crud(self):
runnertype = self._create_save_runnertype(metadata=False)
saved = self._create_save_action(runnertype, metadata=False)
retrieved = Action.get_by_id(saved.id)
- self.assertEqual(saved.name, retrieved.name,
- 'Same Action was not returned.')
+ self.assertEqual(saved.name, retrieved.name, "Same Action was not returned.")
# test update
- self.assertEqual(retrieved.description, 'awesomeness')
+ self.assertEqual(retrieved.description, "awesomeness")
retrieved.description = DUMMY_DESCRIPTION
saved = Action.add_or_update(retrieved)
retrieved = Action.get_by_id(saved.id)
- self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to action failed.')
+ self.assertEqual(
+ retrieved.description, DUMMY_DESCRIPTION, "Update to action failed."
+ )
# cleanup
self._delete([retrieved])
@@ -608,14 +709,14 @@ def test_action_crud(self):
retrieved = Action.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_action_with_notify_crud(self):
runnertype = self._create_save_runnertype(metadata=False)
saved = self._create_save_action(runnertype, metadata=False)
# Update action with notification settings
- on_complete = NotificationSubSchema(message='Action complete.')
+ on_complete = NotificationSubSchema(message="Action complete.")
saved.notify = NotificationSchema(on_complete=on_complete)
saved = Action.add_or_update(saved)
@@ -635,7 +736,7 @@ def test_action_with_notify_crud(self):
retrieved = Action.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_parameter_schema(self):
runnertype = self._create_save_runnertype(metadata=True)
@@ -650,13 +751,30 @@ def test_parameter_schema(self):
# use schema to validate parameters
jsonschema.validate({"r2": "abc", "p1": "def"}, schema, validator)
- jsonschema.validate({"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator)
- self.assertRaises(jsonschema.ValidationError, jsonschema.validate,
- '{"r2": "abc", "p1": "def"}', schema, validator)
- self.assertRaises(jsonschema.ValidationError, jsonschema.validate,
- {"r2": "abc"}, schema, validator)
- self.assertRaises(jsonschema.ValidationError, jsonschema.validate,
- {"r2": "abc", "p1": "def", "r1": 123}, schema, validator)
+ jsonschema.validate(
+ {"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator
+ )
+ self.assertRaises(
+ jsonschema.ValidationError,
+ jsonschema.validate,
+ '{"r2": "abc", "p1": "def"}',
+ schema,
+ validator,
+ )
+ self.assertRaises(
+ jsonschema.ValidationError,
+ jsonschema.validate,
+ {"r2": "abc"},
+ schema,
+ validator,
+ )
+ self.assertRaises(
+ jsonschema.ValidationError,
+ jsonschema.validate,
+ {"r2": "abc", "p1": "def", "r1": 123},
+ schema,
+ validator,
+ )
# cleanup
self._delete([retrieved])
@@ -664,7 +782,7 @@ def test_parameter_schema(self):
retrieved = Action.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(self):
# Test that the runner and action parameters are correctly deep merged when building
@@ -673,54 +791,55 @@ def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(sel
self._create_save_runnertype(metadata=True)
action_db = mock.Mock()
- action_db.runner_type = {'name': 'python'}
- action_db.parameters = {'r1': {'immutable': True}}
+ action_db.runner_type = {"name": "python"}
+ action_db.parameters = {"r1": {"immutable": True}}
schema = util_schema.get_schema_for_action_parameters(action_db=action_db)
expected = {
- u'type': u'object',
- u'properties': {
- u'r1a': {
- u'type': u'string'
- }
- },
- 'immutable': True
+ "type": "object",
+ "properties": {"r1a": {"type": "string"}},
+ "immutable": True,
}
- self.assertEqual(schema['properties']['r1'], expected)
+ self.assertEqual(schema["properties"]["r1"], expected)
@staticmethod
def _create_save_runnertype(metadata=False):
- created = RunnerTypeDB(name='python')
- created.description = ''
+ created = RunnerTypeDB(name="python")
+ created.description = ""
created.enabled = True
if not metadata:
- created.runner_parameters = {'r1': None, 'r2': None}
+ created.runner_parameters = {"r1": None, "r2": None}
else:
created.runner_parameters = {
- 'r1': {'type': 'object', 'properties': {'r1a': {'type': 'string'}}},
- 'r2': {'type': 'string', 'required': True}
+ "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}},
+ "r2": {"type": "string", "required": True},
}
- created.runner_module = 'nomodule'
+ created.runner_module = "nomodule"
return RunnerType.add_or_update(created)
@staticmethod
def _create_save_action(runnertype, metadata=False):
- name = 'action-1'
- pack = 'wolfpack'
+ name = "action-1"
+ pack = "wolfpack"
ref = ResourceReference(pack=pack, name=name).ref
- created = ActionDB(name=name, description='awesomeness', enabled=True,
- entry_point='/tmp/action.py', pack=pack,
- ref=ref,
- runner_type={'name': runnertype.name})
+ created = ActionDB(
+ name=name,
+ description="awesomeness",
+ enabled=True,
+ entry_point="/tmp/action.py",
+ pack=pack,
+ ref=ref,
+ runner_type={"name": runnertype.name},
+ )
if not metadata:
- created.parameters = {'p1': None, 'p2': None, 'p3': None, 'p4': None}
+ created.parameters = {"p1": None, "p2": None, "p3": None, "p4": None}
else:
created.parameters = {
- 'p1': {'type': 'string', 'required': True},
- 'p2': {'type': 'number', 'default': 2868},
- 'p3': {'type': 'boolean', 'default': False},
- 'p4': {'type': 'string', 'secret': True}
+ "p1": {"type": "string", "required": True},
+ "p2": {"type": "number", "default": 2868},
+ "p3": {"type": "boolean", "default": False},
+ "p4": {"type": "string", "secret": True},
}
return Action.add_or_update(created)
@@ -738,20 +857,19 @@ def _delete(model_objects):
class KeyValuePairModelTestCase(DbTestCase):
-
def test_kvp_crud(self):
saved = KeyValuePairModelTestCase._create_save_kvp()
retrieved = KeyValuePair.get_by_name(saved.name)
- self.assertEqual(saved.id, retrieved.id,
- 'Same KeyValuePair was not returned.')
+ self.assertEqual(saved.id, retrieved.id, "Same KeyValuePair was not returned.")
# test update
- self.assertEqual(retrieved.value, '0123456789ABCDEF')
- retrieved.value = 'ABCDEF0123456789'
+ self.assertEqual(retrieved.value, "0123456789ABCDEF")
+ retrieved.value = "ABCDEF0123456789"
saved = KeyValuePair.add_or_update(retrieved)
retrieved = KeyValuePair.get_by_name(saved.name)
- self.assertEqual(retrieved.value, 'ABCDEF0123456789',
- 'Update of key value failed')
+ self.assertEqual(
+ retrieved.value, "ABCDEF0123456789", "Update of key value failed"
+ )
# cleanup
KeyValuePairModelTestCase._delete([retrieved])
@@ -759,11 +877,11 @@ def test_kvp_crud(self):
retrieved = KeyValuePair.get_by_name(saved.name)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
@staticmethod
def _create_save_kvp():
- created = KeyValuePairDB(name='token', value='0123456789ABCDEF')
+ created = KeyValuePairDB(name="token", value="0123456789ABCDEF")
return KeyValuePair.add_or_update(created)
@staticmethod
diff --git a/st2common/tests/unit/test_db_action_state.py b/st2common/tests/unit/test_db_action_state.py
index 3251898e29..47b9d170bd 100644
--- a/st2common/tests/unit/test_db_action_state.py
+++ b/st2common/tests/unit/test_db_action_state.py
@@ -34,13 +34,13 @@ def test_state_crud(self):
retrieved = ActionExecutionState.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
@staticmethod
def _create_save_actionstate():
created = ActionExecutionStateDB()
- created.query_context = {'id': 'some_external_service_id'}
- created.query_module = 'dummy.modules.query1'
+ created.query_context = {"id": "some_external_service_id"}
+ created.query_module = "dummy.modules.query1"
created.execution_id = bson.ObjectId()
return ActionExecutionState.add_or_update(created)
diff --git a/st2common/tests/unit/test_db_auth.py b/st2common/tests/unit/test_db_auth.py
index 9cf35bc737..b159580505 100644
--- a/st2common/tests/unit/test_db_auth.py
+++ b/st2common/tests/unit/test_db_auth.py
@@ -26,44 +26,35 @@
from tests.unit.base import BaseDBModelCRUDTestCase
-__all__ = [
- 'UserDBModelCRUDTestCase'
-]
+__all__ = ["UserDBModelCRUDTestCase"]
class UserDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = UserDB
persistance_class = User
model_class_kwargs = {
- 'name': 'pony',
- 'is_service': False,
- 'nicknames': {
- 'pony1': 'ponyA'
- }
+ "name": "pony",
+ "is_service": False,
+ "nicknames": {"pony1": "ponyA"},
}
- update_attribute_name = 'name'
+ update_attribute_name = "name"
class TokenDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = TokenDB
persistance_class = Token
model_class_kwargs = {
- 'user': 'pony',
- 'token': 'token-token-token-token',
- 'expiry': get_datetime_utc_now(),
- 'metadata': {
- 'service': 'action-runner'
- }
+ "user": "pony",
+ "token": "token-token-token-token",
+ "expiry": get_datetime_utc_now(),
+ "metadata": {"service": "action-runner"},
}
- skip_check_attribute_names = ['expiry']
- update_attribute_name = 'user'
+ skip_check_attribute_names = ["expiry"]
+ update_attribute_name = "user"
class ApiKeyDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = ApiKeyDB
persistance_class = ApiKey
- model_class_kwargs = {
- 'user': 'pony',
- 'key_hash': 'token-token-token-token'
- }
- update_attribute_name = 'user'
+ model_class_kwargs = {"user": "pony", "key_hash": "token-token-token-token"}
+ update_attribute_name = "user"
diff --git a/st2common/tests/unit/test_db_base.py b/st2common/tests/unit/test_db_base.py
index 0c77c336bf..6849643243 100644
--- a/st2common/tests/unit/test_db_base.py
+++ b/st2common/tests/unit/test_db_base.py
@@ -27,11 +27,11 @@ class FakeRuleSpecDB(mongoengine.EmbeddedDocument):
def __str__(self):
result = []
- result.append('ActionExecutionSpecDB@')
- result.append('test')
+ result.append("ActionExecutionSpecDB@")
+ result.append("test")
result.append('(ref="%s", ' % self.ref)
result.append('parameters="%s")' % self.parameters)
- return ''.join(result)
+ return "".join(result)
class FakeModel(stormbase.StormBaseDB):
@@ -52,30 +52,43 @@ class FakeRuleModel(stormbase.StormBaseDB):
class TestBaseModel(DbTestCase):
-
def test_print(self):
- instance = FakeModel(name='seesaw', boolean_field=True,
- datetime_field=date_utils.get_datetime_utc_now(),
- description=u'fun!', dict_field={'a': 1},
- integer_field=68, list_field=['abc'])
-
- expected = ('FakeModel(boolean_field=True, datetime_field="%s", description="fun!", '
- 'dict_field={\'a\': 1}, id=None, integer_field=68, list_field=[\'abc\'], '
- 'name="seesaw")' % str(instance.datetime_field))
+ instance = FakeModel(
+ name="seesaw",
+ boolean_field=True,
+ datetime_field=date_utils.get_datetime_utc_now(),
+ description="fun!",
+ dict_field={"a": 1},
+ integer_field=68,
+ list_field=["abc"],
+ )
+
+ expected = (
+ 'FakeModel(boolean_field=True, datetime_field="%s", description="fun!", '
+ "dict_field={'a': 1}, id=None, integer_field=68, list_field=['abc'], "
+ 'name="seesaw")' % str(instance.datetime_field)
+ )
self.assertEqual(str(instance), expected)
def test_rule_print(self):
- instance = FakeRuleModel(name='seesaw', boolean_field=True,
- datetime_field=date_utils.get_datetime_utc_now(),
- description=u'fun!', dict_field={'a': 1},
- integer_field=68, list_field=['abc'],
- embedded_doc_field={'ref': '1234', 'parameters': {'b': 2}})
-
- expected = ('FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", '
- 'dict_field={\'a\': 1}, embedded_doc_field=ActionExecutionSpecDB@test('
- 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, '
- 'list_field=[\'abc\'], '
- 'name="seesaw")' % str(instance.datetime_field))
+ instance = FakeRuleModel(
+ name="seesaw",
+ boolean_field=True,
+ datetime_field=date_utils.get_datetime_utc_now(),
+ description="fun!",
+ dict_field={"a": 1},
+ integer_field=68,
+ list_field=["abc"],
+ embedded_doc_field={"ref": "1234", "parameters": {"b": 2}},
+ )
+
+ expected = (
+ 'FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", '
+ "dict_field={'a': 1}, embedded_doc_field=ActionExecutionSpecDB@test("
+ 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, '
+ "list_field=['abc'], "
+ 'name="seesaw")' % str(instance.datetime_field)
+ )
self.assertEqual(str(instance), expected)
diff --git a/st2common/tests/unit/test_db_execution.py b/st2common/tests/unit/test_db_execution.py
index 62478ee13e..e94ccb3d94 100644
--- a/st2common/tests/unit/test_db_execution.py
+++ b/st2common/tests/unit/test_db_execution.py
@@ -27,79 +27,71 @@
INQUIRY_RESULT = {
- 'users': [],
- 'roles': [],
- 'route': 'developers',
- 'ttl': 1440,
- 'response': {
- 'secondfactor': 'supersecretvalue'
- },
- 'schema': {
- 'type': 'object',
- 'properties': {
- 'secondfactor': {
- 'secret': True,
- 'required': True,
- 'type': 'string',
- 'description': 'Please enter second factor for authenticating to "foo" service'
+ "users": [],
+ "roles": [],
+ "route": "developers",
+ "ttl": 1440,
+ "response": {"secondfactor": "supersecretvalue"},
+ "schema": {
+ "type": "object",
+ "properties": {
+ "secondfactor": {
+ "secret": True,
+ "required": True,
+ "type": "string",
+ "description": 'Please enter second factor for authenticating to "foo" service',
}
- }
- }
+ },
+ },
}
INQUIRY_LIVEACTION = {
- 'parameters': {
- 'route': 'developers',
- 'schema': {
- 'type': 'object',
- 'properties': {
- 'secondfactor': {
- 'secret': True,
- 'required': True,
- 'type': u'string',
- 'description': 'Please enter second factor for authenticating to "foo" service'
+ "parameters": {
+ "route": "developers",
+ "schema": {
+ "type": "object",
+ "properties": {
+ "secondfactor": {
+ "secret": True,
+ "required": True,
+ "type": "string",
+ "description": 'Please enter second factor for authenticating to "foo" service',
}
- }
- }
+ },
+ },
},
- 'action': 'core.ask'
+ "action": "core.ask",
}
RESPOND_LIVEACTION = {
- 'parameters': {
- 'response': {
- 'secondfactor': 'omgsupersecret',
+ "parameters": {
+ "response": {
+ "secondfactor": "omgsupersecret",
}
},
- 'action': 'st2.inquiry.respond'
+ "action": "st2.inquiry.respond",
}
ACTIONEXECUTIONS = {
"execution_1": {
- 'action': {'uid': 'action:core:ask'},
- 'status': 'succeeded',
- 'runner': {'name': 'inquirer'},
- 'liveaction': INQUIRY_LIVEACTION,
- 'result': INQUIRY_RESULT
+ "action": {"uid": "action:core:ask"},
+ "status": "succeeded",
+ "runner": {"name": "inquirer"},
+ "liveaction": INQUIRY_LIVEACTION,
+ "result": INQUIRY_RESULT,
},
"execution_2": {
- 'action': {'uid': 'action:st2:inquiry.respond'},
- 'status': 'succeeded',
- 'runner': {'name': 'python-script'},
- 'liveaction': RESPOND_LIVEACTION,
- 'result': {
- 'exit_code': 0,
- 'result': None,
- 'stderr': '',
- 'stdout': ''
- }
- }
+ "action": {"uid": "action:st2:inquiry.respond"},
+ "status": "succeeded",
+ "runner": {"name": "python-script"},
+ "liveaction": RESPOND_LIVEACTION,
+ "result": {"exit_code": 0, "result": None, "stderr": "", "stdout": ""},
+ },
}
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ActionExecutionModelTest(DbTestCase):
-
def setUp(self):
self.executions = {}
@@ -107,16 +99,17 @@ def setUp(self):
for name, execution in ACTIONEXECUTIONS.items():
created = ActionExecutionDB()
- created.action = execution['action']
- created.status = execution['status']
- created.runner = execution['runner']
- created.liveaction = execution['liveaction']
- created.result = execution['result']
+ created.action = execution["action"]
+ created.status = execution["status"]
+ created.runner = execution["runner"]
+ created.liveaction = execution["liveaction"]
+ created.result = execution["result"]
saved = ActionExecutionModelTest._save_execution(created)
retrieved = ActionExecution.get_by_id(saved.id)
- self.assertEqual(saved.action, retrieved.action,
- 'Same action was not returned.')
+ self.assertEqual(
+ saved.action, retrieved.action, "Same action was not returned."
+ )
self.executions[name] = retrieved
@@ -128,15 +121,16 @@ def tearDown(self):
retrieved = ActionExecution.get_by_id(execution.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_update_execution(self):
- """Test ActionExecutionDb update
- """
- self.assertIsNone(self.executions['execution_1'].end_timestamp)
- self.executions['execution_1'].end_timestamp = date_utils.get_datetime_utc_now()
- updated = ActionExecution.add_or_update(self.executions['execution_1'])
- self.assertTrue(updated.end_timestamp == self.executions['execution_1'].end_timestamp)
+ """Test ActionExecutionDb update"""
+ self.assertIsNone(self.executions["execution_1"].end_timestamp)
+ self.executions["execution_1"].end_timestamp = date_utils.get_datetime_utc_now()
+ updated = ActionExecution.add_or_update(self.executions["execution_1"])
+ self.assertTrue(
+ updated.end_timestamp == self.executions["execution_1"].end_timestamp
+ )
def test_execution_inquiry_secrets(self):
"""Corner case test for Inquiry responses that contain secrets.
@@ -148,13 +142,15 @@ def test_execution_inquiry_secrets(self):
"""
# Test Inquiry response masking is done properly within this model
- masked = self.executions['execution_1'].mask_secrets(
- self.executions['execution_1'].to_serializable_dict()
+ masked = self.executions["execution_1"].mask_secrets(
+ self.executions["execution_1"].to_serializable_dict()
+ )
+ self.assertEqual(
+ masked["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE
)
- self.assertEqual(masked['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE)
self.assertEqual(
- self.executions['execution_1'].result['response']['secondfactor'],
- "supersecretvalue"
+ self.executions["execution_1"].result["response"]["secondfactor"],
+ "supersecretvalue",
)
def test_execution_inquiry_response_action(self):
@@ -164,10 +160,10 @@ def test_execution_inquiry_response_action(self):
so we mask all response values. This test ensures this happens.
"""
- masked = self.executions['execution_2'].mask_secrets(
- self.executions['execution_2'].to_serializable_dict()
+ masked = self.executions["execution_2"].mask_secrets(
+ self.executions["execution_2"].to_serializable_dict()
)
- for value in masked['parameters']['response'].values():
+ for value in masked["parameters"]["response"].values():
self.assertEqual(value, MASKED_ATTRIBUTE_VALUE)
@staticmethod
diff --git a/st2common/tests/unit/test_db_fields.py b/st2common/tests/unit/test_db_fields.py
index eceb70d4c0..86fd3bc6fb 100644
--- a/st2common/tests/unit/test_db_fields.py
+++ b/st2common/tests/unit/test_db_fields.py
@@ -37,12 +37,12 @@ def test_round_trip_conversion(self):
datetime_values = [
datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500),
datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=0),
- datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999)
+ datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999),
]
datetime_values = [
date_utils.add_utc_tz(datetime_values[0]),
date_utils.add_utc_tz(datetime_values[1]),
- date_utils.add_utc_tz(datetime_values[2])
+ date_utils.add_utc_tz(datetime_values[2]),
]
microsecond_values = []
@@ -69,7 +69,7 @@ def test_round_trip_conversion(self):
expected_value = datetime_values[index]
self.assertEqual(actual_value, expected_value)
- @mock.patch('st2common.fields.LongField.__get__')
+ @mock.patch("st2common.fields.LongField.__get__")
def test_get_(self, mock_get):
field = ComplexDateTimeField()
@@ -79,7 +79,9 @@ def test_get_(self, mock_get):
# Already a datetime
mock_get.return_value = date_utils.get_datetime_utc_now()
- self.assertEqual(field.__get__(instance=None, owner=None), mock_get.return_value)
+ self.assertEqual(
+ field.__get__(instance=None, owner=None), mock_get.return_value
+ )
# Microseconds
dt = datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500)
diff --git a/st2common/tests/unit/test_db_liveaction.py b/st2common/tests/unit/test_db_liveaction.py
index 7c8b6aa35f..605aa759f6 100644
--- a/st2common/tests/unit/test_db_liveaction.py
+++ b/st2common/tests/unit/test_db_liveaction.py
@@ -26,19 +26,19 @@
from st2tests import DbTestCase
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class LiveActionModelTest(DbTestCase):
-
def test_liveaction_crud_no_notify(self):
created = LiveActionDB()
- created.action = 'core.local'
- created.description = ''
- created.status = 'running'
+ created.action = "core.local"
+ created.description = ""
+ created.status = "running"
created.parameters = {}
saved = LiveActionModelTest._save_liveaction(created)
retrieved = LiveAction.get_by_id(saved.id)
- self.assertEqual(saved.action, retrieved.action,
- 'Same triggertype was not returned.')
+ self.assertEqual(
+ saved.action, retrieved.action, "Same triggertype was not returned."
+ )
self.assertEqual(retrieved.notify, None)
# Test update
@@ -52,80 +52,81 @@ def test_liveaction_crud_no_notify(self):
retrieved = LiveAction.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
def test_liveaction_create_with_notify_on_complete_only(self):
created = LiveActionDB()
- created.action = 'core.local'
- created.description = ''
- created.status = 'running'
+ created.action = "core.local"
+ created.description = ""
+ created.status = "running"
created.parameters = {}
notify_db = NotificationSchema()
notify_sub_schema = NotificationSubSchema()
- notify_sub_schema.message = 'Action complete.'
- notify_sub_schema.data = {
- 'foo': 'bar',
- 'bar': 1,
- 'baz': {'k1': 'v1'}
- }
+ notify_sub_schema.message = "Action complete."
+ notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}}
notify_db.on_complete = notify_sub_schema
created.notify = notify_db
saved = LiveActionModelTest._save_liveaction(created)
retrieved = LiveAction.get_by_id(saved.id)
- self.assertEqual(saved.action, retrieved.action,
- 'Same triggertype was not returned.')
+ self.assertEqual(
+ saved.action, retrieved.action, "Same triggertype was not returned."
+ )
# Assert notify settings saved are right.
- self.assertEqual(notify_sub_schema.message, retrieved.notify.on_complete.message)
+ self.assertEqual(
+ notify_sub_schema.message, retrieved.notify.on_complete.message
+ )
self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_complete.data)
- self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_complete.routes)
+ self.assertListEqual(
+ notify_sub_schema.routes, retrieved.notify.on_complete.routes
+ )
self.assertEqual(retrieved.notify.on_success, None)
self.assertEqual(retrieved.notify.on_failure, None)
def test_liveaction_create_with_notify_on_success_only(self):
created = LiveActionDB()
- created.action = 'core.local'
- created.description = ''
- created.status = 'running'
+ created.action = "core.local"
+ created.description = ""
+ created.status = "running"
created.parameters = {}
notify_db = NotificationSchema()
notify_sub_schema = NotificationSubSchema()
- notify_sub_schema.message = 'Action succeeded.'
- notify_sub_schema.data = {
- 'foo': 'bar',
- 'bar': 1,
- 'baz': {'k1': 'v1'}
- }
+ notify_sub_schema.message = "Action succeeded."
+ notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}}
notify_db.on_success = notify_sub_schema
created.notify = notify_db
saved = LiveActionModelTest._save_liveaction(created)
retrieved = LiveAction.get_by_id(saved.id)
- self.assertEqual(saved.action, retrieved.action,
- 'Same triggertype was not returned.')
+ self.assertEqual(
+ saved.action, retrieved.action, "Same triggertype was not returned."
+ )
# Assert notify settings saved are right.
- self.assertEqual(notify_sub_schema.message,
- retrieved.notify.on_success.message)
+ self.assertEqual(notify_sub_schema.message, retrieved.notify.on_success.message)
self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_success.data)
- self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_success.routes)
+ self.assertListEqual(
+ notify_sub_schema.routes, retrieved.notify.on_success.routes
+ )
self.assertEqual(retrieved.notify.on_failure, None)
self.assertEqual(retrieved.notify.on_complete, None)
def test_liveaction_create_with_notify_both_on_success_and_on_error(self):
created = LiveActionDB()
- created.action = 'core.local'
- created.description = ''
- created.status = 'running'
+ created.action = "core.local"
+ created.description = ""
+ created.status = "running"
created.parameters = {}
- on_success = NotificationSubSchema(message='Action succeeded.')
- on_failure = NotificationSubSchema(message='Action failed.')
- created.notify = NotificationSchema(on_success=on_success,
- on_failure=on_failure)
+ on_success = NotificationSubSchema(message="Action succeeded.")
+ on_failure = NotificationSubSchema(message="Action failed.")
+ created.notify = NotificationSchema(
+ on_success=on_success, on_failure=on_failure
+ )
saved = LiveActionModelTest._save_liveaction(created)
retrieved = LiveAction.get_by_id(saved.id)
- self.assertEqual(saved.action, retrieved.action,
- 'Same triggertype was not returned.')
+ self.assertEqual(
+ saved.action, retrieved.action, "Same triggertype was not returned."
+ )
# Assert notify settings saved are right.
self.assertEqual(on_success.message, retrieved.notify.on_success.message)
self.assertEqual(on_failure.message, retrieved.notify.on_failure.message)
diff --git a/st2common/tests/unit/test_db_marker.py b/st2common/tests/unit/test_db_marker.py
index 72dc879697..b9cd879ea3 100644
--- a/st2common/tests/unit/test_db_marker.py
+++ b/st2common/tests/unit/test_db_marker.py
@@ -26,26 +26,27 @@ class DumperMarkerModelTest(DbTestCase):
def test_dumper_marker_crud(self):
saved = DumperMarkerModelTest._create_save_dumper_marker()
retrieved = DumperMarker.get_by_id(saved.id)
- self.assertEqual(saved.marker, retrieved.marker,
- 'Same marker was not returned.')
+ self.assertEqual(
+ saved.marker, retrieved.marker, "Same marker was not returned."
+ )
# test update
time_now = date_utils.get_datetime_utc_now()
retrieved.updated_at = time_now
saved = DumperMarker.add_or_update(retrieved)
retrieved = DumperMarker.get_by_id(saved.id)
- self.assertEqual(retrieved.updated_at, time_now, 'Update to marker failed.')
+ self.assertEqual(retrieved.updated_at, time_now, "Update to marker failed.")
# cleanup
DumperMarkerModelTest._delete([retrieved])
try:
retrieved = DumperMarker.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after failure.')
+ self.assertIsNone(retrieved, "managed to retrieve after failure.")
@staticmethod
def _create_save_dumper_marker():
created = DumperMarkerDB()
- created.marker = '2015-06-11T00:35:15.260439Z'
+ created.marker = "2015-06-11T00:35:15.260439Z"
created.updated_at = date_utils.get_datetime_utc_now()
return DumperMarker.add_or_update(created)
diff --git a/st2common/tests/unit/test_db_model_uids.py b/st2common/tests/unit/test_db_model_uids.py
index 3f5ec1ca6c..2dd3bfb87d 100644
--- a/st2common/tests/unit/test_db_model_uids.py
+++ b/st2common/tests/unit/test_db_model_uids.py
@@ -30,72 +30,80 @@
from st2common.models.db.policy import PolicyDB
from st2common.models.db.auth import ApiKeyDB
-__all__ = [
- 'DBModelUIDFieldTestCase'
-]
+__all__ = ["DBModelUIDFieldTestCase"]
class DBModelUIDFieldTestCase(unittest2.TestCase):
def test_get_uid(self):
- pack_db = PackDB(ref='ma_pack')
- self.assertEqual(pack_db.get_uid(), 'pack:ma_pack')
+ pack_db = PackDB(ref="ma_pack")
+ self.assertEqual(pack_db.get_uid(), "pack:ma_pack")
self.assertTrue(pack_db.has_valid_uid())
- sensor_type_db = SensorTypeDB(name='sname', pack='spack')
- self.assertEqual(sensor_type_db.get_uid(), 'sensor_type:spack:sname')
+ sensor_type_db = SensorTypeDB(name="sname", pack="spack")
+ self.assertEqual(sensor_type_db.get_uid(), "sensor_type:spack:sname")
self.assertTrue(sensor_type_db.has_valid_uid())
- action_db = ActionDB(name='aname', pack='apack', runner_type={})
- self.assertEqual(action_db.get_uid(), 'action:apack:aname')
+ action_db = ActionDB(name="aname", pack="apack", runner_type={})
+ self.assertEqual(action_db.get_uid(), "action:apack:aname")
self.assertTrue(action_db.has_valid_uid())
- rule_db = RuleDB(name='rname', pack='rpack')
- self.assertEqual(rule_db.get_uid(), 'rule:rpack:rname')
+ rule_db = RuleDB(name="rname", pack="rpack")
+ self.assertEqual(rule_db.get_uid(), "rule:rpack:rname")
self.assertTrue(rule_db.has_valid_uid())
- trigger_type_db = TriggerTypeDB(name='ttname', pack='ttpack')
- self.assertEqual(trigger_type_db.get_uid(), 'trigger_type:ttpack:ttname')
+ trigger_type_db = TriggerTypeDB(name="ttname", pack="ttpack")
+ self.assertEqual(trigger_type_db.get_uid(), "trigger_type:ttpack:ttname")
self.assertTrue(trigger_type_db.has_valid_uid())
- trigger_db = TriggerDB(name='tname', pack='tpack')
- self.assertTrue(trigger_db.get_uid().startswith('trigger:tpack:tname:'))
+ trigger_db = TriggerDB(name="tname", pack="tpack")
+ self.assertTrue(trigger_db.get_uid().startswith("trigger:tpack:tname:"))
# Verify that same set of parameters always results in the same hash
- parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}}
+ parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}}
paramers_hash = json.dumps(parameters, sort_keys=True)
paramers_hash = hashlib.md5(paramers_hash.encode()).hexdigest()
- parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}}
- trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters)
- self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash))
+ parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}}
+ trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters)
+ self.assertEqual(
+ trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash)
+ )
self.assertTrue(trigger_db.has_valid_uid())
- parameters = {'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1}
- trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters)
- self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash))
+ parameters = {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1}
+ trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters)
+ self.assertEqual(
+ trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash)
+ )
self.assertTrue(trigger_db.has_valid_uid())
- parameters = {'b': u'unicode', 'c': [1, 2, 3], 'd': {'h': 2, 'g': 1}, 'a': 1}
- trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters)
- self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash))
+ parameters = {"b": "unicode", "c": [1, 2, 3], "d": {"h": 2, "g": 1}, "a": 1}
+ trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters)
+ self.assertEqual(
+ trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash)
+ )
self.assertTrue(trigger_db.has_valid_uid())
- parameters = OrderedDict({'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1})
- trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters)
- self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash))
+ parameters = OrderedDict(
+ {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1}
+ )
+ trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters)
+ self.assertEqual(
+ trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash)
+ )
self.assertTrue(trigger_db.has_valid_uid())
- policy_type_db = PolicyTypeDB(resource_type='action', name='concurrency')
- self.assertEqual(policy_type_db.get_uid(), 'policy_type:action:concurrency')
+ policy_type_db = PolicyTypeDB(resource_type="action", name="concurrency")
+ self.assertEqual(policy_type_db.get_uid(), "policy_type:action:concurrency")
self.assertTrue(policy_type_db.has_valid_uid())
- policy_db = PolicyDB(pack='dummy', name='policy1')
- self.assertEqual(policy_db.get_uid(), 'policy:dummy:policy1')
+ policy_db = PolicyDB(pack="dummy", name="policy1")
+ self.assertEqual(policy_db.get_uid(), "policy:dummy:policy1")
- api_key_db = ApiKeyDB(key_hash='valid')
- self.assertEqual(api_key_db.get_uid(), 'api_key:valid')
+ api_key_db = ApiKeyDB(key_hash="valid")
+ self.assertEqual(api_key_db.get_uid(), "api_key:valid")
self.assertTrue(api_key_db.has_valid_uid())
api_key_db = ApiKeyDB()
- self.assertEqual(api_key_db.get_uid(), 'api_key:')
+ self.assertEqual(api_key_db.get_uid(), "api_key:")
self.assertFalse(api_key_db.has_valid_uid())
diff --git a/st2common/tests/unit/test_db_pack.py b/st2common/tests/unit/test_db_pack.py
index c8df8b5a28..d5b5af00f4 100644
--- a/st2common/tests/unit/test_db_pack.py
+++ b/st2common/tests/unit/test_db_pack.py
@@ -26,21 +26,21 @@ class PackDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = PackDB
persistance_class = Pack
model_class_kwargs = {
- 'name': 'Yolo CI',
- 'ref': 'yolo_ci',
- 'description': 'YOLO CI pack',
- 'version': '0.1.0',
- 'author': 'Volkswagen',
- 'path': '/opt/stackstorm/packs/yolo_ci/'
+ "name": "Yolo CI",
+ "ref": "yolo_ci",
+ "description": "YOLO CI pack",
+ "version": "0.1.0",
+ "author": "Volkswagen",
+ "path": "/opt/stackstorm/packs/yolo_ci/",
}
- update_attribute_name = 'author'
+ update_attribute_name = "author"
def test_path_none(self):
PackDBModelCRUDTestCase.model_class_kwargs = {
- 'name': 'Yolo CI',
- 'ref': 'yolo_ci',
- 'description': 'YOLO CI pack',
- 'version': '0.1.0',
- 'author': 'Volkswagen'
+ "name": "Yolo CI",
+ "ref": "yolo_ci",
+ "description": "YOLO CI pack",
+ "version": "0.1.0",
+ "author": "Volkswagen",
}
super(PackDBModelCRUDTestCase, self).test_crud_operations()
diff --git a/st2common/tests/unit/test_db_policy.py b/st2common/tests/unit/test_db_policy.py
index 9364c61074..95b682e4a4 100644
--- a/st2common/tests/unit/test_db_policy.py
+++ b/st2common/tests/unit/test_db_policy.py
@@ -24,64 +24,113 @@
class PolicyTypeReferenceTest(unittest2.TestCase):
-
def test_is_reference(self):
- self.assertTrue(PolicyTypeReference.is_reference('action.concurrency'))
- self.assertFalse(PolicyTypeReference.is_reference('concurrency'))
- self.assertFalse(PolicyTypeReference.is_reference(''))
+ self.assertTrue(PolicyTypeReference.is_reference("action.concurrency"))
+ self.assertFalse(PolicyTypeReference.is_reference("concurrency"))
+ self.assertFalse(PolicyTypeReference.is_reference(""))
self.assertFalse(PolicyTypeReference.is_reference(None))
def test_validate_resource_type(self):
- self.assertEqual(PolicyTypeReference.validate_resource_type('action'), 'action')
- self.assertRaises(ValueError, PolicyTypeReference.validate_resource_type, 'action.test')
+ self.assertEqual(PolicyTypeReference.validate_resource_type("action"), "action")
+ self.assertRaises(
+ ValueError, PolicyTypeReference.validate_resource_type, "action.test"
+ )
def test_get_resource_type(self):
- self.assertEqual(PolicyTypeReference.get_resource_type('action.concurrency'), 'action')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '.abc')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, 'abc')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, None)
+ self.assertEqual(
+ PolicyTypeReference.get_resource_type("action.concurrency"), "action"
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.get_resource_type, ".abc"
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.get_resource_type, "abc"
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.get_resource_type, ""
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.get_resource_type, None
+ )
def test_get_name(self):
- self.assertEqual(PolicyTypeReference.get_name('action.concurrency'), 'concurrency')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '.abc')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, 'abc')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '')
+ self.assertEqual(
+ PolicyTypeReference.get_name("action.concurrency"), "concurrency"
+ )
+ self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, ".abc")
+ self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "abc")
+ self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "")
self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, None)
def test_to_string_reference(self):
- ref = PolicyTypeReference.to_string_reference(resource_type='action', name='concurrency')
- self.assertEqual(ref, 'action.concurrency')
-
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type='action.test', name='concurrency')
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type=None, name='concurrency')
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type='', name='concurrency')
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type='action', name=None)
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type='action', name='')
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type=None, name=None)
- self.assertRaises(ValueError, PolicyTypeReference.to_string_reference,
- resource_type='', name='')
+ ref = PolicyTypeReference.to_string_reference(
+ resource_type="action", name="concurrency"
+ )
+ self.assertEqual(ref, "action.concurrency")
+
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type="action.test",
+ name="concurrency",
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type=None,
+ name="concurrency",
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type="",
+ name="concurrency",
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type="action",
+ name=None,
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type="action",
+ name="",
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type=None,
+ name=None,
+ )
+ self.assertRaises(
+ ValueError,
+ PolicyTypeReference.to_string_reference,
+ resource_type="",
+ name="",
+ )
def test_from_string_reference(self):
- ref = PolicyTypeReference.from_string_reference('action.concurrency')
- self.assertEqual(ref.resource_type, 'action')
- self.assertEqual(ref.name, 'concurrency')
- self.assertEqual(ref.ref, 'action.concurrency')
-
- ref = PolicyTypeReference.from_string_reference('action.concurrency.targeted')
- self.assertEqual(ref.resource_type, 'action')
- self.assertEqual(ref.name, 'concurrency.targeted')
- self.assertEqual(ref.ref, 'action.concurrency.targeted')
-
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '.test')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '')
- self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, None)
+ ref = PolicyTypeReference.from_string_reference("action.concurrency")
+ self.assertEqual(ref.resource_type, "action")
+ self.assertEqual(ref.name, "concurrency")
+ self.assertEqual(ref.ref, "action.concurrency")
+
+ ref = PolicyTypeReference.from_string_reference("action.concurrency.targeted")
+ self.assertEqual(ref.resource_type, "action")
+ self.assertEqual(ref.name, "concurrency.targeted")
+ self.assertEqual(ref.ref, "action.concurrency.targeted")
+
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.from_string_reference, ".test"
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.from_string_reference, ""
+ )
+ self.assertRaises(
+ InvalidReferenceError, PolicyTypeReference.from_string_reference, None
+ )
class PolicyTypeTest(DbModelTestCase):
@@ -89,34 +138,26 @@ class PolicyTypeTest(DbModelTestCase):
@staticmethod
def _create_instance():
- parameters = {
- 'threshold': {
- 'type': 'integer',
- 'required': True
- }
- }
-
- instance = PolicyTypeDB(name='concurrency',
- description='TBD',
- enabled=None,
- ref=None,
- resource_type='action',
- module='st2action.policies.concurrency',
- parameters=parameters)
+ parameters = {"threshold": {"type": "integer", "required": True}}
+
+ instance = PolicyTypeDB(
+ name="concurrency",
+ description="TBD",
+ enabled=None,
+ ref=None,
+ resource_type="action",
+ module="st2action.policies.concurrency",
+ parameters=parameters,
+ )
return instance
def test_crud(self):
instance = self._create_instance()
- defaults = {
- 'ref': 'action.concurrency',
- 'enabled': True
- }
+ defaults = {"ref": "action.concurrency", "enabled": True}
- updates = {
- 'description': 'Limits the concurrent executions for the action.'
- }
+ updates = {"description": "Limits the concurrent executions for the action."}
self._assert_crud(instance, defaults=defaults, updates=updates)
@@ -130,16 +171,16 @@ class PolicyTest(DbModelTestCase):
@staticmethod
def _create_instance():
- instance = PolicyDB(pack=None,
- name='local.concurrency',
- description='TBD',
- enabled=None,
- ref=None,
- resource_ref='core.local',
- policy_type='action.concurrency',
- parameters={
- 'threshold': 25
- })
+ instance = PolicyDB(
+ pack=None,
+ name="local.concurrency",
+ description="TBD",
+ enabled=None,
+ ref=None,
+ resource_ref="core.local",
+ policy_type="action.concurrency",
+ parameters={"threshold": 25},
+ )
return instance
@@ -147,13 +188,13 @@ def test_crud(self):
instance = self._create_instance()
defaults = {
- 'pack': pack_constants.DEFAULT_PACK_NAME,
- 'ref': '%s.local.concurrency' % pack_constants.DEFAULT_PACK_NAME,
- 'enabled': True
+ "pack": pack_constants.DEFAULT_PACK_NAME,
+ "ref": "%s.local.concurrency" % pack_constants.DEFAULT_PACK_NAME,
+ "enabled": True,
}
updates = {
- 'description': 'Limits the concurrent executions for the action "core.local".'
+ "description": 'Limits the concurrent executions for the action "core.local".'
}
self._assert_crud(instance, defaults=defaults, updates=updates)
@@ -164,7 +205,7 @@ def test_ref(self):
self.assertIsNotNone(ref)
self.assertEqual(ref.pack, instance.pack)
self.assertEqual(ref.name, instance.name)
- self.assertEqual(ref.ref, instance.pack + '.' + instance.name)
+ self.assertEqual(ref.ref, instance.pack + "." + instance.name)
self.assertEqual(ref.ref, instance.ref)
def test_unique_key(self):
diff --git a/st2common/tests/unit/test_db_rbac.py b/st2common/tests/unit/test_db_rbac.py
index 62b9763272..d9c3fcc958 100644
--- a/st2common/tests/unit/test_db_rbac.py
+++ b/st2common/tests/unit/test_db_rbac.py
@@ -28,10 +28,10 @@
__all__ = [
- 'RoleDBModelCRUDTestCase',
- 'UserRoleAssignmentDBModelCRUDTestCase',
- 'PermissionGrantDBModelCRUDTestCase',
- 'GroupToRoleMappingDBModelCRUDTestCase'
+ "RoleDBModelCRUDTestCase",
+ "UserRoleAssignmentDBModelCRUDTestCase",
+ "PermissionGrantDBModelCRUDTestCase",
+ "GroupToRoleMappingDBModelCRUDTestCase",
]
@@ -39,44 +39,44 @@ class RoleDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = RoleDB
persistance_class = Role
model_class_kwargs = {
- 'name': 'role_one',
- 'description': None,
- 'system': False,
- 'permission_grants': []
+ "name": "role_one",
+ "description": None,
+ "system": False,
+ "permission_grants": [],
}
- update_attribute_name = 'name'
+ update_attribute_name = "name"
class UserRoleAssignmentDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = UserRoleAssignmentDB
persistance_class = UserRoleAssignment
model_class_kwargs = {
- 'user': 'user_one',
- 'role': 'role_one',
- 'source': 'source_one',
- 'is_remote': True
+ "user": "user_one",
+ "role": "role_one",
+ "source": "source_one",
+ "is_remote": True,
}
- update_attribute_name = 'role'
+ update_attribute_name = "role"
class PermissionGrantDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = PermissionGrantDB
persistance_class = PermissionGrant
model_class_kwargs = {
- 'resource_uid': 'pack:core',
- 'resource_type': 'pack',
- 'permission_types': []
+ "resource_uid": "pack:core",
+ "resource_type": "pack",
+ "permission_types": [],
}
- update_attribute_name = 'resource_uid'
+ update_attribute_name = "resource_uid"
class GroupToRoleMappingDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase):
model_class = GroupToRoleMappingDB
persistance_class = GroupToRoleMapping
model_class_kwargs = {
- 'group': 'some group',
- 'roles': ['role_one', 'role_two'],
- 'description': 'desc',
- 'enabled': True
+ "group": "some group",
+ "roles": ["role_one", "role_two"],
+ "description": "desc",
+ "enabled": True,
}
- update_attribute_name = 'group'
+ update_attribute_name = "group"
diff --git a/st2common/tests/unit/test_db_rule_enforcement.py b/st2common/tests/unit/test_db_rule_enforcement.py
index 734a34ffc3..5cececffa0 100644
--- a/st2common/tests/unit/test_db_rule_enforcement.py
+++ b/st2common/tests/unit/test_db_rule_enforcement.py
@@ -28,19 +28,19 @@
SKIP_DELETE = False
-__all__ = [
- 'RuleEnforcementModelTest'
-]
+__all__ = ["RuleEnforcementModelTest"]
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class RuleEnforcementModelTest(DbTestCase):
-
def test_ruleenforcment_crud(self):
saved = RuleEnforcementModelTest._create_save_rule_enforcement()
retrieved = RuleEnforcement.get_by_id(saved.id)
- self.assertEqual(saved.rule.ref, retrieved.rule.ref,
- 'Same rule enforcement was not returned.')
+ self.assertEqual(
+ saved.rule.ref,
+ retrieved.rule.ref,
+ "Same rule enforcement was not returned.",
+ )
self.assertIsNotNone(retrieved.enforced_at)
# test update
RULE_ID = str(bson.ObjectId())
@@ -48,73 +48,82 @@ def test_ruleenforcment_crud(self):
retrieved.rule.id = RULE_ID
saved = RuleEnforcement.add_or_update(retrieved)
retrieved = RuleEnforcement.get_by_id(saved.id)
- self.assertEqual(retrieved.rule.id, RULE_ID,
- 'Update to rule enforcement failed.')
+ self.assertEqual(
+ retrieved.rule.id, RULE_ID, "Update to rule enforcement failed."
+ )
# cleanup
RuleEnforcementModelTest._delete([retrieved])
try:
retrieved = RuleEnforcement.get_by_id(saved.id)
except StackStormDBObjectNotFoundError:
retrieved = None
- self.assertIsNone(retrieved, 'managed to retrieve after delete.')
+ self.assertIsNone(retrieved, "managed to retrieve after delete.")
def test_status_set_to_failed_for_objects_which_predate_status_field(self):
- rule = {
- 'ref': 'foo_pack.foo_rule',
- 'uid': 'rule:foo_pack:foo_rule'
- }
+ rule = {"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"}
# 1. No status field explicitly set and no failure reason
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule=rule,
- execution_id=str(bson.ObjectId()))
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule=rule,
+ execution_id=str(bson.ObjectId()),
+ )
enforcement_db = RuleEnforcement.add_or_update(enforcement_db)
self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED)
# 2. No status field, with failure reason, status should be set to failed
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule=rule,
- execution_id=str(bson.ObjectId()),
- failure_reason='so much fail')
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule=rule,
+ execution_id=str(bson.ObjectId()),
+ failure_reason="so much fail",
+ )
enforcement_db = RuleEnforcement.add_or_update(enforcement_db)
self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED)
# 3. Explcit status field - succeeded + failure reasun
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule=rule,
- execution_id=str(bson.ObjectId()),
- status=RULE_ENFORCEMENT_STATUS_SUCCEEDED,
- failure_reason='so much fail')
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule=rule,
+ execution_id=str(bson.ObjectId()),
+ status=RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ failure_reason="so much fail",
+ )
enforcement_db = RuleEnforcement.add_or_update(enforcement_db)
self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED)
# 4. Explcit status field - succeeded + no failure reasun
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule=rule,
- execution_id=str(bson.ObjectId()),
- status=RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule=rule,
+ execution_id=str(bson.ObjectId()),
+ status=RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
enforcement_db = RuleEnforcement.add_or_update(enforcement_db)
self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED)
# 5. Explcit status field - failed + no failure reasun
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule=rule,
- execution_id=str(bson.ObjectId()),
- status=RULE_ENFORCEMENT_STATUS_FAILED)
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule=rule,
+ execution_id=str(bson.ObjectId()),
+ status=RULE_ENFORCEMENT_STATUS_FAILED,
+ )
enforcement_db = RuleEnforcement.add_or_update(enforcement_db)
self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED)
@staticmethod
def _create_save_rule_enforcement():
- created = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()),
- rule={'ref': 'foo_pack.foo_rule',
- 'uid': 'rule:foo_pack:foo_rule'},
- execution_id=str(bson.ObjectId()))
+ created = RuleEnforcementDB(
+ trigger_instance_id=str(bson.ObjectId()),
+ rule={"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"},
+ execution_id=str(bson.ObjectId()),
+ )
return RuleEnforcement.add_or_update(created)
@staticmethod
diff --git a/st2common/tests/unit/test_db_task.py b/st2common/tests/unit/test_db_task.py
index 60285f1366..bc0d3e2382 100644
--- a/st2common/tests/unit/test_db_task.py
+++ b/st2common/tests/unit/test_db_task.py
@@ -27,19 +27,18 @@
from st2common.util import date as date_utils
-@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock())
class TaskExecutionModelTest(st2tests.DbTestCase):
-
def test_task_execution_crud(self):
initial = wf_db_models.TaskExecutionDB()
initial.workflow_execution = uuid.uuid4().hex
- initial.task_name = 't1'
- initial.task_id = 't1'
+ initial.task_name = "t1"
+ initial.task_id = "t1"
initial.task_route = 0
- initial.task_spec = {'tasks': {'t1': 'some task'}}
+ initial.task_spec = {"tasks": {"t1": "some task"}}
initial.delay = 180
- initial.status = 'requested'
- initial.context = {'var1': 'foobar'}
+ initial.status = "requested"
+ initial.context = {"var1": "foobar"}
# Test create
created = wf_db_access.TaskExecution.add_or_update(initial)
@@ -61,7 +60,7 @@ def test_task_execution_crud(self):
self.assertDictEqual(created.context, retrieved.context)
# Test update
- status = 'running'
+ status = "running"
retrieved = wf_db_access.TaskExecution.update(retrieved, status=status)
updated = wf_db_access.TaskExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
@@ -79,8 +78,8 @@ def test_task_execution_crud(self):
self.assertDictEqual(updated.context, retrieved.context)
# Test add or update
- retrieved.result = {'output': 'fubar'}
- retrieved.status = 'succeeded'
+ retrieved.result = {"output": "fubar"}
+ retrieved.status = "succeeded"
retrieved.end_timestamp = date_utils.get_datetime_utc_now()
retrieved = wf_db_access.TaskExecution.add_or_update(retrieved)
updated = wf_db_access.TaskExecution.get_by_id(doc_id)
@@ -105,20 +104,20 @@ def test_task_execution_crud(self):
self.assertRaises(
db_exc.StackStormDBObjectNotFoundError,
wf_db_access.TaskExecution.get_by_id,
- doc_id
+ doc_id,
)
def test_task_execution_crud_set_itemized_true(self):
initial = wf_db_models.TaskExecutionDB()
initial.workflow_execution = uuid.uuid4().hex
- initial.task_name = 't1'
- initial.task_id = 't1'
+ initial.task_name = "t1"
+ initial.task_id = "t1"
initial.task_route = 0
- initial.task_spec = {'tasks': {'t1': 'some task'}}
+ initial.task_spec = {"tasks": {"t1": "some task"}}
initial.delay = 180
initial.itemized = True
- initial.status = 'requested'
- initial.context = {'var1': 'foobar'}
+ initial.status = "requested"
+ initial.context = {"var1": "foobar"}
# Test create
created = wf_db_access.TaskExecution.add_or_update(initial)
@@ -140,7 +139,7 @@ def test_task_execution_crud_set_itemized_true(self):
self.assertDictEqual(created.context, retrieved.context)
# Test update
- status = 'running'
+ status = "running"
retrieved = wf_db_access.TaskExecution.update(retrieved, status=status)
updated = wf_db_access.TaskExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
@@ -158,8 +157,8 @@ def test_task_execution_crud_set_itemized_true(self):
self.assertDictEqual(updated.context, retrieved.context)
# Test add or update
- retrieved.result = {'output': 'fubar'}
- retrieved.status = 'succeeded'
+ retrieved.result = {"output": "fubar"}
+ retrieved.status = "succeeded"
retrieved.end_timestamp = date_utils.get_datetime_utc_now()
retrieved = wf_db_access.TaskExecution.add_or_update(retrieved)
updated = wf_db_access.TaskExecution.get_by_id(doc_id)
@@ -184,19 +183,19 @@ def test_task_execution_crud_set_itemized_true(self):
self.assertRaises(
db_exc.StackStormDBObjectNotFoundError,
wf_db_access.TaskExecution.get_by_id,
- doc_id
+ doc_id,
)
def test_task_execution_write_conflict(self):
initial = wf_db_models.TaskExecutionDB()
initial.workflow_execution = uuid.uuid4().hex
- initial.task_name = 't1'
- initial.task_id = 't1'
+ initial.task_name = "t1"
+ initial.task_id = "t1"
initial.task_route = 0
- initial.task_spec = {'tasks': {'t1': 'some task'}}
+ initial.task_spec = {"tasks": {"t1": "some task"}}
initial.delay = 180
- initial.status = 'requested'
- initial.context = {'var1': 'foobar'}
+ initial.status = "requested"
+ initial.context = {"var1": "foobar"}
# Prep record
created = wf_db_access.TaskExecution.add_or_update(initial)
@@ -208,7 +207,7 @@ def test_task_execution_write_conflict(self):
retrieved2 = wf_db_access.TaskExecution.get_by_id(doc_id)
# Test update on instance 1, expect success
- status = 'running'
+ status = "running"
retrieved1 = wf_db_access.TaskExecution.update(retrieved1, status=status)
updated = wf_db_access.TaskExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
@@ -230,7 +229,7 @@ def test_task_execution_write_conflict(self):
db_exc.StackStormDBObjectWriteConflictError,
wf_db_access.TaskExecution.update,
retrieved2,
- status='pausing'
+ status="pausing",
)
# Test delete
@@ -239,5 +238,5 @@ def test_task_execution_write_conflict(self):
self.assertRaises(
db_exc.StackStormDBObjectNotFoundError,
wf_db_access.TaskExecution.get_by_id,
- doc_id
+ doc_id,
)
diff --git a/st2common/tests/unit/test_db_trace.py b/st2common/tests/unit/test_db_trace.py
index b9e2ec9c8a..1e0f884472 100644
--- a/st2common/tests/unit/test_db_trace.py
+++ b/st2common/tests/unit/test_db_trace.py
@@ -24,85 +24,103 @@
class TraceDBTest(CleanDbTestCase):
-
def test_get(self):
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
+ trace_tag="test_trace",
action_executions=[str(bson.ObjectId()) for _ in range(4)],
rules=[str(bson.ObjectId()) for _ in range(4)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(5)])
+ trigger_instances=[str(bson.ObjectId()) for _ in range(5)],
+ )
retrieved = Trace.get(id=saved.id)
- self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.')
+ self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.")
def test_query(self):
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
+ trace_tag="test_trace",
action_executions=[str(bson.ObjectId()) for _ in range(4)],
rules=[str(bson.ObjectId()) for _ in range(4)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(5)])
+ trigger_instances=[str(bson.ObjectId()) for _ in range(5)],
+ )
retrieved = Trace.query(trace_tag=saved.trace_tag)
- self.assertEqual(len(retrieved), 1, 'Should have 1 trace.')
- self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.')
+ self.assertEqual(len(retrieved), 1, "Should have 1 trace.")
+ self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.")
# Add another trace with same trace_tag and confirm that we support.
# This is most likley an anti-pattern for the trace_tag but it is an unknown.
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
+ trace_tag="test_trace",
action_executions=[str(bson.ObjectId()) for _ in range(2)],
rules=[str(bson.ObjectId()) for _ in range(4)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(3)])
+ trigger_instances=[str(bson.ObjectId()) for _ in range(3)],
+ )
retrieved = Trace.query(trace_tag=saved.trace_tag)
- self.assertEqual(len(retrieved), 2, 'Should have 2 traces.')
+ self.assertEqual(len(retrieved), 2, "Should have 2 traces.")
def test_update(self):
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
- action_executions=[],
- rules=[],
- trigger_instances=[])
+ trace_tag="test_trace", action_executions=[], rules=[], trigger_instances=[]
+ )
retrieved = Trace.query(trace_tag=saved.trace_tag)
- self.assertEqual(len(retrieved), 1, 'Should have 1 trace.')
- self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.')
+ self.assertEqual(len(retrieved), 1, "Should have 1 trace.")
+ self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.")
no_action_executions = 4
no_rules = 4
no_trigger_instances = 5
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
+ trace_tag="test_trace",
id_=retrieved[0].id,
- action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)],
+ action_executions=[
+ str(bson.ObjectId()) for _ in range(no_action_executions)
+ ],
rules=[str(bson.ObjectId()) for _ in range(no_rules)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)])
+ trigger_instances=[
+ str(bson.ObjectId()) for _ in range(no_trigger_instances)
+ ],
+ )
retrieved = Trace.query(trace_tag=saved.trace_tag)
- self.assertEqual(len(retrieved), 1, 'Should have 1 trace.')
- self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.')
+ self.assertEqual(len(retrieved), 1, "Should have 1 trace.")
+ self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.")
# validate update
- self.assertEqual(len(retrieved[0].action_executions), no_action_executions,
- 'Failed to update action_executions.')
- self.assertEqual(len(retrieved[0].rules), no_rules, 'Failed to update rules.')
- self.assertEqual(len(retrieved[0].trigger_instances), no_trigger_instances,
- 'Failed to update trigger_instances.')
+ self.assertEqual(
+ len(retrieved[0].action_executions),
+ no_action_executions,
+ "Failed to update action_executions.",
+ )
+ self.assertEqual(len(retrieved[0].rules), no_rules, "Failed to update rules.")
+ self.assertEqual(
+ len(retrieved[0].trigger_instances),
+ no_trigger_instances,
+ "Failed to update trigger_instances.",
+ )
def test_update_via_list_push(self):
no_action_executions = 4
no_rules = 4
no_trigger_instances = 5
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
- action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)],
+ trace_tag="test_trace",
+ action_executions=[
+ str(bson.ObjectId()) for _ in range(no_action_executions)
+ ],
rules=[str(bson.ObjectId()) for _ in range(no_rules)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)])
+ trigger_instances=[
+ str(bson.ObjectId()) for _ in range(no_trigger_instances)
+ ],
+ )
# push updates
Trace.push_action_execution(
- saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId())))
+ saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId()))
+ )
Trace.push_rule(saved, rule=TraceComponentDB(object_id=str(bson.ObjectId())))
Trace.push_trigger_instance(
- saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId())))
+ saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId()))
+ )
retrieved = Trace.get(id=saved.id)
- self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.')
+ self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.")
self.assertEqual(len(retrieved.action_executions), no_action_executions + 1)
self.assertEqual(len(retrieved.rules), no_rules + 1)
self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances + 1)
@@ -112,33 +130,48 @@ def test_update_via_list_push_components(self):
no_rules = 4
no_trigger_instances = 5
saved = TraceDBTest._create_save_trace(
- trace_tag='test_trace',
- action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)],
+ trace_tag="test_trace",
+ action_executions=[
+ str(bson.ObjectId()) for _ in range(no_action_executions)
+ ],
rules=[str(bson.ObjectId()) for _ in range(no_rules)],
- trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)])
+ trigger_instances=[
+ str(bson.ObjectId()) for _ in range(no_trigger_instances)
+ ],
+ )
retrieved = Trace.push_components(
saved,
- action_executions=[TraceComponentDB(object_id=str(bson.ObjectId()))
- for _ in range(no_action_executions)],
- rules=[TraceComponentDB(object_id=str(bson.ObjectId()))
- for _ in range(no_rules)],
- trigger_instances=[TraceComponentDB(object_id=str(bson.ObjectId()))
- for _ in range(no_trigger_instances)])
-
- self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.')
+ action_executions=[
+ TraceComponentDB(object_id=str(bson.ObjectId()))
+ for _ in range(no_action_executions)
+ ],
+ rules=[
+ TraceComponentDB(object_id=str(bson.ObjectId()))
+ for _ in range(no_rules)
+ ],
+ trigger_instances=[
+ TraceComponentDB(object_id=str(bson.ObjectId()))
+ for _ in range(no_trigger_instances)
+ ],
+ )
+
+ self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.")
self.assertEqual(len(retrieved.action_executions), no_action_executions * 2)
self.assertEqual(len(retrieved.rules), no_rules * 2)
self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances * 2)
@staticmethod
- def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None,
- trigger_instances=None):
+ def _create_save_trace(
+ trace_tag, id_=None, action_executions=None, rules=None, trigger_instances=None
+ ):
if action_executions is None:
action_executions = []
- action_executions = [TraceComponentDB(object_id=action_execution)
- for action_execution in action_executions]
+ action_executions = [
+ TraceComponentDB(object_id=action_execution)
+ for action_execution in action_executions
+ ]
if rules is None:
rules = []
@@ -146,12 +179,16 @@ def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None,
if trigger_instances is None:
trigger_instances = []
- trigger_instances = [TraceComponentDB(object_id=trigger_instance)
- for trigger_instance in trigger_instances]
-
- created = TraceDB(id=id_,
- trace_tag=trace_tag,
- trigger_instances=trigger_instances,
- rules=rules,
- action_executions=action_executions)
+ trigger_instances = [
+ TraceComponentDB(object_id=trigger_instance)
+ for trigger_instance in trigger_instances
+ ]
+
+ created = TraceDB(
+ id=id_,
+ trace_tag=trace_tag,
+ trigger_instances=trigger_instances,
+ rules=rules,
+ action_executions=action_executions,
+ )
return Trace.add_or_update(created)
diff --git a/st2common/tests/unit/test_db_uid_mixin.py b/st2common/tests/unit/test_db_uid_mixin.py
index e3283e6f91..b7a6a25108 100644
--- a/st2common/tests/unit/test_db_uid_mixin.py
+++ b/st2common/tests/unit/test_db_uid_mixin.py
@@ -23,28 +23,41 @@
class UIDMixinTestCase(CleanDbTestCase):
def test_get_uid(self):
- pack_1_db = PackDB(ref='test_pack')
- pack_2_db = PackDB(ref='examples')
+ pack_1_db = PackDB(ref="test_pack")
+ pack_2_db = PackDB(ref="examples")
- self.assertEqual(pack_1_db.get_uid(), 'pack:test_pack')
- self.assertEqual(pack_2_db.get_uid(), 'pack:examples')
+ self.assertEqual(pack_1_db.get_uid(), "pack:test_pack")
+ self.assertEqual(pack_2_db.get_uid(), "pack:examples")
- action_1_db = ActionDB(pack='examples', name='my_action', ref='examples.my_action')
- action_2_db = ActionDB(pack='core', name='local', ref='core.local')
- self.assertEqual(action_1_db.get_uid(), 'action:examples:my_action')
- self.assertEqual(action_2_db.get_uid(), 'action:core:local')
+ action_1_db = ActionDB(
+ pack="examples", name="my_action", ref="examples.my_action"
+ )
+ action_2_db = ActionDB(pack="core", name="local", ref="core.local")
+ self.assertEqual(action_1_db.get_uid(), "action:examples:my_action")
+ self.assertEqual(action_2_db.get_uid(), "action:core:local")
def test_uid_is_populated_on_save(self):
- pack_1_db = PackDB(ref='test_pack', name='test', description='foo', version='1.0.0',
- author='dev', email='test@example.com')
+ pack_1_db = PackDB(
+ ref="test_pack",
+ name="test",
+ description="foo",
+ version="1.0.0",
+ author="dev",
+ email="test@example.com",
+ )
pack_1_db = Pack.add_or_update(pack_1_db)
pack_1_db.reload()
- self.assertEqual(pack_1_db.uid, 'pack:test_pack')
+ self.assertEqual(pack_1_db.uid, "pack:test_pack")
- action_1_db = ActionDB(name='local', pack='core', ref='core.local', entry_point='',
- runner_type={'name': 'local-shell-cmd'})
+ action_1_db = ActionDB(
+ name="local",
+ pack="core",
+ ref="core.local",
+ entry_point="",
+ runner_type={"name": "local-shell-cmd"},
+ )
action_1_db = Action.add_or_update(action_1_db)
action_1_db.reload()
- self.assertEqual(action_1_db.uid, 'action:core:local')
+ self.assertEqual(action_1_db.uid, "action:core:local")
diff --git a/st2common/tests/unit/test_db_workflow.py b/st2common/tests/unit/test_db_workflow.py
index 1f7ce38a4a..e434d0f9d6 100644
--- a/st2common/tests/unit/test_db_workflow.py
+++ b/st2common/tests/unit/test_db_workflow.py
@@ -26,14 +26,13 @@
from st2common.exceptions import db as db_exc
-@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock())
class WorkflowExecutionModelTest(st2tests.DbTestCase):
-
def test_workflow_execution_crud(self):
initial = wf_db_models.WorkflowExecutionDB()
initial.action_execution = uuid.uuid4().hex
- initial.graph = {'var1': 'foobar'}
- initial.status = 'requested'
+ initial.graph = {"var1": "foobar"}
+ initial.status = "requested"
# Test create
created = wf_db_access.WorkflowExecution.add_or_update(initial)
@@ -47,9 +46,11 @@ def test_workflow_execution_crud(self):
self.assertEqual(created.status, retrieved.status)
# Test update
- graph = {'var1': 'fubar'}
- status = 'running'
- retrieved = wf_db_access.WorkflowExecution.update(retrieved, graph=graph, status=status)
+ graph = {"var1": "fubar"}
+ status = "running"
+ retrieved = wf_db_access.WorkflowExecution.update(
+ retrieved, graph=graph, status=status
+ )
updated = wf_db_access.WorkflowExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
self.assertEqual(retrieved.rev, updated.rev)
@@ -58,7 +59,7 @@ def test_workflow_execution_crud(self):
self.assertEqual(retrieved.status, updated.status)
# Test add or update
- retrieved.graph = {'var2': 'fubar'}
+ retrieved.graph = {"var2": "fubar"}
retrieved = wf_db_access.WorkflowExecution.add_or_update(retrieved)
updated = wf_db_access.WorkflowExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
@@ -73,14 +74,14 @@ def test_workflow_execution_crud(self):
self.assertRaises(
db_exc.StackStormDBObjectNotFoundError,
wf_db_access.WorkflowExecution.get_by_id,
- doc_id
+ doc_id,
)
def test_workflow_execution_write_conflict(self):
initial = wf_db_models.WorkflowExecutionDB()
initial.action_execution = uuid.uuid4().hex
- initial.graph = {'var1': 'foobar'}
- initial.status = 'requested'
+ initial.graph = {"var1": "foobar"}
+ initial.status = "requested"
# Prep record
created = wf_db_access.WorkflowExecution.add_or_update(initial)
@@ -92,9 +93,11 @@ def test_workflow_execution_write_conflict(self):
retrieved2 = wf_db_access.WorkflowExecution.get_by_id(doc_id)
# Test update on instance 1, expect success
- graph = {'var1': 'fubar'}
- status = 'running'
- retrieved1 = wf_db_access.WorkflowExecution.update(retrieved1, graph=graph, status=status)
+ graph = {"var1": "fubar"}
+ status = "running"
+ retrieved1 = wf_db_access.WorkflowExecution.update(
+ retrieved1, graph=graph, status=status
+ )
updated = wf_db_access.WorkflowExecution.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
self.assertEqual(retrieved1.rev, updated.rev)
@@ -107,7 +110,7 @@ def test_workflow_execution_write_conflict(self):
db_exc.StackStormDBObjectWriteConflictError,
wf_db_access.WorkflowExecution.update,
retrieved2,
- graph={'var2': 'fubar'}
+ graph={"var2": "fubar"},
)
# Test delete
@@ -116,5 +119,5 @@ def test_workflow_execution_write_conflict(self):
self.assertRaises(
db_exc.StackStormDBObjectNotFoundError,
wf_db_access.WorkflowExecution.get_by_id,
- doc_id
+ doc_id,
)
diff --git a/st2common/tests/unit/test_dist_utils.py b/st2common/tests/unit/test_dist_utils.py
index 901f8abd44..1b01d4ff48 100644
--- a/st2common/tests/unit/test_dist_utils.py
+++ b/st2common/tests/unit/test_dist_utils.py
@@ -21,7 +21,7 @@
import unittest2
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-SCRIPTS_PATH = os.path.join(BASE_DIR, '../../../scripts/')
+SCRIPTS_PATH = os.path.join(BASE_DIR, "../../../scripts/")
# Add scripts/ which contain main dist_utils.py to PYTHONPATH
sys.path.insert(0, SCRIPTS_PATH)
@@ -32,21 +32,21 @@
from dist_utils import apply_vagrant_workaround
from dist_utils import get_version_string
-__all__ = [
- 'DistUtilsTestCase'
-]
+__all__ = ["DistUtilsTestCase"]
-REQUIREMENTS_PATH_1 = os.path.join(BASE_DIR, '../fixtures/requirements-used-for-tests.txt')
-REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, '../../../requirements.txt')
-VERSION_FILE_PATH = os.path.join(BASE_DIR, '../fixtures/version_file.py')
+REQUIREMENTS_PATH_1 = os.path.join(
+ BASE_DIR, "../fixtures/requirements-used-for-tests.txt"
+)
+REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, "../../../requirements.txt")
+VERSION_FILE_PATH = os.path.join(BASE_DIR, "../fixtures/version_file.py")
class DistUtilsTestCase(unittest2.TestCase):
def setUp(self):
super(DistUtilsTestCase, self).setUp()
- if 'pip' in sys.modules:
- del sys.modules['pip']
+ if "pip" in sys.modules:
+ del sys.modules["pip"]
def tearDown(self):
super(DistUtilsTestCase, self).tearDown()
@@ -54,15 +54,15 @@ def tearDown(self):
def test_check_pip_is_installed_success(self):
self.assertTrue(check_pip_is_installed())
- @mock.patch('sys.exit')
+ @mock.patch("sys.exit")
def test_check_pip_is_installed_failure(self, mock_sys_exit):
if six.PY3:
- module_name = 'builtins.__import__'
+ module_name = "builtins.__import__"
else:
- module_name = '__builtin__.__import__'
+ module_name = "__builtin__.__import__"
with mock.patch(module_name) as mock_import:
- mock_import.side_effect = ImportError('not found')
+ mock_import.side_effect = ImportError("not found")
self.assertEqual(mock_sys_exit.call_count, 0)
check_pip_is_installed()
@@ -72,12 +72,12 @@ def test_check_pip_is_installed_failure(self, mock_sys_exit):
def test_check_pip_version_success(self):
self.assertTrue(check_pip_version())
- @mock.patch('sys.exit')
+ @mock.patch("sys.exit")
def test_check_pip_version_failure(self, mock_sys_exit):
mock_pip = mock.Mock()
- mock_pip.__version__ = '0.0.0'
- sys.modules['pip'] = mock_pip
+ mock_pip.__version__ = "0.0.0"
+ sys.modules["pip"] = mock_pip
self.assertEqual(mock_sys_exit.call_count, 0)
check_pip_version()
@@ -86,50 +86,50 @@ def test_check_pip_version_failure(self, mock_sys_exit):
def test_get_version_string(self):
version = get_version_string(VERSION_FILE_PATH)
- self.assertEqual(version, '1.2.3')
+ self.assertEqual(version, "1.2.3")
def test_apply_vagrant_workaround(self):
- with mock.patch('os.link') as _:
- os.environ['USER'] = 'stanley'
+ with mock.patch("os.link") as _:
+ os.environ["USER"] = "stanley"
apply_vagrant_workaround()
self.assertTrue(os.link)
- with mock.patch('os.link') as _:
- os.environ['USER'] = 'vagrant'
+ with mock.patch("os.link") as _:
+ os.environ["USER"] = "vagrant"
apply_vagrant_workaround()
- self.assertFalse(getattr(os, 'link', None))
+ self.assertFalse(getattr(os, "link", None))
def test_fetch_requirements(self):
expected_reqs = [
- 'RandomWords',
- 'amqp==2.5.1',
- 'argcomplete',
- 'bcrypt==3.1.6',
- 'flex==6.14.0',
- 'logshipper',
- 'orquesta',
- 'st2-auth-backend-flat-file',
- 'logshipper-editable',
- 'python_runner',
- 'SomePackageHq',
- 'SomePackageSvn',
- 'gitpython==2.1.11',
- 'ose-timer==0.7.5',
- 'oslo.config<1.13,>=1.12.1',
- 'requests[security]<2.22.0,>=2.21.0',
- 'retrying==1.3.3',
- 'zake==0.2.2'
+ "RandomWords",
+ "amqp==2.5.1",
+ "argcomplete",
+ "bcrypt==3.1.6",
+ "flex==6.14.0",
+ "logshipper",
+ "orquesta",
+ "st2-auth-backend-flat-file",
+ "logshipper-editable",
+ "python_runner",
+ "SomePackageHq",
+ "SomePackageSvn",
+ "gitpython==2.1.11",
+ "ose-timer==0.7.5",
+ "oslo.config<1.13,>=1.12.1",
+ "requests[security]<2.22.0,>=2.21.0",
+ "retrying==1.3.3",
+ "zake==0.2.2",
]
expected_links = [
- 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper',
- 'git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta', # NOQA
- 'git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file', # NOQA
- 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable',
- 'git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner', # NOQA
- 'hg+https://hg.repo/some_pkg.git#egg=SomePackageHq',
- 'svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn'
+ "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper",
+ "git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta", # NOQA
+ "git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file", # NOQA
+ "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable",
+ "git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner", # NOQA
+ "hg+https://hg.repo/some_pkg.git#egg=SomePackageHq",
+ "svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn",
]
reqs, links = fetch_requirements(REQUIREMENTS_PATH_1)
diff --git a/st2common/tests/unit/test_exceptions_workflow.py b/st2common/tests/unit/test_exceptions_workflow.py
index 9e37f6c5d9..a9fbcc549f 100644
--- a/st2common/tests/unit/test_exceptions_workflow.py
+++ b/st2common/tests/unit/test_exceptions_workflow.py
@@ -26,7 +26,6 @@
class WorkflowExceptionTest(unittest2.TestCase):
-
def test_retry_on_transient_db_errors(self):
instance = wf_db_models.WorkflowExecutionDB()
exc = db_exc.StackStormDBObjectWriteConflictError(instance)
@@ -34,13 +33,13 @@ def test_retry_on_transient_db_errors(self):
def test_do_not_retry_on_transient_db_errors(self):
instance = wf_db_models.WorkflowExecutionDB()
- exc = db_exc.StackStormDBObjectConflictError('foobar', '1234', instance)
+ exc = db_exc.StackStormDBObjectConflictError("foobar", "1234", instance)
self.assertFalse(wf_exc.retry_on_transient_db_errors(exc))
self.assertFalse(wf_exc.retry_on_transient_db_errors(NotImplementedError()))
self.assertFalse(wf_exc.retry_on_transient_db_errors(Exception()))
def test_retry_on_connection_errors(self):
- exc = coordination.ToozConnectionError('foobar')
+ exc = coordination.ToozConnectionError("foobar")
self.assertTrue(wf_exc.retry_on_connection_errors(exc))
exc = mongoengine.connection.MongoEngineConnectionError()
diff --git a/st2common/tests/unit/test_executions.py b/st2common/tests/unit/test_executions.py
index 59353379ac..0be1ca7c9d 100644
--- a/st2common/tests/unit/test_executions.py
+++ b/st2common/tests/unit/test_executions.py
@@ -29,94 +29,117 @@
class TestActionExecutionHistoryModel(DbTestCase):
-
def setUp(self):
super(TestActionExecutionHistoryModel, self).setUp()
# Fake execution record for action liveactions triggered by workflow runner.
self.fake_history_subtasks = [
{
- 'id': str(bson.ObjectId()),
- 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']),
- 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']),
- 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1']),
- 'status': fixture.ARTIFACTS['liveactions']['task1']['status'],
- 'start_timestamp': fixture.ARTIFACTS['liveactions']['task1']['start_timestamp'],
- 'end_timestamp': fixture.ARTIFACTS['liveactions']['task1']['end_timestamp']
+ "id": str(bson.ObjectId()),
+ "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]),
+ "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]),
+ "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]),
+ "status": fixture.ARTIFACTS["liveactions"]["task1"]["status"],
+ "start_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][
+ "start_timestamp"
+ ],
+ "end_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][
+ "end_timestamp"
+ ],
},
{
- 'id': str(bson.ObjectId()),
- 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']),
- 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']),
- 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task2']),
- 'status': fixture.ARTIFACTS['liveactions']['task2']['status'],
- 'start_timestamp': fixture.ARTIFACTS['liveactions']['task2']['start_timestamp'],
- 'end_timestamp': fixture.ARTIFACTS['liveactions']['task2']['end_timestamp']
- }
+ "id": str(bson.ObjectId()),
+ "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]),
+ "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]),
+ "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task2"]),
+ "status": fixture.ARTIFACTS["liveactions"]["task2"]["status"],
+ "start_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][
+ "start_timestamp"
+ ],
+ "end_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][
+ "end_timestamp"
+ ],
+ },
]
# Fake execution record for a workflow action execution triggered by rule.
self.fake_history_workflow = {
- 'id': str(bson.ObjectId()),
- 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']),
- 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']),
- 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']),
- 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']),
- 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']),
- 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']),
- 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']),
- 'children': [task['id'] for task in self.fake_history_subtasks],
- 'status': fixture.ARTIFACTS['liveactions']['workflow']['status'],
- 'start_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['start_timestamp'],
- 'end_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['end_timestamp']
+ "id": str(bson.ObjectId()),
+ "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]),
+ "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]),
+ "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]),
+ "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]),
+ "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]),
+ "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]),
+ "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["workflow"]),
+ "children": [task["id"] for task in self.fake_history_subtasks],
+ "status": fixture.ARTIFACTS["liveactions"]["workflow"]["status"],
+ "start_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][
+ "start_timestamp"
+ ],
+ "end_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][
+ "end_timestamp"
+ ],
}
# Assign parent to the execution records for the subtasks.
for task in self.fake_history_subtasks:
- task['parent'] = self.fake_history_workflow['id']
+ task["parent"] = self.fake_history_workflow["id"]
def test_model_complete(self):
# Create API object.
obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_workflow))
- self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger'])
- self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type'])
- self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance'])
- self.assertDictEqual(obj.rule, self.fake_history_workflow['rule'])
- self.assertDictEqual(obj.action, self.fake_history_workflow['action'])
- self.assertDictEqual(obj.runner, self.fake_history_workflow['runner'])
- self.assertEqual(obj.liveaction, self.fake_history_workflow['liveaction'])
- self.assertIsNone(getattr(obj, 'parent', None))
- self.assertListEqual(obj.children, self.fake_history_workflow['children'])
+ self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"])
+ self.assertDictEqual(
+ obj.trigger_type, self.fake_history_workflow["trigger_type"]
+ )
+ self.assertDictEqual(
+ obj.trigger_instance, self.fake_history_workflow["trigger_instance"]
+ )
+ self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"])
+ self.assertDictEqual(obj.action, self.fake_history_workflow["action"])
+ self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"])
+ self.assertEqual(obj.liveaction, self.fake_history_workflow["liveaction"])
+ self.assertIsNone(getattr(obj, "parent", None))
+ self.assertListEqual(obj.children, self.fake_history_workflow["children"])
# Convert API object to DB model.
model = ActionExecutionAPI.to_model(obj)
self.assertEqual(str(model.id), obj.id)
- self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger'])
- self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type'])
- self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance'])
- self.assertDictEqual(model.rule, self.fake_history_workflow['rule'])
- self.assertDictEqual(model.action, self.fake_history_workflow['action'])
- self.assertDictEqual(model.runner, self.fake_history_workflow['runner'])
- doc = copy.deepcopy(self.fake_history_workflow['liveaction'])
- doc['start_timestamp'] = doc['start_timestamp']
- doc['end_timestamp'] = doc['end_timestamp']
+ self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"])
+ self.assertDictEqual(
+ model.trigger_type, self.fake_history_workflow["trigger_type"]
+ )
+ self.assertDictEqual(
+ model.trigger_instance, self.fake_history_workflow["trigger_instance"]
+ )
+ self.assertDictEqual(model.rule, self.fake_history_workflow["rule"])
+ self.assertDictEqual(model.action, self.fake_history_workflow["action"])
+ self.assertDictEqual(model.runner, self.fake_history_workflow["runner"])
+ doc = copy.deepcopy(self.fake_history_workflow["liveaction"])
+ doc["start_timestamp"] = doc["start_timestamp"]
+ doc["end_timestamp"] = doc["end_timestamp"]
self.assertDictEqual(model.liveaction, doc)
- self.assertIsNone(getattr(model, 'parent', None))
- self.assertListEqual(model.children, self.fake_history_workflow['children'])
+ self.assertIsNone(getattr(model, "parent", None))
+ self.assertListEqual(model.children, self.fake_history_workflow["children"])
# Convert DB model to API object.
obj = ActionExecutionAPI.from_model(model)
self.assertEqual(str(model.id), obj.id)
- self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger'])
- self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type'])
- self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance'])
- self.assertDictEqual(obj.rule, self.fake_history_workflow['rule'])
- self.assertDictEqual(obj.action, self.fake_history_workflow['action'])
- self.assertDictEqual(obj.runner, self.fake_history_workflow['runner'])
- self.assertDictEqual(obj.liveaction, self.fake_history_workflow['liveaction'])
- self.assertIsNone(getattr(obj, 'parent', None))
- self.assertListEqual(obj.children, self.fake_history_workflow['children'])
+ self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"])
+ self.assertDictEqual(
+ obj.trigger_type, self.fake_history_workflow["trigger_type"]
+ )
+ self.assertDictEqual(
+ obj.trigger_instance, self.fake_history_workflow["trigger_instance"]
+ )
+ self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"])
+ self.assertDictEqual(obj.action, self.fake_history_workflow["action"])
+ self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"])
+ self.assertDictEqual(obj.liveaction, self.fake_history_workflow["liveaction"])
+ self.assertIsNone(getattr(obj, "parent", None))
+ self.assertListEqual(obj.children, self.fake_history_workflow["children"])
def test_crud_complete(self):
# Create the DB record.
@@ -124,18 +147,22 @@ def test_crud_complete(self):
ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj))
model = ActionExecution.get_by_id(obj.id)
self.assertEqual(str(model.id), obj.id)
- self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger'])
- self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type'])
- self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance'])
- self.assertDictEqual(model.rule, self.fake_history_workflow['rule'])
- self.assertDictEqual(model.action, self.fake_history_workflow['action'])
- self.assertDictEqual(model.runner, self.fake_history_workflow['runner'])
- doc = copy.deepcopy(self.fake_history_workflow['liveaction'])
- doc['start_timestamp'] = doc['start_timestamp']
- doc['end_timestamp'] = doc['end_timestamp']
+ self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"])
+ self.assertDictEqual(
+ model.trigger_type, self.fake_history_workflow["trigger_type"]
+ )
+ self.assertDictEqual(
+ model.trigger_instance, self.fake_history_workflow["trigger_instance"]
+ )
+ self.assertDictEqual(model.rule, self.fake_history_workflow["rule"])
+ self.assertDictEqual(model.action, self.fake_history_workflow["action"])
+ self.assertDictEqual(model.runner, self.fake_history_workflow["runner"])
+ doc = copy.deepcopy(self.fake_history_workflow["liveaction"])
+ doc["start_timestamp"] = doc["start_timestamp"]
+ doc["end_timestamp"] = doc["end_timestamp"]
self.assertDictEqual(model.liveaction, doc)
- self.assertIsNone(getattr(model, 'parent', None))
- self.assertListEqual(model.children, self.fake_history_workflow['children'])
+ self.assertIsNone(getattr(model, "parent", None))
+ self.assertListEqual(model.children, self.fake_history_workflow["children"])
# Update the DB record.
children = [str(bson.ObjectId()), str(bson.ObjectId())]
@@ -146,20 +173,24 @@ def test_crud_complete(self):
# Delete the DB record.
ActionExecution.delete(model)
- self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id
+ )
def test_model_partial(self):
# Create API object.
obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_subtasks[0]))
- self.assertIsNone(getattr(obj, 'trigger', None))
- self.assertIsNone(getattr(obj, 'trigger_type', None))
- self.assertIsNone(getattr(obj, 'trigger_instance', None))
- self.assertIsNone(getattr(obj, 'rule', None))
- self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action'])
- self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner'])
- self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction'])
- self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent'])
- self.assertIsNone(getattr(obj, 'children', None))
+ self.assertIsNone(getattr(obj, "trigger", None))
+ self.assertIsNone(getattr(obj, "trigger_type", None))
+ self.assertIsNone(getattr(obj, "trigger_instance", None))
+ self.assertIsNone(getattr(obj, "rule", None))
+ self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"])
+ self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"])
+ self.assertDictEqual(
+ obj.liveaction, self.fake_history_subtasks[0]["liveaction"]
+ )
+ self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"])
+ self.assertIsNone(getattr(obj, "children", None))
# Convert API object to DB model.
model = ActionExecutionAPI.to_model(obj)
@@ -168,28 +199,30 @@ def test_model_partial(self):
self.assertDictEqual(model.trigger_type, {})
self.assertDictEqual(model.trigger_instance, {})
self.assertDictEqual(model.rule, {})
- self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action'])
- self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner'])
- doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction'])
- doc['start_timestamp'] = doc['start_timestamp']
- doc['end_timestamp'] = doc['end_timestamp']
+ self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"])
+ self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"])
+ doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"])
+ doc["start_timestamp"] = doc["start_timestamp"]
+ doc["end_timestamp"] = doc["end_timestamp"]
self.assertDictEqual(model.liveaction, doc)
- self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent'])
+ self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"])
self.assertListEqual(model.children, [])
# Convert DB model to API object.
obj = ActionExecutionAPI.from_model(model)
self.assertEqual(str(model.id), obj.id)
- self.assertIsNone(getattr(obj, 'trigger', None))
- self.assertIsNone(getattr(obj, 'trigger_type', None))
- self.assertIsNone(getattr(obj, 'trigger_instance', None))
- self.assertIsNone(getattr(obj, 'rule', None))
- self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action'])
- self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner'])
- self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction'])
- self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent'])
- self.assertIsNone(getattr(obj, 'children', None))
+ self.assertIsNone(getattr(obj, "trigger", None))
+ self.assertIsNone(getattr(obj, "trigger_type", None))
+ self.assertIsNone(getattr(obj, "trigger_instance", None))
+ self.assertIsNone(getattr(obj, "rule", None))
+ self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"])
+ self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"])
+ self.assertDictEqual(
+ obj.liveaction, self.fake_history_subtasks[0]["liveaction"]
+ )
+ self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"])
+ self.assertIsNone(getattr(obj, "children", None))
def test_crud_partial(self):
# Create the DB record.
@@ -201,13 +234,13 @@ def test_crud_partial(self):
self.assertDictEqual(model.trigger_type, {})
self.assertDictEqual(model.trigger_instance, {})
self.assertDictEqual(model.rule, {})
- self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action'])
- self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner'])
- doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction'])
- doc['start_timestamp'] = doc['start_timestamp']
- doc['end_timestamp'] = doc['end_timestamp']
+ self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"])
+ self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"])
+ doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"])
+ doc["start_timestamp"] = doc["start_timestamp"]
+ doc["end_timestamp"] = doc["end_timestamp"]
self.assertDictEqual(model.liveaction, doc)
- self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent'])
+ self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"])
self.assertListEqual(model.children, [])
# Update the DB record.
@@ -219,23 +252,25 @@ def test_crud_partial(self):
# Delete the DB record.
ActionExecution.delete(model)
- self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id
+ )
def test_datetime_range(self):
base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0))
for i in range(60):
timestamp = base + datetime.timedelta(seconds=i)
doc = copy.deepcopy(self.fake_history_subtasks[0])
- doc['id'] = str(bson.ObjectId())
- doc['start_timestamp'] = isotime.format(timestamp)
+ doc["id"] = str(bson.ObjectId())
+ doc["start_timestamp"] = isotime.format(timestamp)
obj = ActionExecutionAPI(**doc)
ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj))
- dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z'
+ dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z"
objs = ActionExecution.query(start_timestamp=dt_range)
self.assertEqual(len(objs), 10)
- dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z'
+ dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z"
objs = ActionExecution.query(start_timestamp=dt_range)
self.assertEqual(len(objs), 10)
@@ -244,19 +279,19 @@ def test_sort_by_start_timestamp(self):
for i in range(60):
timestamp = base + datetime.timedelta(seconds=i)
doc = copy.deepcopy(self.fake_history_subtasks[0])
- doc['id'] = str(bson.ObjectId())
- doc['start_timestamp'] = isotime.format(timestamp)
+ doc["id"] = str(bson.ObjectId())
+ doc["start_timestamp"] = isotime.format(timestamp)
obj = ActionExecutionAPI(**doc)
ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj))
- dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z'
- objs = ActionExecution.query(start_timestamp=dt_range,
- order_by=['start_timestamp'])
- self.assertLess(objs[0]['start_timestamp'],
- objs[9]['start_timestamp'])
+ dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z"
+ objs = ActionExecution.query(
+ start_timestamp=dt_range, order_by=["start_timestamp"]
+ )
+ self.assertLess(objs[0]["start_timestamp"], objs[9]["start_timestamp"])
- dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z'
- objs = ActionExecution.query(start_timestamp=dt_range,
- order_by=['-start_timestamp'])
- self.assertLess(objs[9]['start_timestamp'],
- objs[0]['start_timestamp'])
+ dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z"
+ objs = ActionExecution.query(
+ start_timestamp=dt_range, order_by=["-start_timestamp"]
+ )
+ self.assertLess(objs[9]["start_timestamp"], objs[0]["start_timestamp"])
diff --git a/st2common/tests/unit/test_executions_util.py b/st2common/tests/unit/test_executions_util.py
index 9177493188..f7702614d3 100644
--- a/st2common/tests/unit/test_executions_util.py
+++ b/st2common/tests/unit/test_executions_util.py
@@ -35,25 +35,28 @@
import st2tests.config as tests_config
from six.moves import range
+
tests_config.parse_args()
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_FIXTURES = {
- 'liveactions': ['liveaction1.yaml', 'parentliveaction.yaml', 'childliveaction.yaml',
- 'successful_liveaction.yaml'],
- 'actions': ['local.yaml'],
- 'executions': ['execution1.yaml'],
- 'runners': ['run-local.yaml'],
- 'triggertypes': ['triggertype2.yaml'],
- 'rules': ['rule3.yaml'],
- 'triggers': ['trigger2.yaml'],
- 'triggerinstances': ['trigger_instance_1.yaml']
+ "liveactions": [
+ "liveaction1.yaml",
+ "parentliveaction.yaml",
+ "childliveaction.yaml",
+ "successful_liveaction.yaml",
+ ],
+ "actions": ["local.yaml"],
+ "executions": ["execution1.yaml"],
+ "runners": ["run-local.yaml"],
+ "triggertypes": ["triggertype2.yaml"],
+ "rules": ["rule3.yaml"],
+ "triggers": ["trigger2.yaml"],
+ "triggerinstances": ["trigger_instance_1.yaml"],
}
-DYNAMIC_FIXTURES = {
- 'liveactions': ['liveaction3.yaml']
-}
+DYNAMIC_FIXTURES = {"liveactions": ["liveaction3.yaml"]}
class ExecutionsUtilTestCase(CleanDbTestCase):
@@ -63,118 +66,144 @@ def __init__(self, *args, **kwargs):
def setUp(self):
super(ExecutionsUtilTestCase, self).setUp()
- self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_FIXTURES)
- self.FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=DYNAMIC_FIXTURES)
+ self.MODELS = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
+ self.FIXTURES = FixturesLoader().load_fixtures(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=DYNAMIC_FIXTURES
+ )
def test_execution_creation_manual_action_run(self):
- liveaction = self.MODELS['liveactions']['liveaction1.yaml']
+ liveaction = self.MODELS["liveactions"]["liveaction1.yaml"]
pre_creation_timestamp = date_utils.get_datetime_utc_now()
executions_util.create_execution_object(liveaction)
post_creation_timestamp = date_utils.get_datetime_utc_now()
- execution = self._get_action_execution(liveaction__id=str(liveaction.id),
- raise_exception=True)
+ execution = self._get_action_execution(
+ liveaction__id=str(liveaction.id), raise_exception=True
+ )
self.assertDictEqual(execution.trigger, {})
self.assertDictEqual(execution.trigger_type, {})
self.assertDictEqual(execution.trigger_instance, {})
self.assertDictEqual(execution.rule, {})
- action = action_utils.get_action_by_ref('core.local')
+ action = action_utils.get_action_by_ref("core.local")
self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action)))
- runner = RunnerType.get_by_name(action.runner_type['name'])
+ runner = RunnerType.get_by_name(action.runner_type["name"])
self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner)))
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(execution.liveaction['id'], str(liveaction.id))
+ self.assertEqual(execution.liveaction["id"], str(liveaction.id))
self.assertEqual(len(execution.log), 1)
- self.assertEqual(execution.log[0]['status'], liveaction.status)
- self.assertGreater(execution.log[0]['timestamp'], pre_creation_timestamp)
- self.assertLess(execution.log[0]['timestamp'], post_creation_timestamp)
+ self.assertEqual(execution.log[0]["status"], liveaction.status)
+ self.assertGreater(execution.log[0]["timestamp"], pre_creation_timestamp)
+ self.assertLess(execution.log[0]["timestamp"], post_creation_timestamp)
def test_execution_creation_action_triggered_by_rule(self):
# Wait for the action execution to complete and then confirm outcome.
- trigger_type = self.MODELS['triggertypes']['triggertype2.yaml']
- trigger = self.MODELS['triggers']['trigger2.yaml']
- trigger_instance = self.MODELS['triggerinstances']['trigger_instance_1.yaml']
- test_liveaction = self.FIXTURES['liveactions']['liveaction3.yaml']
- rule = self.MODELS['rules']['rule3.yaml']
+ trigger_type = self.MODELS["triggertypes"]["triggertype2.yaml"]
+ trigger = self.MODELS["triggers"]["trigger2.yaml"]
+ trigger_instance = self.MODELS["triggerinstances"]["trigger_instance_1.yaml"]
+ test_liveaction = self.FIXTURES["liveactions"]["liveaction3.yaml"]
+ rule = self.MODELS["rules"]["rule3.yaml"]
# Setup LiveAction to point to right rule and trigger_instance.
# XXX: We need support for dynamic fixtures.
- test_liveaction['context']['rule']['id'] = str(rule.id)
- test_liveaction['context']['trigger_instance']['id'] = str(trigger_instance.id)
+ test_liveaction["context"]["rule"]["id"] = str(rule.id)
+ test_liveaction["context"]["trigger_instance"]["id"] = str(trigger_instance.id)
test_liveaction_api = LiveActionAPI(**test_liveaction)
- test_liveaction = LiveAction.add_or_update(LiveActionAPI.to_model(test_liveaction_api))
- liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id))
+ test_liveaction = LiveAction.add_or_update(
+ LiveActionAPI.to_model(test_liveaction_api)
+ )
+ liveaction = LiveAction.get(
+ context__trigger_instance__id=str(trigger_instance.id)
+ )
self.assertIsNotNone(liveaction)
- self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED)
+ self.assertEqual(
+ liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED
+ )
executions_util.create_execution_object(liveaction)
- execution = self._get_action_execution(liveaction__id=str(liveaction.id),
- raise_exception=True)
+ execution = self._get_action_execution(
+ liveaction__id=str(liveaction.id), raise_exception=True
+ )
self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger)))
- self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)))
- self.assertDictEqual(execution.trigger_instance,
- vars(TriggerInstanceAPI.from_model(trigger_instance)))
+ self.assertDictEqual(
+ execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))
+ )
+ self.assertDictEqual(
+ execution.trigger_instance,
+ vars(TriggerInstanceAPI.from_model(trigger_instance)),
+ )
self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule)))
action = action_utils.get_action_by_ref(liveaction.action)
self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action)))
- runner = RunnerType.get_by_name(action.runner_type['name'])
+ runner = RunnerType.get_by_name(action.runner_type["name"])
self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner)))
liveaction = LiveAction.get_by_id(str(liveaction.id))
- self.assertEqual(execution.liveaction['id'], str(liveaction.id))
+ self.assertEqual(execution.liveaction["id"], str(liveaction.id))
def test_execution_creation_with_web_url(self):
- liveaction = self.MODELS['liveactions']['liveaction1.yaml']
+ liveaction = self.MODELS["liveactions"]["liveaction1.yaml"]
executions_util.create_execution_object(liveaction)
- execution = self._get_action_execution(liveaction__id=str(liveaction.id),
- raise_exception=True)
+ execution = self._get_action_execution(
+ liveaction__id=str(liveaction.id), raise_exception=True
+ )
self.assertIsNotNone(execution.web_url)
execution_id = str(execution.id)
- self.assertIn(('history/%s/general' % execution_id), execution.web_url)
+ self.assertIn(("history/%s/general" % execution_id), execution.web_url)
def test_execution_creation_chains(self):
- childliveaction = self.MODELS['liveactions']['childliveaction.yaml']
+ childliveaction = self.MODELS["liveactions"]["childliveaction.yaml"]
child_exec = executions_util.create_execution_object(childliveaction)
- parent_execution_id = childliveaction.context['parent']['execution_id']
+ parent_execution_id = childliveaction.context["parent"]["execution_id"]
parent_execution = ActionExecution.get_by_id(parent_execution_id)
child_execs = parent_execution.children
self.assertIn(str(child_exec.id), child_execs)
def test_execution_update(self):
- liveaction = self.MODELS['liveactions']['liveaction1.yaml']
+ liveaction = self.MODELS["liveactions"]["liveaction1.yaml"]
executions_util.create_execution_object(liveaction)
- liveaction.status = 'running'
+ liveaction.status = "running"
pre_update_timestamp = date_utils.get_datetime_utc_now()
executions_util.update_execution(liveaction)
post_update_timestamp = date_utils.get_datetime_utc_now()
- execution = self._get_action_execution(liveaction__id=str(liveaction.id),
- raise_exception=True)
+ execution = self._get_action_execution(
+ liveaction__id=str(liveaction.id), raise_exception=True
+ )
self.assertEqual(len(execution.log), 2)
- self.assertEqual(execution.log[1]['status'], liveaction.status)
- self.assertGreater(execution.log[1]['timestamp'], pre_update_timestamp)
- self.assertLess(execution.log[1]['timestamp'], post_update_timestamp)
-
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
- @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None))
+ self.assertEqual(execution.log[1]["status"], liveaction.status)
+ self.assertGreater(execution.log[1]["timestamp"], pre_update_timestamp)
+ self.assertLess(execution.log[1]["timestamp"], post_update_timestamp)
+
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+ @mock.patch.object(
+ runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)
+ )
def test_abandon_executions(self):
- liveaction_db = self.MODELS['liveactions']['liveaction1.yaml']
+ liveaction_db = self.MODELS["liveactions"]["liveaction1.yaml"]
executions_util.create_execution_object(liveaction_db)
execution_db = executions_util.abandon_execution_if_incomplete(
- liveaction_id=str(liveaction_db.id))
+ liveaction_id=str(liveaction_db.id)
+ )
- self.assertEqual(execution_db.status, 'abandoned')
+ self.assertEqual(execution_db.status, "abandoned")
runners_utils.invoke_post_run.assert_called_once()
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
- @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None))
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
+ @mock.patch.object(
+ runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)
+ )
def test_abandon_executions_on_complete(self):
- liveaction_db = self.MODELS['liveactions']['successful_liveaction.yaml']
+ liveaction_db = self.MODELS["liveactions"]["successful_liveaction.yaml"]
executions_util.create_execution_object(liveaction_db)
- expected_msg = r'LiveAction %s already in a completed state %s\.' % \
- (str(liveaction_db.id), liveaction_db.status)
-
- self.assertRaisesRegexp(ValueError, expected_msg,
- executions_util.abandon_execution_if_incomplete,
- liveaction_id=str(liveaction_db.id))
+ expected_msg = r"LiveAction %s already in a completed state %s\." % (
+ str(liveaction_db.id),
+ liveaction_db.status,
+ )
+
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ executions_util.abandon_execution_if_incomplete,
+ liveaction_id=str(liveaction_db.id),
+ )
runners_utils.invoke_post_run.assert_not_called()
@@ -184,12 +213,20 @@ def _get_action_execution(self, **kwargs):
# descendants test section
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
@@ -200,75 +237,90 @@ def __init__(self, *args, **kwargs):
def setUp(self):
super(ExecutionsUtilDescendantsTestCase, self).setUp()
- self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
+ self.MODELS = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
def test_get_all_descendants_sorted(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- all_descendants = executions_util.get_descendants(str(root_execution.id),
- result_fmt='sorted')
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ all_descendants = executions_util.get_descendants(
+ str(root_execution.id), result_fmt="sorted"
+ )
all_descendants_ids = [str(descendant.id) for descendant in all_descendants]
all_descendants_ids.sort()
# everything except the root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.id != root_execution.id]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.id != root_execution.id
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
# verify sort order
for idx in range(len(all_descendants) - 1):
- self.assertLess(all_descendants[idx].start_timestamp,
- all_descendants[idx + 1].start_timestamp)
+ self.assertLess(
+ all_descendants[idx].start_timestamp,
+ all_descendants[idx + 1].start_timestamp,
+ )
def test_get_all_descendants(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
all_descendants = executions_util.get_descendants(str(root_execution.id))
all_descendants_ids = [str(descendant.id) for descendant in all_descendants]
all_descendants_ids.sort()
# everything except the root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.id != root_execution.id]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.id != root_execution.id
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
def test_get_1_level_descendants_sorted(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- all_descendants = executions_util.get_descendants(str(root_execution.id),
- descendant_depth=1,
- result_fmt='sorted')
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ all_descendants = executions_util.get_descendants(
+ str(root_execution.id), descendant_depth=1, result_fmt="sorted"
+ )
all_descendants_ids = [str(descendant.id) for descendant in all_descendants]
all_descendants_ids.sort()
# All children of root_execution
- expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions'])
- if v.parent == str(root_execution.id)]
+ expected_ids = [
+ str(v.id)
+ for _, v in six.iteritems(self.MODELS["executions"])
+ if v.parent == str(root_execution.id)
+ ]
expected_ids.sort()
self.assertListEqual(all_descendants_ids, expected_ids)
# verify sort order
for idx in range(len(all_descendants) - 1):
- self.assertLess(all_descendants[idx].start_timestamp,
- all_descendants[idx + 1].start_timestamp)
+ self.assertLess(
+ all_descendants[idx].start_timestamp,
+ all_descendants[idx + 1].start_timestamp,
+ )
def test_get_2_level_descendants_sorted(self):
- root_execution = self.MODELS['executions']['root_execution.yaml']
- all_descendants = executions_util.get_descendants(str(root_execution.id),
- descendant_depth=2,
- result_fmt='sorted')
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
+ all_descendants = executions_util.get_descendants(
+ str(root_execution.id), descendant_depth=2, result_fmt="sorted"
+ )
all_descendants_ids = [str(descendant.id) for descendant in all_descendants]
all_descendants_ids.sort()
# All children of root_execution
- root_execution = self.MODELS['executions']['root_execution.yaml']
+ root_execution = self.MODELS["executions"]["root_execution.yaml"]
expected_ids = []
traverse = [(child_id, 1) for child_id in root_execution.children]
while traverse:
@@ -282,7 +334,7 @@ def test_get_2_level_descendants_sorted(self):
self.assertListEqual(all_descendants_ids, expected_ids)
def _get_action_execution(self, ae_id):
- for _, execution in six.iteritems(self.MODELS['executions']):
+ for _, execution in six.iteritems(self.MODELS["executions"]):
if str(execution.id) == ae_id:
return execution
return None
diff --git a/st2common/tests/unit/test_greenpooldispatch.py b/st2common/tests/unit/test_greenpooldispatch.py
index 84c411d140..45cc568759 100644
--- a/st2common/tests/unit/test_greenpooldispatch.py
+++ b/st2common/tests/unit/test_greenpooldispatch.py
@@ -23,7 +23,6 @@
class TestGreenPoolDispatch(TestCase):
-
def test_dispatch_simple(self):
dispatcher = BufferedDispatcher(dispatch_pool_size=10)
mock_handler = mock.MagicMock()
@@ -34,13 +33,17 @@ def test_dispatch_simple(self):
while mock_handler.call_count < 10:
eventlet.sleep(0.01)
dispatcher.shutdown()
- call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list]
+ call_args_list = [
+ (args[0][0], args[0][1]) for args in mock_handler.call_args_list
+ ]
self.assertItemsEqual(expected, call_args_list)
def test_dispatch_starved(self):
- dispatcher = BufferedDispatcher(dispatch_pool_size=2,
- monitor_thread_empty_q_sleep_time=0.01,
- monitor_thread_no_workers_sleep_time=0.01)
+ dispatcher = BufferedDispatcher(
+ dispatch_pool_size=2,
+ monitor_thread_empty_q_sleep_time=0.01,
+ monitor_thread_no_workers_sleep_time=0.01,
+ )
mock_handler = mock.MagicMock()
expected = []
for i in range(10):
@@ -49,5 +52,7 @@ def test_dispatch_starved(self):
while mock_handler.call_count < 10:
eventlet.sleep(0.01)
dispatcher.shutdown()
- call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list]
+ call_args_list = [
+ (args[0][0], args[0][1]) for args in mock_handler.call_args_list
+ ]
self.assertItemsEqual(expected, call_args_list)
diff --git a/st2common/tests/unit/test_hash.py b/st2common/tests/unit/test_hash.py
index 7211879ff6..234d4969da 100644
--- a/st2common/tests/unit/test_hash.py
+++ b/st2common/tests/unit/test_hash.py
@@ -22,15 +22,14 @@
class TestHashWithApiKeys(unittest2.TestCase):
-
def test_hash_repeatability(self):
api_key = auth_utils.generate_api_key()
hash1 = hash_utils.hash(api_key)
hash2 = hash_utils.hash(api_key)
- self.assertEqual(hash1, hash2, 'Expected a repeated hash.')
+ self.assertEqual(hash1, hash2, "Expected a repeated hash.")
def test_hash_uniqueness(self):
count = 10000
api_keys = [auth_utils.generate_api_key() for _ in range(count)]
hashes = set([hash_utils.hash(api_key) for api_key in api_keys])
- self.assertEqual(len(hashes), count, 'Expected all unique hashes.')
+ self.assertEqual(len(hashes), count, "Expected all unique hashes.")
diff --git a/st2common/tests/unit/test_ip_utils.py b/st2common/tests/unit/test_ip_utils.py
index a33c220d71..cd1339be73 100644
--- a/st2common/tests/unit/test_ip_utils.py
+++ b/st2common/tests/unit/test_ip_utils.py
@@ -20,73 +20,72 @@
class IPUtilsTests(unittest2.TestCase):
-
def test_host_port_split(self):
# Simple IPv4
- host_str = '1.2.3.4'
+ host_str = "1.2.3.4"
host, port = split_host_port(host_str)
self.assertEqual(host, host_str)
self.assertEqual(port, None)
# Simple IPv4 with port
- host_str = '1.2.3.4:55'
+ host_str = "1.2.3.4:55"
host, port = split_host_port(host_str)
- self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 55)
# Simple IPv6
- host_str = 'fec2::10'
+ host_str = "fec2::10"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'fec2::10')
+ self.assertEqual(host, "fec2::10")
self.assertEqual(port, None)
# IPv6 with square brackets no port
- host_str = '[fec2::10]'
+ host_str = "[fec2::10]"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'fec2::10')
+ self.assertEqual(host, "fec2::10")
self.assertEqual(port, None)
# IPv6 with square brackets with port
- host_str = '[fec2::10]:55'
+ host_str = "[fec2::10]:55"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'fec2::10')
+ self.assertEqual(host, "fec2::10")
self.assertEqual(port, 55)
# IPv4 inside bracket
- host_str = '[1.2.3.4]'
+ host_str = "[1.2.3.4]"
host, port = split_host_port(host_str)
- self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, None)
# IPv4 inside bracket and port
- host_str = '[1.2.3.4]:55'
+ host_str = "[1.2.3.4]:55"
host, port = split_host_port(host_str)
- self.assertEqual(host, '1.2.3.4')
+ self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 55)
# Hostname inside bracket
- host_str = '[st2build001]:55'
+ host_str = "[st2build001]:55"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'st2build001')
+ self.assertEqual(host, "st2build001")
self.assertEqual(port, 55)
# Simple hostname
- host_str = 'st2build001'
+ host_str = "st2build001"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'st2build001')
+ self.assertEqual(host, "st2build001")
self.assertEqual(port, None)
# Simple hostname with port
- host_str = 'st2build001:55'
+ host_str = "st2build001:55"
host, port = split_host_port(host_str)
- self.assertEqual(host, 'st2build001')
+ self.assertEqual(host, "st2build001")
self.assertEqual(port, 55)
# No-bracket invalid port
- host_str = 'st2build001:abc'
+ host_str = "st2build001:abc"
self.assertRaises(Exception, split_host_port, host_str)
# Bracket invalid port
- host_str = '[fec2::10]:abc'
+ host_str = "[fec2::10]:abc"
self.assertRaises(Exception, split_host_port, host_str)
diff --git a/st2common/tests/unit/test_isotime_utils.py b/st2common/tests/unit/test_isotime_utils.py
index 5ec5495ca9..34d785031b 100644
--- a/st2common/tests/unit/test_isotime_utils.py
+++ b/st2common/tests/unit/test_isotime_utils.py
@@ -24,50 +24,54 @@
class IsoTimeUtilsTestCase(unittest.TestCase):
def test_validate(self):
- self.assertTrue(isotime.validate('2000-01-01 12:00:00Z'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00+00'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00+0000'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00+00:00'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000Z'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+0000'))
- self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00:00'))
- self.assertTrue(isotime.validate('2000-01-01T12:00:00Z'))
- self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000Z'))
- self.assertTrue(isotime.validate('2000-01-01T12:00:00+00:00'))
- self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000+00:00'))
- self.assertTrue(isotime.validate('2015-02-10T21:21:53.399Z'))
- self.assertFalse(isotime.validate('2000-01-01', raise_exception=False))
- self.assertFalse(isotime.validate('2000-01-01T12:00:00', raise_exception=False))
- self.assertFalse(isotime.validate('2000-01-01T12:00:00+00:00Z', raise_exception=False))
- self.assertFalse(isotime.validate('2000-01-01T12:00:00.000000', raise_exception=False))
- self.assertFalse(isotime.validate('Epic!', raise_exception=False))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00Z"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00+00"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00+0000"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00+00:00"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000Z"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+0000"))
+ self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00:00"))
+ self.assertTrue(isotime.validate("2000-01-01T12:00:00Z"))
+ self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000Z"))
+ self.assertTrue(isotime.validate("2000-01-01T12:00:00+00:00"))
+ self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000+00:00"))
+ self.assertTrue(isotime.validate("2015-02-10T21:21:53.399Z"))
+ self.assertFalse(isotime.validate("2000-01-01", raise_exception=False))
+ self.assertFalse(isotime.validate("2000-01-01T12:00:00", raise_exception=False))
+ self.assertFalse(
+ isotime.validate("2000-01-01T12:00:00+00:00Z", raise_exception=False)
+ )
+ self.assertFalse(
+ isotime.validate("2000-01-01T12:00:00.000000", raise_exception=False)
+ )
+ self.assertFalse(isotime.validate("Epic!", raise_exception=False))
self.assertFalse(isotime.validate(object(), raise_exception=False))
- self.assertRaises(ValueError, isotime.validate, 'Epic!', True)
+ self.assertRaises(ValueError, isotime.validate, "Epic!", True)
def test_parse(self):
dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12))
- self.assertEqual(isotime.parse('2000-01-01 12:00:00Z'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00+00'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00+0000'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00+00:00'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000Z'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+0000'), dt)
- self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00:00'), dt)
- self.assertEqual(isotime.parse('2000-01-01T12:00:00Z'), dt)
- self.assertEqual(isotime.parse('2000-01-01T12:00:00+00:00'), dt)
- self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000Z'), dt)
- self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000+00:00'), dt)
- self.assertEqual(isotime.parse('2000-01-01T12:00:00.000Z'), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00Z"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00+00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00+0000"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00+00:00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000Z"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+0000"), dt)
+ self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00:00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01T12:00:00Z"), dt)
+ self.assertEqual(isotime.parse("2000-01-01T12:00:00+00:00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000Z"), dt)
+ self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000+00:00"), dt)
+ self.assertEqual(isotime.parse("2000-01-01T12:00:00.000Z"), dt)
def test_format(self):
dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12))
- dt_str_usec_offset = '2000-01-01T12:00:00.000000+00:00'
- dt_str_usec = '2000-01-01T12:00:00.000000Z'
- dt_str_offset = '2000-01-01T12:00:00+00:00'
- dt_str = '2000-01-01T12:00:00Z'
- dt_unicode = u'2000-01-01T12:00:00Z'
+ dt_str_usec_offset = "2000-01-01T12:00:00.000000+00:00"
+ dt_str_usec = "2000-01-01T12:00:00.000000Z"
+ dt_str_offset = "2000-01-01T12:00:00+00:00"
+ dt_str = "2000-01-01T12:00:00Z"
+ dt_unicode = "2000-01-01T12:00:00Z"
# datetime object
self.assertEqual(isotime.format(dt, usec=True, offset=True), dt_str_usec_offset)
@@ -75,16 +79,22 @@ def test_format(self):
self.assertEqual(isotime.format(dt, usec=False, offset=True), dt_str_offset)
self.assertEqual(isotime.format(dt, usec=False, offset=False), dt_str)
self.assertEqual(isotime.format(dt_str, usec=False, offset=False), dt_str)
- self.assertEqual(isotime.format(dt_unicode, usec=False, offset=False), dt_unicode)
+ self.assertEqual(
+ isotime.format(dt_unicode, usec=False, offset=False), dt_unicode
+ )
# unix timestamp (epoch)
dt = 1557390483
- self.assertEqual(isotime.format(dt, usec=True, offset=True),
- '2019-05-09T08:28:03.000000+00:00')
- self.assertEqual(isotime.format(dt, usec=False, offset=False),
- '2019-05-09T08:28:03Z')
- self.assertEqual(isotime.format(dt, usec=False, offset=True),
- '2019-05-09T08:28:03+00:00')
+ self.assertEqual(
+ isotime.format(dt, usec=True, offset=True),
+ "2019-05-09T08:28:03.000000+00:00",
+ )
+ self.assertEqual(
+ isotime.format(dt, usec=False, offset=False), "2019-05-09T08:28:03Z"
+ )
+ self.assertEqual(
+ isotime.format(dt, usec=False, offset=True), "2019-05-09T08:28:03+00:00"
+ )
def test_format_tz_naive(self):
dt1 = datetime.datetime.utcnow()
@@ -99,6 +109,8 @@ def test_format_tz_aware(self):
def test_format_sec_truncated(self):
dt1 = date.add_utc_tz(datetime.datetime.utcnow())
dt2 = isotime.parse(isotime.format(dt1, usec=False))
- dt3 = datetime.datetime(dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second)
+ dt3 = datetime.datetime(
+ dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second
+ )
self.assertLess(dt2, dt1)
self.assertEqual(dt2, date.add_utc_tz(dt3))
diff --git a/st2common/tests/unit/test_jinja_render_crypto_filters.py b/st2common/tests/unit/test_jinja_render_crypto_filters.py
index 1a026e83ed..f58edb1309 100644
--- a/st2common/tests/unit/test_jinja_render_crypto_filters.py
+++ b/st2common/tests/unit/test_jinja_render_crypto_filters.py
@@ -38,72 +38,101 @@ def setUp(self):
crypto_key_path = cfg.CONF.keyvalue.encryption_key_path
crypto_key = read_crypto_key(key_path=crypto_key_path)
- self.secret = 'Build a wall'
- self.secret_value = symmetric_encrypt(encrypt_key=crypto_key, plaintext=self.secret)
+ self.secret = "Build a wall"
+ self.secret_value = symmetric_encrypt(
+ encrypt_key=crypto_key, plaintext=self.secret
+ )
self.env = jinja_utils.get_jinja_environment()
def test_filter_decrypt_kv(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value=self.secret_value,
- scope=FULL_SYSTEM_SCOPE,
- secret=True))
+ KeyValuePair.add_or_update(
+ KeyValuePairDB(
+ name="k8", value=self.secret_value, scope=FULL_SYSTEM_SCOPE, secret=True
+ )
+ )
context = {}
context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)})
- context.update({
- DATASTORE_PARENT_SCOPE: {
- SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ context.update(
+ {
+ DATASTORE_PARENT_SCOPE: {
+ SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ }
}
- })
+ )
- template = '{{st2kv.system.k8 | decrypt_kv}}'
+ template = "{{st2kv.system.k8 | decrypt_kv}}"
actual = self.env.from_string(template).render(context)
self.assertEqual(actual, self.secret)
def test_filter_decrypt_kv_datastore_value_doesnt_exist(self):
context = {}
context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)})
- context.update({
- DATASTORE_PARENT_SCOPE: {
- SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ context.update(
+ {
+ DATASTORE_PARENT_SCOPE: {
+ SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
+ }
}
- })
+ )
- template = '{{st2kv.system.doesnt_exist | decrypt_kv}}'
+ template = "{{st2kv.system.doesnt_exist | decrypt_kv}}"
- expected_msg = ('Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or '
- 'it contains an empty string')
- self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render,
- context)
+ expected_msg = (
+ 'Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or '
+ "it contains an empty string"
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, self.env.from_string(template).render, context
+ )
def test_filter_decrypt_kv_with_user_scope_value(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k8', value=self.secret_value,
- scope=FULL_USER_SCOPE,
- secret=True))
+ KeyValuePair.add_or_update(
+ KeyValuePairDB(
+ name="stanley:k8",
+ value=self.secret_value,
+ scope=FULL_USER_SCOPE,
+ secret=True,
+ )
+ )
context = {}
- context.update({USER_SCOPE: UserKeyValueLookup(user='stanley', scope=USER_SCOPE)})
- context.update({
- DATASTORE_PARENT_SCOPE: {
- USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE)
+ context.update(
+ {USER_SCOPE: UserKeyValueLookup(user="stanley", scope=USER_SCOPE)}
+ )
+ context.update(
+ {
+ DATASTORE_PARENT_SCOPE: {
+ USER_SCOPE: UserKeyValueLookup(
+ user="stanley", scope=FULL_USER_SCOPE
+ )
+ }
}
- })
+ )
- template = '{{st2kv.user.k8 | decrypt_kv}}'
+ template = "{{st2kv.user.k8 | decrypt_kv}}"
actual = self.env.from_string(template).render(context)
self.assertEqual(actual, self.secret)
def test_filter_decrypt_kv_with_user_scope_value_datastore_value_doesnt_exist(self):
context = {}
context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)})
- context.update({
- DATASTORE_PARENT_SCOPE: {
- USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE)
+ context.update(
+ {
+ DATASTORE_PARENT_SCOPE: {
+ USER_SCOPE: UserKeyValueLookup(
+ user="stanley", scope=FULL_USER_SCOPE
+ )
+ }
}
- })
+ )
- template = '{{st2kv.user.doesnt_exist | decrypt_kv}}'
+ template = "{{st2kv.user.doesnt_exist | decrypt_kv}}"
- expected_msg = ('Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or '
- 'it contains an empty string')
- self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render,
- context)
+ expected_msg = (
+ 'Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or '
+ "it contains an empty string"
+ )
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, self.env.from_string(template).render, context
+ )
diff --git a/st2common/tests/unit/test_jinja_render_data_filters.py b/st2common/tests/unit/test_jinja_render_data_filters.py
index fd923e870f..44d2f296f9 100644
--- a/st2common/tests/unit/test_jinja_render_data_filters.py
+++ b/st2common/tests/unit/test_jinja_render_data_filters.py
@@ -24,77 +24,68 @@
class JinjaUtilsDataFilterTestCase(unittest2.TestCase):
-
def test_filter_from_json_string(self):
env = jinja_utils.get_jinja_environment()
- expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}}
+ expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}}
obj_json_str = '{"a": "b", "c": {"d": "e", "f": 1, "g": true}}'
- template = '{{k1 | from_json_string}}'
+ template = "{{k1 | from_json_string}}"
- obj_str = env.from_string(template).render({'k1': obj_json_str})
+ obj_str = env.from_string(template).render({"k1": obj_json_str})
obj = eval(obj_str)
self.assertDictEqual(obj, expected_obj)
# With KeyValueLookup object
env = jinja_utils.get_jinja_environment()
obj_json_str = '["a", "b", "c"]'
- expected_obj = ['a', 'b', 'c']
+ expected_obj = ["a", "b", "c"]
- template = '{{ k1 | from_json_string}}'
+ template = "{{ k1 | from_json_string}}"
- lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='a')
- lookup._value_cache['a'] = obj_json_str
- obj_str = env.from_string(template).render({'k1': lookup})
+ lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="a")
+ lookup._value_cache["a"] = obj_json_str
+ obj_str = env.from_string(template).render({"k1": lookup})
obj = eval(obj_str)
self.assertEqual(obj, expected_obj)
def test_filter_from_yaml_string(self):
env = jinja_utils.get_jinja_environment()
- expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}}
- obj_yaml_str = ("---\n"
- "a: b\n"
- "c:\n"
- " d: e\n"
- " f: 1\n"
- " g: true\n")
-
- template = '{{k1 | from_yaml_string}}'
- obj_str = env.from_string(template).render({'k1': obj_yaml_str})
+ expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}}
+ obj_yaml_str = "---\n" "a: b\n" "c:\n" " d: e\n" " f: 1\n" " g: true\n"
+
+ template = "{{k1 | from_yaml_string}}"
+ obj_str = env.from_string(template).render({"k1": obj_yaml_str})
obj = eval(obj_str)
self.assertDictEqual(obj, expected_obj)
# With KeyValueLookup object
env = jinja_utils.get_jinja_environment()
- obj_yaml_str = ("---\n"
- "- a\n"
- "- b\n"
- "- c\n")
- expected_obj = ['a', 'b', 'c']
+ obj_yaml_str = "---\n" "- a\n" "- b\n" "- c\n"
+ expected_obj = ["a", "b", "c"]
- template = '{{ k1 | from_yaml_string }}'
+ template = "{{ k1 | from_yaml_string }}"
- lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='b')
- lookup._value_cache['b'] = obj_yaml_str
- obj_str = env.from_string(template).render({'k1': lookup})
+ lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="b")
+ lookup._value_cache["b"] = obj_yaml_str
+ obj_str = env.from_string(template).render({"k1": lookup})
obj = eval(obj_str)
self.assertEqual(obj, expected_obj)
def test_filter_to_json_string(self):
env = jinja_utils.get_jinja_environment()
- obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}}
+ obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}}
- template = '{{k1 | to_json_string}}'
+ template = "{{k1 | to_json_string}}"
- obj_json_str = env.from_string(template).render({'k1': obj})
+ obj_json_str = env.from_string(template).render({"k1": obj})
actual_obj = json.loads(obj_json_str)
self.assertDictEqual(obj, actual_obj)
def test_filter_to_yaml_string(self):
env = jinja_utils.get_jinja_environment()
- obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}}
+ obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}}
- template = '{{k1 | to_yaml_string}}'
- obj_yaml_str = env.from_string(template).render({'k1': obj})
+ template = "{{k1 | to_yaml_string}}"
+ obj_yaml_str = env.from_string(template).render({"k1": obj})
actual_obj = yaml.safe_load(obj_yaml_str)
self.assertDictEqual(obj, actual_obj)
diff --git a/st2common/tests/unit/test_jinja_render_json_escape_filters.py b/st2common/tests/unit/test_jinja_render_json_escape_filters.py
index 82534100c5..48fef776c1 100644
--- a/st2common/tests/unit/test_jinja_render_json_escape_filters.py
+++ b/st2common/tests/unit/test_jinja_render_json_escape_filters.py
@@ -21,52 +21,51 @@
class JinjaUtilsJsonEscapeTestCase(unittest2.TestCase):
-
def test_doublequotes(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo """ bar'})
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": 'foo """ bar'})
expected = 'foo \\"\\"\\" bar'
self.assertEqual(actual, expected)
def test_backslashes(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': r'foo \ bar'})
- expected = 'foo \\\\ bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": r"foo \ bar"})
+ expected = "foo \\\\ bar"
self.assertEqual(actual, expected)
def test_backspace(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo \b bar'})
- expected = 'foo \\b bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": "foo \b bar"})
+ expected = "foo \\b bar"
self.assertEqual(actual, expected)
def test_formfeed(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo \f bar'})
- expected = 'foo \\f bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": "foo \f bar"})
+ expected = "foo \\f bar"
self.assertEqual(actual, expected)
def test_newline(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo \n bar'})
- expected = 'foo \\n bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": "foo \n bar"})
+ expected = "foo \\n bar"
self.assertEqual(actual, expected)
def test_carriagereturn(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo \r bar'})
- expected = 'foo \\r bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": "foo \r bar"})
+ expected = "foo \\r bar"
self.assertEqual(actual, expected)
def test_tab(self):
env = jinja_utils.get_jinja_environment()
- template = '{{ test_str | json_escape }}'
- actual = env.from_string(template).render({'test_str': 'foo \t bar'})
- expected = 'foo \\t bar'
+ template = "{{ test_str | json_escape }}"
+ actual = env.from_string(template).render({"test_str": "foo \t bar"})
+ expected = "foo \\t bar"
self.assertEqual(actual, expected)
diff --git a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py
index fd199ebf64..934aa04de8 100644
--- a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py
+++ b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py
@@ -21,49 +21,58 @@
class JinjaUtilsJsonpathQueryTestCase(unittest2.TestCase):
-
def test_jsonpath_query_static(self):
env = jinja_utils.get_jinja_environment()
- obj = {'people': [{'first': 'James', 'last': 'd'},
- {'first': 'Jacob', 'last': 'e'},
- {'first': 'Jayden', 'last': 'f'},
- {'missing': 'different'}],
- 'foo': {'bar': 'baz'}}
+ obj = {
+ "people": [
+ {"first": "James", "last": "d"},
+ {"first": "Jacob", "last": "e"},
+ {"first": "Jayden", "last": "f"},
+ {"missing": "different"},
+ ],
+ "foo": {"bar": "baz"},
+ }
template = '{{ obj | jsonpath_query("people[*].first") }}'
- actual_str = env.from_string(template).render({'obj': obj})
+ actual_str = env.from_string(template).render({"obj": obj})
actual = eval(actual_str)
- expected = ['James', 'Jacob', 'Jayden']
+ expected = ["James", "Jacob", "Jayden"]
self.assertEqual(actual, expected)
def test_jsonpath_query_dynamic(self):
env = jinja_utils.get_jinja_environment()
- obj = {'people': [{'first': 'James', 'last': 'd'},
- {'first': 'Jacob', 'last': 'e'},
- {'first': 'Jayden', 'last': 'f'},
- {'missing': 'different'}],
- 'foo': {'bar': 'baz'}}
+ obj = {
+ "people": [
+ {"first": "James", "last": "d"},
+ {"first": "Jacob", "last": "e"},
+ {"first": "Jayden", "last": "f"},
+ {"missing": "different"},
+ ],
+ "foo": {"bar": "baz"},
+ }
query = "people[*].last"
- template = '{{ obj | jsonpath_query(query) }}'
- actual_str = env.from_string(template).render({'obj': obj,
- 'query': query})
+ template = "{{ obj | jsonpath_query(query) }}"
+ actual_str = env.from_string(template).render({"obj": obj, "query": query})
actual = eval(actual_str)
- expected = ['d', 'e', 'f']
+ expected = ["d", "e", "f"]
self.assertEqual(actual, expected)
def test_jsonpath_query_no_results(self):
env = jinja_utils.get_jinja_environment()
- obj = {'people': [{'first': 'James', 'last': 'd'},
- {'first': 'Jacob', 'last': 'e'},
- {'first': 'Jayden', 'last': 'f'},
- {'missing': 'different'}],
- 'foo': {'bar': 'baz'}}
+ obj = {
+ "people": [
+ {"first": "James", "last": "d"},
+ {"first": "Jacob", "last": "e"},
+ {"first": "Jayden", "last": "f"},
+ {"missing": "different"},
+ ],
+ "foo": {"bar": "baz"},
+ }
query = "query_returns_no_results"
- template = '{{ obj | jsonpath_query(query) }}'
- actual_str = env.from_string(template).render({'obj': obj,
- 'query': query})
+ template = "{{ obj | jsonpath_query(query) }}"
+ actual_str = env.from_string(template).render({"obj": obj, "query": query})
actual = eval(actual_str)
expected = None
self.assertEqual(actual, expected)
diff --git a/st2common/tests/unit/test_jinja_render_path_filters.py b/st2common/tests/unit/test_jinja_render_path_filters.py
index 504b6454bb..23507bbbc1 100644
--- a/st2common/tests/unit/test_jinja_render_path_filters.py
+++ b/st2common/tests/unit/test_jinja_render_path_filters.py
@@ -21,29 +21,28 @@
class JinjaUtilsPathFilterTestCase(unittest2.TestCase):
-
def test_basename(self):
env = jinja_utils.get_jinja_environment()
- template = '{{k1 | basename}}'
- actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'})
- self.assertEqual(actual, 'file.txt')
+ template = "{{k1 | basename}}"
+ actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"})
+ self.assertEqual(actual, "file.txt")
- actual = env.from_string(template).render({'k1': '/some/path/to/dir'})
- self.assertEqual(actual, 'dir')
+ actual = env.from_string(template).render({"k1": "/some/path/to/dir"})
+ self.assertEqual(actual, "dir")
- actual = env.from_string(template).render({'k1': '/some/path/to/dir/'})
- self.assertEqual(actual, '')
+ actual = env.from_string(template).render({"k1": "/some/path/to/dir/"})
+ self.assertEqual(actual, "")
def test_dirname(self):
env = jinja_utils.get_jinja_environment()
- template = '{{k1 | dirname}}'
- actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'})
- self.assertEqual(actual, '/some/path/to')
+ template = "{{k1 | dirname}}"
+ actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"})
+ self.assertEqual(actual, "/some/path/to")
- actual = env.from_string(template).render({'k1': '/some/path/to/dir'})
- self.assertEqual(actual, '/some/path/to')
+ actual = env.from_string(template).render({"k1": "/some/path/to/dir"})
+ self.assertEqual(actual, "/some/path/to")
- actual = env.from_string(template).render({'k1': '/some/path/to/dir/'})
- self.assertEqual(actual, '/some/path/to/dir')
+ actual = env.from_string(template).render({"k1": "/some/path/to/dir/"})
+ self.assertEqual(actual, "/some/path/to/dir")
diff --git a/st2common/tests/unit/test_jinja_render_regex_filters.py b/st2common/tests/unit/test_jinja_render_regex_filters.py
index 081d068682..df2e347779 100644
--- a/st2common/tests/unit/test_jinja_render_regex_filters.py
+++ b/st2common/tests/unit/test_jinja_render_regex_filters.py
@@ -20,54 +20,53 @@
class JinjaUtilsRegexFilterTestCase(unittest2.TestCase):
-
def test_filters_regex_match(self):
env = jinja_utils.get_jinja_environment()
template = '{{k1 | regex_match("x")}}'
- actual = env.from_string(template).render({'k1': 'xyz'})
- expected = 'True'
+ actual = env.from_string(template).render({"k1": "xyz"})
+ expected = "True"
self.assertEqual(actual, expected)
template = '{{k1 | regex_match("y")}}'
- actual = env.from_string(template).render({'k1': 'xyz'})
- expected = 'False'
+ actual = env.from_string(template).render({"k1": "xyz"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{k1 | regex_match("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}'
- actual = env.from_string(template).render({'k1': 'v0.10.1'})
- expected = 'True'
+ actual = env.from_string(template).render({"k1": "v0.10.1"})
+ expected = "True"
self.assertEqual(actual, expected)
def test_filters_regex_replace(self):
env = jinja_utils.get_jinja_environment()
template = '{{k1 | regex_replace("x", "y")}}'
- actual = env.from_string(template).render({'k1': 'xyz'})
- expected = 'yyz'
+ actual = env.from_string(template).render({"k1": "xyz"})
+ expected = "yyz"
self.assertEqual(actual, expected)
template = '{{k1 | regex_replace("(blue|white|red)", "color")}}'
- actual = env.from_string(template).render({'k1': 'blue socks and red shoes'})
- expected = 'color socks and color shoes'
+ actual = env.from_string(template).render({"k1": "blue socks and red shoes"})
+ expected = "color socks and color shoes"
self.assertEqual(actual, expected)
def test_filters_regex_search(self):
env = jinja_utils.get_jinja_environment()
template = '{{k1 | regex_search("x")}}'
- actual = env.from_string(template).render({'k1': 'xyz'})
- expected = 'True'
+ actual = env.from_string(template).render({"k1": "xyz"})
+ expected = "True"
self.assertEqual(actual, expected)
template = '{{k1 | regex_search("y")}}'
- actual = env.from_string(template).render({'k1': 'xyz'})
- expected = 'True'
+ actual = env.from_string(template).render({"k1": "xyz"})
+ expected = "True"
self.assertEqual(actual, expected)
template = '{{k1 | regex_search("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}'
- actual = env.from_string(template).render({'k1': 'v0.10.1'})
- expected = 'True'
+ actual = env.from_string(template).render({"k1": "v0.10.1"})
+ expected = "True"
self.assertEqual(actual, expected)
def test_filters_regex_substring(self):
@@ -76,29 +75,31 @@ def test_filters_regex_substring(self):
# Normal (match)
template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))")}}'
actual = env.from_string(template).render(
- {'input_str': 'My address is 123 Somewhere Ave. See you soon!'}
+ {"input_str": "My address is 123 Somewhere Ave. See you soon!"}
)
- expected = '123 Somewhere Ave'
+ expected = "123 Somewhere Ave"
self.assertEqual(actual, expected)
# Selecting second match explicitly
template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}'
actual = env.from_string(template).render(
- {'input_str': 'Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave.'}
+ {
+ "input_str": "Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave."
+ }
)
- expected = '123 Somewhere Ave'
+ expected = "123 Somewhere Ave"
self.assertEqual(actual, expected)
# Selecting second match explicitly, but doesn't exist
template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}'
with self.assertRaises(IndexError):
actual = env.from_string(template).render(
- {'input_str': 'Your address is 567 Elsewhere Dr.'}
+ {"input_str": "Your address is 567 Elsewhere Dr."}
)
# No match
template = r'{{input_str | regex_substring("([0-3]{3} \w+ (?:Ave|St|Dr))")}}'
with self.assertRaises(IndexError):
actual = env.from_string(template).render(
- {'input_str': 'My address is 986 Somewhere Ave. See you soon!'}
+ {"input_str": "My address is 986 Somewhere Ave. See you soon!"}
)
diff --git a/st2common/tests/unit/test_jinja_render_time_filters.py b/st2common/tests/unit/test_jinja_render_time_filters.py
index 5151cec695..2cf002a0e3 100644
--- a/st2common/tests/unit/test_jinja_render_time_filters.py
+++ b/st2common/tests/unit/test_jinja_render_time_filters.py
@@ -20,16 +20,16 @@
class JinjaUtilsTimeFilterTestCase(unittest2.TestCase):
-
def test_to_human_time_filter(self):
env = jinja_utils.get_jinja_environment()
- template = '{{k1 | to_human_time_from_seconds}}'
- actual = env.from_string(template).render({'k1': 12345})
- self.assertEqual(actual, '3h25m45s')
+ template = "{{k1 | to_human_time_from_seconds}}"
+ actual = env.from_string(template).render({"k1": 12345})
+ self.assertEqual(actual, "3h25m45s")
- actual = env.from_string(template).render({'k1': 0})
- self.assertEqual(actual, '0s')
+ actual = env.from_string(template).render({"k1": 0})
+ self.assertEqual(actual, "0s")
- self.assertRaises(AssertionError, env.from_string(template).render,
- {'k1': 'stuff'})
+ self.assertRaises(
+ AssertionError, env.from_string(template).render, {"k1": "stuff"}
+ )
diff --git a/st2common/tests/unit/test_jinja_render_version_filters.py b/st2common/tests/unit/test_jinja_render_version_filters.py
index 9cbacd7dcb..41b2b23670 100644
--- a/st2common/tests/unit/test_jinja_render_version_filters.py
+++ b/st2common/tests/unit/test_jinja_render_version_filters.py
@@ -21,134 +21,133 @@
class JinjaUtilsVersionsFilterTestCase(unittest2.TestCase):
-
def test_version_compare(self):
env = jinja_utils.get_jinja_environment()
template = '{{version | version_compare("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.9.0'})
- expected = '-1'
+ actual = env.from_string(template).render({"version": "0.9.0"})
+ expected = "-1"
self.assertEqual(actual, expected)
template = '{{version | version_compare("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = '1'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "1"
self.assertEqual(actual, expected)
template = '{{version | version_compare("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.0'})
- expected = '0'
+ actual = env.from_string(template).render({"version": "0.10.0"})
+ expected = "0"
self.assertEqual(actual, expected)
def test_version_more_than(self):
env = jinja_utils.get_jinja_environment()
template = '{{version | version_more_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.9.0'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.9.0"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_more_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "True"
self.assertEqual(actual, expected)
template = '{{version | version_more_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.0'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.10.0"})
+ expected = "False"
self.assertEqual(actual, expected)
def test_version_less_than(self):
env = jinja_utils.get_jinja_environment()
template = '{{version | version_less_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.9.0'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.9.0"})
+ expected = "True"
self.assertEqual(actual, expected)
template = '{{version | version_less_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_less_than("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.0'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.10.0"})
+ expected = "False"
self.assertEqual(actual, expected)
def test_version_equal(self):
env = jinja_utils.get_jinja_environment()
template = '{{version | version_equal("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.9.0'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.9.0"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_equal("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_equal("0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.0'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.10.0"})
+ expected = "True"
self.assertEqual(actual, expected)
def test_version_match(self):
env = jinja_utils.get_jinja_environment()
template = '{{version | version_match(">0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "True"
self.assertEqual(actual, expected)
- actual = env.from_string(template).render({'version': '0.1.1'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.1.1"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_match("<0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.1.0'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.1.0"})
+ expected = "True"
self.assertEqual(actual, expected)
- actual = env.from_string(template).render({'version': '1.1.0'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "1.1.0"})
+ expected = "False"
self.assertEqual(actual, expected)
template = '{{version | version_match("==0.10.0")}}'
- actual = env.from_string(template).render({'version': '0.10.0'})
- expected = 'True'
+ actual = env.from_string(template).render({"version": "0.10.0"})
+ expected = "True"
self.assertEqual(actual, expected)
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = 'False'
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "False"
self.assertEqual(actual, expected)
def test_version_bump_major(self):
env = jinja_utils.get_jinja_environment()
- template = '{{version | version_bump_major}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = '1.0.0'
+ template = "{{version | version_bump_major}}"
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "1.0.0"
self.assertEqual(actual, expected)
def test_version_bump_minor(self):
env = jinja_utils.get_jinja_environment()
- template = '{{version | version_bump_minor}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = '0.11.0'
+ template = "{{version | version_bump_minor}}"
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "0.11.0"
self.assertEqual(actual, expected)
def test_version_bump_patch(self):
env = jinja_utils.get_jinja_environment()
- template = '{{version | version_bump_patch}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = '0.10.2'
+ template = "{{version | version_bump_patch}}"
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "0.10.2"
self.assertEqual(actual, expected)
def test_version_strip_patch(self):
env = jinja_utils.get_jinja_environment()
- template = '{{version | version_strip_patch}}'
- actual = env.from_string(template).render({'version': '0.10.1'})
- expected = '0.10'
+ template = "{{version | version_strip_patch}}"
+ actual = env.from_string(template).render({"version": "0.10.1"})
+ expected = "0.10"
self.assertEqual(actual, expected)
diff --git a/st2common/tests/unit/test_json_schema.py b/st2common/tests/unit/test_json_schema.py
index 42b94efb70..892e2604ed 100644
--- a/st2common/tests/unit/test_json_schema.py
+++ b/st2common/tests/unit/test_json_schema.py
@@ -20,158 +20,127 @@
from st2common.util import schema as util_schema
TEST_SCHEMA_1 = {
- 'additionalProperties': False,
- 'title': 'foo',
- 'description': 'Foo.',
- 'type': 'object',
- 'properties': {
- 'arg_required_no_default': {
- 'description': 'Foo',
- 'required': True,
- 'type': 'string'
+ "additionalProperties": False,
+ "title": "foo",
+ "description": "Foo.",
+ "type": "object",
+ "properties": {
+ "arg_required_no_default": {
+ "description": "Foo",
+ "required": True,
+ "type": "string",
},
- 'arg_optional_no_type': {
- 'description': 'Bar'
+ "arg_optional_no_type": {"description": "Bar"},
+ "arg_optional_multi_type": {
+ "description": "Mirror mirror",
+ "type": ["string", "boolean", "number"],
},
- 'arg_optional_multi_type': {
- 'description': 'Mirror mirror',
- 'type': ['string', 'boolean', 'number']
+ "arg_optional_multi_type_none": {
+ "description": "Mirror mirror on the wall",
+ "type": ["string", "boolean", "number", "null"],
},
- 'arg_optional_multi_type_none': {
- 'description': 'Mirror mirror on the wall',
- 'type': ['string', 'boolean', 'number', 'null']
+ "arg_optional_type_array": {
+ "description": "Who" "s the fairest?",
+ "type": "array",
},
- 'arg_optional_type_array': {
- 'description': 'Who''s the fairest?',
- 'type': 'array'
+ "arg_optional_type_object": {
+ "description": "Who" "s the fairest of them?",
+ "type": "object",
},
- 'arg_optional_type_object': {
- 'description': 'Who''s the fairest of them?',
- 'type': 'object'
+ "arg_optional_multi_collection_type": {
+ "description": "Who" "s the fairest of them all?",
+ "type": ["array", "object"],
},
- 'arg_optional_multi_collection_type': {
- 'description': 'Who''s the fairest of them all?',
- 'type': ['array', 'object']
- }
- }
+ },
}
TEST_SCHEMA_2 = {
- 'additionalProperties': False,
- 'title': 'foo',
- 'description': 'Foo.',
- 'type': 'object',
- 'properties': {
- 'arg_required_default': {
- 'default': 'date',
- 'description': 'Foo',
- 'required': True,
- 'type': 'string'
+ "additionalProperties": False,
+ "title": "foo",
+ "description": "Foo.",
+ "type": "object",
+ "properties": {
+ "arg_required_default": {
+ "default": "date",
+ "description": "Foo",
+ "required": True,
+ "type": "string",
}
- }
+ },
}
TEST_SCHEMA_3 = {
- 'additionalProperties': False,
- 'title': 'foo',
- 'description': 'Foo.',
- 'type': 'object',
- 'properties': {
- 'arg_optional_default': {
- 'default': 'bar',
- 'description': 'Foo',
- 'type': 'string'
+ "additionalProperties": False,
+ "title": "foo",
+ "description": "Foo.",
+ "type": "object",
+ "properties": {
+ "arg_optional_default": {
+ "default": "bar",
+ "description": "Foo",
+ "type": "string",
},
- 'arg_optional_default_none': {
- 'default': None,
- 'description': 'Foo',
- 'type': 'string'
+ "arg_optional_default_none": {
+ "default": None,
+ "description": "Foo",
+ "type": "string",
},
- 'arg_optional_no_default': {
- 'description': 'Foo',
- 'type': 'string'
- }
- }
+ "arg_optional_no_default": {"description": "Foo", "type": "string"},
+ },
}
TEST_SCHEMA_4 = {
- 'additionalProperties': False,
- 'title': 'foo',
- 'description': 'Foo.',
- 'type': 'object',
- 'properties': {
- 'arg_optional_default': {
- 'default': 'bar',
- 'description': 'Foo',
- 'anyOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "additionalProperties": False,
+ "title": "foo",
+ "description": "Foo.",
+ "type": "object",
+ "properties": {
+ "arg_optional_default": {
+ "default": "bar",
+ "description": "Foo",
+ "anyOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_default_none': {
- 'default': None,
- 'description': 'Foo',
- 'anyOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "arg_optional_default_none": {
+ "default": None,
+ "description": "Foo",
+ "anyOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_no_default': {
- 'description': 'Foo',
- 'anyOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "arg_optional_no_default": {
+ "description": "Foo",
+ "anyOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_no_default_anyof_none': {
- 'description': 'Foo',
- 'anyOf': [
- {'type': 'string'},
- {'type': 'boolean'},
- {'type': 'null'}
- ]
- }
- }
+ "arg_optional_no_default_anyof_none": {
+ "description": "Foo",
+ "anyOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}],
+ },
+ },
}
TEST_SCHEMA_5 = {
- 'additionalProperties': False,
- 'title': 'foo',
- 'description': 'Foo.',
- 'type': 'object',
- 'properties': {
- 'arg_optional_default': {
- 'default': 'bar',
- 'description': 'Foo',
- 'oneOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "additionalProperties": False,
+ "title": "foo",
+ "description": "Foo.",
+ "type": "object",
+ "properties": {
+ "arg_optional_default": {
+ "default": "bar",
+ "description": "Foo",
+ "oneOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_default_none': {
- 'default': None,
- 'description': 'Foo',
- 'oneOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "arg_optional_default_none": {
+ "default": None,
+ "description": "Foo",
+ "oneOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_no_default': {
- 'description': 'Foo',
- 'oneOf': [
- {'type': 'string'},
- {'type': 'boolean'}
- ]
+ "arg_optional_no_default": {
+ "description": "Foo",
+ "oneOf": [{"type": "string"}, {"type": "boolean"}],
},
- 'arg_optional_no_default_oneof_none': {
- 'description': 'Foo',
- 'oneOf': [
- {'type': 'string'},
- {'type': 'boolean'},
- {'type': 'null'}
- ]
- }
- }
+ "arg_optional_no_default_oneof_none": {
+ "description": "Foo",
+ "oneOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}],
+ },
+ },
}
@@ -181,192 +150,265 @@ def test_use_default_value(self):
instance = {}
validator = util_schema.get_validator()
- expected_msg = '\'arg_required_no_default\' is a required property'
- self.assertRaisesRegexp(ValidationError, expected_msg, util_schema.validate,
- instance=instance, schema=TEST_SCHEMA_1, cls=validator,
- use_default=True)
+ expected_msg = "'arg_required_no_default' is a required property"
+ self.assertRaisesRegexp(
+ ValidationError,
+ expected_msg,
+ util_schema.validate,
+ instance=instance,
+ schema=TEST_SCHEMA_1,
+ cls=validator,
+ use_default=True,
+ )
# No default, value provided
- instance = {'arg_required_no_default': 'foo'}
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_1, cls=validator,
- use_default=True)
+ instance = {"arg_required_no_default": "foo"}
+ util_schema.validate(
+ instance=instance, schema=TEST_SCHEMA_1, cls=validator, use_default=True
+ )
# default value provided, no value, should pass
instance = {}
validator = util_schema.get_validator()
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator,
- use_default=True)
+ util_schema.validate(
+ instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True
+ )
# default value provided, value provided, should pass
- instance = {'arg_required_default': 'foo'}
+ instance = {"arg_required_default": "foo"}
validator = util_schema.get_validator()
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator,
- use_default=True)
+ util_schema.validate(
+ instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True
+ )
def test_allow_default_none(self):
# Let validator take care of default
validator = util_schema.get_validator()
- util_schema.validate(instance=dict(), schema=TEST_SCHEMA_3, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=dict(),
+ schema=TEST_SCHEMA_3,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_allow_default_explicit_none(self):
# Explicitly pass None to arguments
instance = {
- 'arg_optional_default': None,
- 'arg_optional_default_none': None,
- 'arg_optional_no_default': None
+ "arg_optional_default": None,
+ "arg_optional_default_none": None,
+ "arg_optional_no_default": None,
}
validator = util_schema.get_validator()
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_3, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=instance,
+ schema=TEST_SCHEMA_3,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_anyof_type_allow_default_none(self):
# Let validator take care of default
validator = util_schema.get_validator()
- util_schema.validate(instance=dict(), schema=TEST_SCHEMA_4, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=dict(),
+ schema=TEST_SCHEMA_4,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_anyof_allow_default_explicit_none(self):
# Explicitly pass None to arguments
instance = {
- 'arg_optional_default': None,
- 'arg_optional_default_none': None,
- 'arg_optional_no_default': None,
- 'arg_optional_no_default_anyof_none': None
+ "arg_optional_default": None,
+ "arg_optional_default_none": None,
+ "arg_optional_no_default": None,
+ "arg_optional_no_default_anyof_none": None,
}
validator = util_schema.get_validator()
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_4, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=instance,
+ schema=TEST_SCHEMA_4,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_oneof_type_allow_default_none(self):
# Let validator take care of default
validator = util_schema.get_validator()
- util_schema.validate(instance=dict(), schema=TEST_SCHEMA_5, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=dict(),
+ schema=TEST_SCHEMA_5,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_oneof_allow_default_explicit_none(self):
# Explicitly pass None to arguments
instance = {
- 'arg_optional_default': None,
- 'arg_optional_default_none': None,
- 'arg_optional_no_default': None,
- 'arg_optional_no_default_oneof_none': None
+ "arg_optional_default": None,
+ "arg_optional_default_none": None,
+ "arg_optional_no_default": None,
+ "arg_optional_no_default_oneof_none": None,
}
validator = util_schema.get_validator()
- util_schema.validate(instance=instance, schema=TEST_SCHEMA_5, cls=validator,
- use_default=True, allow_default_none=True)
+ util_schema.validate(
+ instance=instance,
+ schema=TEST_SCHEMA_5,
+ cls=validator,
+ use_default=True,
+ allow_default_none=True,
+ )
def test_is_property_type_single(self):
- typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default']
+ typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"]
self.assertTrue(util_schema.is_property_type_single(typed_property))
- untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type']
+ untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"]
self.assertTrue(util_schema.is_property_type_single(untyped_property))
- multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
+ multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
self.assertFalse(util_schema.is_property_type_single(multi_typed_property))
- anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default']
+ anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_single(anyof_property))
- oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default']
+ oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_single(oneof_property))
def test_is_property_type_anyof(self):
- anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default']
+ anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"]
self.assertTrue(util_schema.is_property_type_anyof(anyof_property))
- typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default']
+ typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"]
self.assertFalse(util_schema.is_property_type_anyof(typed_property))
- untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type']
+ untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"]
self.assertFalse(util_schema.is_property_type_anyof(untyped_property))
- multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
+ multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
self.assertFalse(util_schema.is_property_type_anyof(multi_typed_property))
- oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default']
+ oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_anyof(oneof_property))
def test_is_property_type_oneof(self):
- oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default']
+ oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"]
self.assertTrue(util_schema.is_property_type_oneof(oneof_property))
- typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default']
+ typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"]
self.assertFalse(util_schema.is_property_type_oneof(typed_property))
- untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type']
+ untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"]
self.assertFalse(util_schema.is_property_type_oneof(untyped_property))
- multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
+ multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
self.assertFalse(util_schema.is_property_type_oneof(multi_typed_property))
- anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default']
+ anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_oneof(anyof_property))
def test_is_property_type_list(self):
- multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
+ multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
self.assertTrue(util_schema.is_property_type_list(multi_typed_property))
- typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default']
+ typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"]
self.assertFalse(util_schema.is_property_type_list(typed_property))
- untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type']
+ untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"]
self.assertFalse(util_schema.is_property_type_list(untyped_property))
- anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default']
+ anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_list(anyof_property))
- oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default']
+ oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"]
self.assertFalse(util_schema.is_property_type_list(oneof_property))
def test_is_property_nullable(self):
- multi_typed_prop_nullable = TEST_SCHEMA_1['properties']['arg_optional_multi_type_none']
- self.assertTrue(util_schema.is_property_nullable(multi_typed_prop_nullable.get('type')))
-
- anyof_property_nullable = TEST_SCHEMA_4['properties']['arg_optional_no_default_anyof_none']
- self.assertTrue(util_schema.is_property_nullable(anyof_property_nullable.get('anyOf')))
-
- oneof_property_nullable = TEST_SCHEMA_5['properties']['arg_optional_no_default_oneof_none']
- self.assertTrue(util_schema.is_property_nullable(oneof_property_nullable.get('oneOf')))
-
- typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default']
+ multi_typed_prop_nullable = TEST_SCHEMA_1["properties"][
+ "arg_optional_multi_type_none"
+ ]
+ self.assertTrue(
+ util_schema.is_property_nullable(multi_typed_prop_nullable.get("type"))
+ )
+
+ anyof_property_nullable = TEST_SCHEMA_4["properties"][
+ "arg_optional_no_default_anyof_none"
+ ]
+ self.assertTrue(
+ util_schema.is_property_nullable(anyof_property_nullable.get("anyOf"))
+ )
+
+ oneof_property_nullable = TEST_SCHEMA_5["properties"][
+ "arg_optional_no_default_oneof_none"
+ ]
+ self.assertTrue(
+ util_schema.is_property_nullable(oneof_property_nullable.get("oneOf"))
+ )
+
+ typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"]
self.assertFalse(util_schema.is_property_nullable(typed_property))
- multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
- self.assertFalse(util_schema.is_property_nullable(multi_typed_property.get('type')))
+ multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
+ self.assertFalse(
+ util_schema.is_property_nullable(multi_typed_property.get("type"))
+ )
- anyof_property = TEST_SCHEMA_4['properties']['arg_optional_no_default']
- self.assertFalse(util_schema.is_property_nullable(anyof_property.get('anyOf')))
+ anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_no_default"]
+ self.assertFalse(util_schema.is_property_nullable(anyof_property.get("anyOf")))
- oneof_property = TEST_SCHEMA_5['properties']['arg_optional_no_default']
- self.assertFalse(util_schema.is_property_nullable(oneof_property.get('oneOf')))
+ oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_no_default"]
+ self.assertFalse(util_schema.is_property_nullable(oneof_property.get("oneOf")))
def test_is_attribute_type_array(self):
- multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type']
- self.assertTrue(util_schema.is_attribute_type_array(multi_coll_typed_prop.get('type')))
-
- array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array']
- self.assertTrue(util_schema.is_attribute_type_array(array_type_property.get('type')))
-
- multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
- self.assertFalse(util_schema.is_attribute_type_array(multi_non_coll_prop.get('type')))
-
- object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object']
- self.assertFalse(util_schema.is_attribute_type_array(object_type_property.get('type')))
+ multi_coll_typed_prop = TEST_SCHEMA_1["properties"][
+ "arg_optional_multi_collection_type"
+ ]
+ self.assertTrue(
+ util_schema.is_attribute_type_array(multi_coll_typed_prop.get("type"))
+ )
+
+ array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"]
+ self.assertTrue(
+ util_schema.is_attribute_type_array(array_type_property.get("type"))
+ )
+
+ multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
+ self.assertFalse(
+ util_schema.is_attribute_type_array(multi_non_coll_prop.get("type"))
+ )
+
+ object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"]
+ self.assertFalse(
+ util_schema.is_attribute_type_array(object_type_property.get("type"))
+ )
def test_is_attribute_type_object(self):
- multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type']
- self.assertTrue(util_schema.is_attribute_type_object(multi_coll_typed_prop.get('type')))
-
- object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object']
- self.assertTrue(util_schema.is_attribute_type_object(object_type_property.get('type')))
-
- multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type']
- self.assertFalse(util_schema.is_attribute_type_object(multi_non_coll_prop.get('type')))
-
- array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array']
- self.assertFalse(util_schema.is_attribute_type_object(array_type_property.get('type')))
+ multi_coll_typed_prop = TEST_SCHEMA_1["properties"][
+ "arg_optional_multi_collection_type"
+ ]
+ self.assertTrue(
+ util_schema.is_attribute_type_object(multi_coll_typed_prop.get("type"))
+ )
+
+ object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"]
+ self.assertTrue(
+ util_schema.is_attribute_type_object(object_type_property.get("type"))
+ )
+
+ multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"]
+ self.assertFalse(
+ util_schema.is_attribute_type_object(multi_non_coll_prop.get("type"))
+ )
+
+ array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"]
+ self.assertFalse(
+ util_schema.is_attribute_type_object(array_type_property.get("type"))
+ )
diff --git a/st2common/tests/unit/test_jsonify.py b/st2common/tests/unit/test_jsonify.py
index 801d912333..1feaac96b0 100644
--- a/st2common/tests/unit/test_jsonify.py
+++ b/st2common/tests/unit/test_jsonify.py
@@ -20,33 +20,32 @@
class JsonifyTests(unittest2.TestCase):
-
def test_none_object(self):
obj = None
self.assertIsNone(jsonify.json_loads(obj))
def test_no_keys(self):
- obj = {'foo': '{"bar": "baz"}'}
+ obj = {"foo": '{"bar": "baz"}'}
transformed_obj = jsonify.json_loads(obj)
- self.assertTrue(transformed_obj['foo']['bar'] == 'baz')
+ self.assertTrue(transformed_obj["foo"]["bar"] == "baz")
def test_no_json_value(self):
- obj = {'foo': 'bar'}
+ obj = {"foo": "bar"}
transformed_obj = jsonify.json_loads(obj)
- self.assertTrue(transformed_obj['foo'] == 'bar')
+ self.assertTrue(transformed_obj["foo"] == "bar")
def test_happy_case(self):
- obj = {'foo': '{"bar": "baz"}', 'yo': 'bibimbao'}
- transformed_obj = jsonify.json_loads(obj, ['yo'])
- self.assertTrue(transformed_obj['yo'] == 'bibimbao')
+ obj = {"foo": '{"bar": "baz"}', "yo": "bibimbao"}
+ transformed_obj = jsonify.json_loads(obj, ["yo"])
+ self.assertTrue(transformed_obj["yo"] == "bibimbao")
def test_try_loads(self):
# The function json.loads will fail and the function should return the original value.
- values = ['abc', 123, True, object()]
+ values = ["abc", 123, True, object()]
for value in values:
self.assertEqual(jsonify.try_loads(value), value)
# The function json.loads succeed.
d = '{"a": 1, "b": true}'
- expected = {'a': 1, 'b': True}
+ expected = {"a": 1, "b": True}
self.assertDictEqual(jsonify.try_loads(d), expected)
diff --git a/st2common/tests/unit/test_keyvalue_lookup.py b/st2common/tests/unit/test_keyvalue_lookup.py
index f37cc04dc9..afcd76901a 100644
--- a/st2common/tests/unit/test_keyvalue_lookup.py
+++ b/st2common/tests/unit/test_keyvalue_lookup.py
@@ -24,23 +24,29 @@
class TestKeyValueLookup(CleanDbTestCase):
def test_lookup_with_key_prefix(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='some:prefix:stanley:k5', value='v5',
- scope=FULL_USER_SCOPE))
+ KeyValuePair.add_or_update(
+ KeyValuePairDB(
+ name="some:prefix:stanley:k5", value="v5", scope=FULL_USER_SCOPE
+ )
+ )
# No prefix provided, should return None
- lookup = UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE)
- self.assertEqual(str(lookup.k5), '')
+ lookup = UserKeyValueLookup(user="stanley", scope=FULL_USER_SCOPE)
+ self.assertEqual(str(lookup.k5), "")
# Prefix provided
- lookup = UserKeyValueLookup(prefix='some:prefix', user='stanley', scope=FULL_USER_SCOPE)
- self.assertEqual(str(lookup.k5), 'v5')
+ lookup = UserKeyValueLookup(
+ prefix="some:prefix", user="stanley", scope=FULL_USER_SCOPE
+ )
+ self.assertEqual(str(lookup.k5), "v5")
def test_non_hierarchical_lookup(self):
- k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1'))
- k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2'))
- k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='k3', value='v3'))
- k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k4', value='v4',
- scope=FULL_USER_SCOPE))
+ k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1"))
+ k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2"))
+ k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="k3", value="v3"))
+ k4 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="stanley:k4", value="v4", scope=FULL_USER_SCOPE)
+ )
lookup = KeyValueLookup()
self.assertEqual(str(lookup.k1), k1.value)
@@ -49,108 +55,119 @@ def test_non_hierarchical_lookup(self):
# Scoped lookup
lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
- self.assertEqual(str(lookup.k4), '')
- user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
+ self.assertEqual(str(lookup.k4), "")
+ user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
self.assertEqual(str(user_lookup.k4), k4.value)
def test_hierarchical_lookup_dotted(self):
- k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1'))
- k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2'))
- k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3'))
- k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4',
- scope=FULL_USER_SCOPE))
+ k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1"))
+ k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2"))
+ k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3"))
+ k4 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE)
+ )
lookup = KeyValueLookup()
self.assertEqual(str(lookup.a.b), k1.value)
self.assertEqual(str(lookup.a.b.c), k2.value)
self.assertEqual(str(lookup.b.c), k3.value)
- self.assertEqual(str(lookup.a), '')
+ self.assertEqual(str(lookup.a), "")
# Scoped lookup
lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
- self.assertEqual(str(lookup.r.i.p), '')
- user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
+ self.assertEqual(str(lookup.r.i.p), "")
+ user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
self.assertEqual(str(user_lookup.r.i.p), k4.value)
def test_hierarchical_lookup_dict(self):
- k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1'))
- k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2'))
- k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3'))
- k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4',
- scope=FULL_USER_SCOPE))
+ k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1"))
+ k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2"))
+ k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3"))
+ k4 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE)
+ )
lookup = KeyValueLookup()
- self.assertEqual(str(lookup['a']['b']), k1.value)
- self.assertEqual(str(lookup['a']['b']['c']), k2.value)
- self.assertEqual(str(lookup['b']['c']), k3.value)
- self.assertEqual(str(lookup['a']), '')
+ self.assertEqual(str(lookup["a"]["b"]), k1.value)
+ self.assertEqual(str(lookup["a"]["b"]["c"]), k2.value)
+ self.assertEqual(str(lookup["b"]["c"]), k3.value)
+ self.assertEqual(str(lookup["a"]), "")
# Scoped lookup
lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
- self.assertEqual(str(lookup['r']['i']['p']), '')
- user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
- self.assertEqual(str(user_lookup['r']['i']['p']), k4.value)
+ self.assertEqual(str(lookup["r"]["i"]["p"]), "")
+ user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
+ self.assertEqual(str(user_lookup["r"]["i"]["p"]), k4.value)
def test_lookups_older_scope_names_backward_compatibility(self):
- k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1',
- scope=FULL_SYSTEM_SCOPE))
+ k1 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="a.b", value="v1", scope=FULL_SYSTEM_SCOPE)
+ )
lookup = KeyValueLookup(scope=SYSTEM_SCOPE)
- self.assertEqual(str(lookup['a']['b']), k1.value)
+ self.assertEqual(str(lookup["a"]["b"]), k1.value)
- k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4',
- scope=FULL_USER_SCOPE))
- user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user='stanley')
- self.assertEqual(str(user_lookup['r']['i']['p']), k2.value)
+ k2 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE)
+ )
+ user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user="stanley")
+ self.assertEqual(str(user_lookup["r"]["i"]["p"]), k2.value)
def test_user_scope_lookups_dot_in_user(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='first.last:r.i.p', value='v4',
- scope=FULL_USER_SCOPE))
- lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='first.last')
- self.assertEqual(str(lookup.r.i.p), 'v4')
- self.assertEqual(str(lookup['r']['i']['p']), 'v4')
+ KeyValuePair.add_or_update(
+ KeyValuePairDB(name="first.last:r.i.p", value="v4", scope=FULL_USER_SCOPE)
+ )
+ lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="first.last")
+ self.assertEqual(str(lookup.r.i.p), "v4")
+ self.assertEqual(str(lookup["r"]["i"]["p"]), "v4")
def test_user_scope_lookups_user_sep_in_name(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r:i:p', value='v4',
- scope=FULL_USER_SCOPE))
- lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
+ KeyValuePair.add_or_update(
+ KeyValuePairDB(name="stanley:r:i:p", value="v4", scope=FULL_USER_SCOPE)
+ )
+ lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
# This is the only way to lookup because USER_SEPARATOR (':') cannot be a part of
# variable name in Python.
- self.assertEqual(str(lookup['r:i:p']), 'v4')
+ self.assertEqual(str(lookup["r:i:p"]), "v4")
def test_missing_key_lookup(self):
lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
- self.assertEqual(str(lookup.missing_key), '')
- self.assertTrue(lookup.missing_key, 'Should be not none.')
+ self.assertEqual(str(lookup.missing_key), "")
+ self.assertTrue(lookup.missing_key, "Should be not none.")
- user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
- self.assertEqual(str(user_lookup.missing_key), '')
- self.assertTrue(user_lookup.missing_key, 'Should be not none.')
+ user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
+ self.assertEqual(str(user_lookup.missing_key), "")
+ self.assertTrue(user_lookup.missing_key, "Should be not none.")
def test_secret_lookup(self):
- secret_value = '0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3' + \
- '64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638'
- k1 = KeyValuePair.add_or_update(KeyValuePairDB(
- name='k1', value=secret_value,
- secret=True)
+ secret_value = (
+ "0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3"
+ + "64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638"
+ )
+ k1 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="k1", value=secret_value, secret=True)
)
- k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2'))
- k3 = KeyValuePair.add_or_update(KeyValuePairDB(
- name='stanley:k3', value=secret_value, scope=FULL_USER_SCOPE,
- secret=True)
+ k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2"))
+ k3 = KeyValuePair.add_or_update(
+ KeyValuePairDB(
+ name="stanley:k3",
+ value=secret_value,
+ scope=FULL_USER_SCOPE,
+ secret=True,
+ )
)
lookup = KeyValueLookup()
self.assertEqual(str(lookup.k1), k1.value)
self.assertEqual(str(lookup.k2), k2.value)
- self.assertEqual(str(lookup.k3), '')
+ self.assertEqual(str(lookup.k3), "")
- user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley')
+ user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley")
self.assertEqual(str(user_lookup.k3), k3.value)
def test_lookup_cast(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='count', value='5.5'))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="count", value="5.5"))
lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE)
- self.assertEqual(str(lookup.count), '5.5')
+ self.assertEqual(str(lookup.count), "5.5")
self.assertEqual(float(lookup.count), 5.5)
self.assertEqual(int(lookup.count), 5)
diff --git a/st2common/tests/unit/test_keyvalue_system_model.py b/st2common/tests/unit/test_keyvalue_system_model.py
index ff834f2d6d..a8ea10822b 100644
--- a/st2common/tests/unit/test_keyvalue_system_model.py
+++ b/st2common/tests/unit/test_keyvalue_system_model.py
@@ -21,15 +21,19 @@
class UserKeyReferenceSystemModelTest(unittest2.TestCase):
-
def test_to_string_reference(self):
- key_ref = UserKeyReference.to_string_reference(user='stanley', name='foo')
- self.assertEqual(key_ref, 'stanley:foo')
- self.assertRaises(ValueError, UserKeyReference.to_string_reference, user=None, name='foo')
+ key_ref = UserKeyReference.to_string_reference(user="stanley", name="foo")
+ self.assertEqual(key_ref, "stanley:foo")
+ self.assertRaises(
+ ValueError, UserKeyReference.to_string_reference, user=None, name="foo"
+ )
def test_from_string_reference(self):
- user, name = UserKeyReference.from_string_reference('stanley:foo')
- self.assertEqual(user, 'stanley')
- self.assertEqual(name, 'foo')
- self.assertRaises(InvalidUserKeyReferenceError, UserKeyReference.from_string_reference,
- 'this_key_has_no_sep')
+ user, name = UserKeyReference.from_string_reference("stanley:foo")
+ self.assertEqual(user, "stanley")
+ self.assertEqual(name, "foo")
+ self.assertRaises(
+ InvalidUserKeyReferenceError,
+ UserKeyReference.from_string_reference,
+ "this_key_has_no_sep",
+ )
diff --git a/st2common/tests/unit/test_logger.py b/st2common/tests/unit/test_logger.py
index 30d18e9f89..79158b8b7f 100644
--- a/st2common/tests/unit/test_logger.py
+++ b/st2common/tests/unit/test_logger.py
@@ -36,13 +36,13 @@
import st2tests.config as tests_config
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
-RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources'))
-CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, 'logging.conf')
+RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources"))
+CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, "logging.conf")
MOCK_MASKED_ATTRIBUTES_BLACKLIST = [
- 'blacklisted_1',
- 'blacklisted_2',
- 'blacklisted_3',
+ "blacklisted_1",
+ "blacklisted_2",
+ "blacklisted_3",
]
@@ -69,9 +69,8 @@ def setUp(self):
self.cfg_fd, self.cfg_path = tempfile.mkstemp()
self.info_log_fd, self.info_log_path = tempfile.mkstemp()
self.audit_log_fd, self.audit_log_path = tempfile.mkstemp()
- with open(self.cfg_path, 'a') as f:
- f.write(self.config_text.format(self.info_log_path,
- self.audit_log_path))
+ with open(self.cfg_path, "a") as f:
+ f.write(self.config_text.format(self.info_log_path, self.audit_log_path))
def tearDown(self):
self._remove_tempfile(self.cfg_fd, self.cfg_path)
@@ -84,7 +83,7 @@ def _remove_tempfile(self, fd, path):
os.unlink(path)
def test_logger_setup_failure(self):
- config_file = '/tmp/abc123'
+ config_file = "/tmp/abc123"
self.assertFalse(os.path.exists(config_file))
self.assertRaises(Exception, logging.setup, config_file)
@@ -146,7 +145,7 @@ def test_format(self):
formatter = ConsoleLogFormatter()
# No extra attributes
- mock_message = 'test message 1'
+ mock_message = "test message 1"
record = MockRecord()
record.msg = mock_message
@@ -155,94 +154,109 @@ def test_format(self):
self.assertEqual(message, mock_message)
# Some extra attributes
- mock_message = 'test message 2'
+ mock_message = "test message 2"
record = MockRecord()
record.msg = mock_message
# Add "extra" attributes
record._user_id = 1
- record._value = 'bar'
- record.ignored = 'foo' # this one is ignored since it doesnt have a prefix
+ record._value = "bar"
+ record.ignored = "foo" # this one is ignored since it doesnt have a prefix
message = formatter.format(record=record)
- expected = 'test message 2 (value=\'bar\',user_id=1)'
+ expected = "test message 2 (value='bar',user_id=1)"
self.assertEqual(sorted(message), sorted(expected))
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
def test_format_blacklisted_attributes_are_masked(self):
formatter = ConsoleLogFormatter()
- mock_message = 'test message 1'
+ mock_message = "test message 1"
record = MockRecord()
record.msg = mock_message
# Add "extra" attributes
- record._blacklisted_1 = 'test value 1'
- record._blacklisted_2 = 'test value 2'
- record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'}
- record._foo1 = 'bar'
+ record._blacklisted_1 = "test value 1"
+ record._blacklisted_2 = "test value 2"
+ record._blacklisted_3 = {
+ "key1": "val1",
+ "blacklisted_1": "val2",
+ "key3": "val3",
+ }
+ record._foo1 = "bar"
message = formatter.format(record=record)
- expected = ("test message 1 (blacklisted_1='********',blacklisted_2='********',"
- "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'},"
- "foo1='bar')")
+ expected = (
+ "test message 1 (blacklisted_1='********',blacklisted_2='********',"
+ "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'},"
+ "foo1='bar')"
+ )
self.assertEqual(sorted(message), sorted(expected))
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
def test_format_custom_blacklist_attributes_are_masked(self):
- cfg.CONF.set_override(group='log', name='mask_secrets_blacklist',
- override=['blacklisted_4', 'blacklisted_5'])
+ cfg.CONF.set_override(
+ group="log",
+ name="mask_secrets_blacklist",
+ override=["blacklisted_4", "blacklisted_5"],
+ )
formatter = ConsoleLogFormatter()
- mock_message = 'test message 1'
+ mock_message = "test message 1"
record = MockRecord()
record.msg = mock_message
# Add "extra" attributes
- record._blacklisted_1 = 'test value 1'
- record._blacklisted_2 = 'test value 2'
- record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'}
- record._blacklisted_4 = 'fowa'
- record._blacklisted_5 = 'fiva'
- record._foo1 = 'bar'
+ record._blacklisted_1 = "test value 1"
+ record._blacklisted_2 = "test value 2"
+ record._blacklisted_3 = {
+ "key1": "val1",
+ "blacklisted_1": "val2",
+ "key3": "val3",
+ }
+ record._blacklisted_4 = "fowa"
+ record._blacklisted_5 = "fiva"
+ record._foo1 = "bar"
message = formatter.format(record=record)
- expected = ("test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********',"
- "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'},"
- "blacklisted_4='********',blacklisted_5='********')")
+ expected = (
+ "test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********',"
+ "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'},"
+ "blacklisted_4='********',blacklisted_5='********')"
+ )
self.assertEqual(sorted(message), sorted(expected))
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
def test_format_secret_action_parameters_are_masked(self):
formatter = ConsoleLogFormatter()
- mock_message = 'test message 1'
+ mock_message = "test message 1"
parameters = {
- 'parameter1': {
- 'type': 'string',
- 'required': False
- },
- 'parameter2': {
- 'type': 'string',
- 'required': False,
- 'secret': True
- }
+ "parameter1": {"type": "string", "required": False},
+ "parameter2": {"type": "string", "required": False, "secret": True},
}
- mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters)
+ mock_action_db = ActionDB(
+ pack="testpack", name="test.action", parameters=parameters
+ )
action = mock_action_db.to_serializable_dict()
- parameters = {
- 'parameter1': 'value1',
- 'parameter2': 'value2'
- }
- mock_action_execution_db = ActionExecutionDB(action=action, parameters=parameters)
+ parameters = {"parameter1": "value1", "parameter2": "value2"}
+ mock_action_execution_db = ActionExecutionDB(
+ action=action, parameters=parameters
+ )
record = MockRecord()
record.msg = mock_message
@@ -250,97 +264,94 @@ def test_format_secret_action_parameters_are_masked(self):
# Add "extra" attributes
record._action_execution_db = mock_action_execution_db
- expected_msg_part = (r"'parameters': {u?'parameter1': u?'value1', "
- r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}")
+ expected_msg_part = (
+ r"'parameters': {u?'parameter1': u?'value1', "
+ r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}"
+ )
message = formatter.format(record=record)
- self.assertIn('test message 1', message)
+ self.assertIn("test message 1", message)
self.assertRegexpMatches(message, expected_msg_part)
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
def test_format_rule(self):
expected_result = {
- 'description': 'Test description',
- 'tags': [],
- 'type': {
- 'ref': 'standard',
- 'parameters': {}},
- 'enabled': True,
- 'trigger': 'test tigger',
- 'metadata_file': None,
- 'context': {},
- 'criteria': {},
- 'action': {
- 'ref': '1234',
- 'parameters': {'b': 2}},
- 'uid': 'rule:testpack:test.action',
- 'pack': 'testpack',
- 'ref': 'testpack.test.action',
- 'id': None,
- 'name': 'test.action'
+ "description": "Test description",
+ "tags": [],
+ "type": {"ref": "standard", "parameters": {}},
+ "enabled": True,
+ "trigger": "test tigger",
+ "metadata_file": None,
+ "context": {},
+ "criteria": {},
+ "action": {"ref": "1234", "parameters": {"b": 2}},
+ "uid": "rule:testpack:test.action",
+ "pack": "testpack",
+ "ref": "testpack.test.action",
+ "id": None,
+ "name": "test.action",
}
- mock_rule_db = RuleDB(pack='testpack',
- name='test.action',
- description='Test description',
- trigger='test tigger',
- action={'ref': '1234', 'parameters': {'b': 2}})
+ mock_rule_db = RuleDB(
+ pack="testpack",
+ name="test.action",
+ description="Test description",
+ trigger="test tigger",
+ action={"ref": "1234", "parameters": {"b": 2}},
+ )
result = mock_rule_db.to_serializable_dict()
self.assertEqual(expected_result, result)
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
- @mock.patch('st2common.models.db.rule.RuleDB._get_referenced_action_model')
- def test_format_secret_rule_parameters_are_masked(self, mock__get_referenced_action_model):
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
+ @mock.patch("st2common.models.db.rule.RuleDB._get_referenced_action_model")
+ def test_format_secret_rule_parameters_are_masked(
+ self, mock__get_referenced_action_model
+ ):
expected_result = {
- 'description': 'Test description',
- 'tags': [],
- 'type': {
- 'ref': 'standard',
- 'parameters': {}},
- 'enabled': True,
- 'trigger': 'test tigger',
- 'metadata_file': None,
- 'context': {},
- 'criteria': {},
- 'action': {
- 'ref': '1234',
- 'parameters': {
- 'parameter1': 'value1',
- 'parameter2': '********'
- }},
- 'uid': 'rule:testpack:test.action',
- 'pack': 'testpack',
- 'ref': 'testpack.test.action',
- 'id': None,
- 'name': 'test.action'
+ "description": "Test description",
+ "tags": [],
+ "type": {"ref": "standard", "parameters": {}},
+ "enabled": True,
+ "trigger": "test tigger",
+ "metadata_file": None,
+ "context": {},
+ "criteria": {},
+ "action": {
+ "ref": "1234",
+ "parameters": {"parameter1": "value1", "parameter2": "********"},
+ },
+ "uid": "rule:testpack:test.action",
+ "pack": "testpack",
+ "ref": "testpack.test.action",
+ "id": None,
+ "name": "test.action",
}
parameters = {
- 'parameter1': {
- 'type': 'string',
- 'required': False
- },
- 'parameter2': {
- 'type': 'string',
- 'required': False,
- 'secret': True
- }
+ "parameter1": {"type": "string", "required": False},
+ "parameter2": {"type": "string", "required": False, "secret": True},
}
- mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters)
+ mock_action_db = ActionDB(
+ pack="testpack", name="test.action", parameters=parameters
+ )
mock__get_referenced_action_model.return_value = mock_action_db
- cfg.CONF.set_override(group='log', name='mask_secrets',
- override=True)
- mock_rule_db = RuleDB(pack='testpack',
- name='test.action',
- description='Test description',
- trigger='test tigger',
- action={'ref': '1234',
- 'parameters': {
- 'parameter1': 'value1',
- 'parameter2': 'value2'
- }})
+ cfg.CONF.set_override(group="log", name="mask_secrets", override=True)
+ mock_rule_db = RuleDB(
+ pack="testpack",
+ name="test.action",
+ description="Test description",
+ trigger="test tigger",
+ action={
+ "ref": "1234",
+ "parameters": {"parameter1": "value1", "parameter2": "value2"},
+ },
+ )
result = mock_rule_db.to_serializable_dict(True)
@@ -355,11 +366,18 @@ def setUpClass(cls):
def test_format(self):
formatter = GelfLogFormatter()
- expected_keys = ['version', 'host', 'short_message', 'full_message',
- 'timestamp', 'timestamp_f', 'level']
+ expected_keys = [
+ "version",
+ "host",
+ "short_message",
+ "full_message",
+ "timestamp",
+ "timestamp_f",
+ "level",
+ ]
# No extra attributes
- mock_message = 'test message 1'
+ mock_message = "test message 1"
record = MockRecord()
record.msg = mock_message
@@ -370,19 +388,19 @@ def test_format(self):
for key in expected_keys:
self.assertIn(key, parsed)
- self.assertEqual(parsed['short_message'], mock_message)
- self.assertEqual(parsed['full_message'], mock_message)
+ self.assertEqual(parsed["short_message"], mock_message)
+ self.assertEqual(parsed["full_message"], mock_message)
# Some extra attributes
- mock_message = 'test message 2'
+ mock_message = "test message 2"
record = MockRecord()
record.msg = mock_message
# Add "extra" attributes
record._user_id = 1
- record._value = 'bar'
- record.ignored = 'foo' # this one is ignored since it doesnt have a prefix
+ record._value = "bar"
+ record.ignored = "foo" # this one is ignored since it doesnt have a prefix
record.created = 1234.5678
message = formatter.format(record=record)
@@ -391,16 +409,16 @@ def test_format(self):
for key in expected_keys:
self.assertIn(key, parsed)
- self.assertEqual(parsed['short_message'], mock_message)
- self.assertEqual(parsed['full_message'], mock_message)
- self.assertEqual(parsed['_user_id'], 1)
- self.assertEqual(parsed['_value'], 'bar')
- self.assertEqual(parsed['timestamp'], 1234)
- self.assertEqual(parsed['timestamp_f'], 1234.5678)
- self.assertNotIn('ignored', parsed)
+ self.assertEqual(parsed["short_message"], mock_message)
+ self.assertEqual(parsed["full_message"], mock_message)
+ self.assertEqual(parsed["_user_id"], 1)
+ self.assertEqual(parsed["_value"], "bar")
+ self.assertEqual(parsed["timestamp"], 1234)
+ self.assertEqual(parsed["timestamp_f"], 1234.5678)
+ self.assertNotIn("ignored", parsed)
# Record with an exception
- mock_exception = Exception('mock exception bar')
+ mock_exception = Exception("mock exception bar")
try:
raise mock_exception
@@ -408,7 +426,7 @@ def test_format(self):
mock_exc_info = sys.exc_info()
# Some extra attributes
- mock_message = 'test message 3'
+ mock_message = "test message 3"
record = MockRecord()
record.msg = mock_message
@@ -420,69 +438,77 @@ def test_format(self):
for key in expected_keys:
self.assertIn(key, parsed)
- self.assertEqual(parsed['short_message'], mock_message)
- self.assertIn(mock_message, parsed['full_message'])
- self.assertIn('Traceback', parsed['full_message'])
- self.assertIn('_exception', parsed)
- self.assertIn('_traceback', parsed)
+ self.assertEqual(parsed["short_message"], mock_message)
+ self.assertIn(mock_message, parsed["full_message"])
+ self.assertIn("Traceback", parsed["full_message"])
+ self.assertIn("_exception", parsed)
+ self.assertIn("_traceback", parsed)
def test_extra_object_serialization(self):
class MyClass1(object):
def __repr__(self):
- return 'repr'
+ return "repr"
class MyClass2(object):
def to_dict(self):
- return 'to_dict'
+ return "to_dict"
class MyClass3(object):
def to_serializable_dict(self, mask_secrets=False):
- return 'to_serializable_dict'
+ return "to_serializable_dict"
formatter = GelfLogFormatter()
record = MockRecord()
- record.msg = 'message'
+ record.msg = "message"
record._obj1 = MyClass1()
record._obj2 = MyClass2()
record._obj3 = MyClass3()
message = formatter.format(record=record)
parsed = json.loads(message)
- self.assertEqual(parsed['_obj1'], 'repr')
- self.assertEqual(parsed['_obj2'], 'to_dict')
- self.assertEqual(parsed['_obj3'], 'to_serializable_dict')
-
- @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST',
- MOCK_MASKED_ATTRIBUTES_BLACKLIST)
+ self.assertEqual(parsed["_obj1"], "repr")
+ self.assertEqual(parsed["_obj2"], "to_dict")
+ self.assertEqual(parsed["_obj3"], "to_serializable_dict")
+
+ @mock.patch(
+ "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST",
+ MOCK_MASKED_ATTRIBUTES_BLACKLIST,
+ )
def test_format_blacklisted_attributes_are_masked(self):
formatter = GelfLogFormatter()
# Some extra attributes
- mock_message = 'test message 1'
+ mock_message = "test message 1"
record = MockRecord()
record.msg = mock_message
# Add "extra" attributes
- record._blacklisted_1 = 'test value 1'
- record._blacklisted_2 = 'test value 2'
- record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'}
- record._foo1 = 'bar'
+ record._blacklisted_1 = "test value 1"
+ record._blacklisted_2 = "test value 2"
+ record._blacklisted_3 = {
+ "key1": "val1",
+ "blacklisted_1": "val2",
+ "key3": "val3",
+ }
+ record._foo1 = "bar"
message = formatter.format(record=record)
parsed = json.loads(message)
- self.assertEqual(parsed['_blacklisted_1'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(parsed['_blacklisted_2'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(parsed['_blacklisted_3']['key1'], 'val1')
- self.assertEqual(parsed['_blacklisted_3']['blacklisted_1'], MASKED_ATTRIBUTE_VALUE)
- self.assertEqual(parsed['_blacklisted_3']['key3'], 'val3')
- self.assertEqual(parsed['_foo1'], 'bar')
+ self.assertEqual(parsed["_blacklisted_1"], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(parsed["_blacklisted_2"], MASKED_ATTRIBUTE_VALUE)
+ self.assertEqual(parsed["_blacklisted_3"]["key1"], "val1")
+ self.assertEqual(
+ parsed["_blacklisted_3"]["blacklisted_1"], MASKED_ATTRIBUTE_VALUE
+ )
+ self.assertEqual(parsed["_blacklisted_3"]["key3"], "val3")
+ self.assertEqual(parsed["_foo1"], "bar")
# Assert that the original dict is left unmodified
- self.assertEqual(record._blacklisted_1, 'test value 1')
- self.assertEqual(record._blacklisted_2, 'test value 2')
- self.assertEqual(record._blacklisted_3['key1'], 'val1')
- self.assertEqual(record._blacklisted_3['blacklisted_1'], 'val2')
- self.assertEqual(record._blacklisted_3['key3'], 'val3')
+ self.assertEqual(record._blacklisted_1, "test value 1")
+ self.assertEqual(record._blacklisted_2, "test value 2")
+ self.assertEqual(record._blacklisted_3["key1"], "val1")
+ self.assertEqual(record._blacklisted_3["blacklisted_1"], "val2")
+ self.assertEqual(record._blacklisted_3["key3"], "val3")
diff --git a/st2common/tests/unit/test_logging.py b/st2common/tests/unit/test_logging.py
index 7dc4fc1b6d..ebb75b6f1d 100644
--- a/st2common/tests/unit/test_logging.py
+++ b/st2common/tests/unit/test_logging.py
@@ -21,25 +21,29 @@
from python_runner import python_runner
from st2common import runners
-__all__ = [
- 'LoggingMiscUtilsTestCase'
-]
+__all__ = ["LoggingMiscUtilsTestCase"]
class LoggingMiscUtilsTestCase(unittest2.TestCase):
def test_get_logger_name_for_module(self):
logger_name = get_logger_name_for_module(sensormanager)
- self.assertEqual(logger_name, 'st2reactor.cmd.sensormanager')
+ self.assertEqual(logger_name, "st2reactor.cmd.sensormanager")
logger_name = get_logger_name_for_module(python_runner)
- result = logger_name.endswith('contrib.runners.python_runner.python_runner.python_runner')
+ result = logger_name.endswith(
+ "contrib.runners.python_runner.python_runner.python_runner"
+ )
self.assertTrue(result)
- logger_name = get_logger_name_for_module(python_runner, exclude_module_name=True)
- self.assertTrue(logger_name.endswith('contrib.runners.python_runner.python_runner'))
+ logger_name = get_logger_name_for_module(
+ python_runner, exclude_module_name=True
+ )
+ self.assertTrue(
+ logger_name.endswith("contrib.runners.python_runner.python_runner")
+ )
logger_name = get_logger_name_for_module(runners)
- self.assertEqual(logger_name, 'st2common.runners.__init__')
+ self.assertEqual(logger_name, "st2common.runners.__init__")
logger_name = get_logger_name_for_module(runners, exclude_module_name=True)
- self.assertEqual(logger_name, 'st2common.runners')
+ self.assertEqual(logger_name, "st2common.runners")
diff --git a/st2common/tests/unit/test_logging_middleware.py b/st2common/tests/unit/test_logging_middleware.py
index 8a59177beb..b7d34de0bc 100644
--- a/st2common/tests/unit/test_logging_middleware.py
+++ b/st2common/tests/unit/test_logging_middleware.py
@@ -21,18 +21,15 @@
from st2common.middleware.logging import LoggingMiddleware
from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE
-__all__ = [
- 'LoggingMiddlewareTestCase'
-]
+__all__ = ["LoggingMiddlewareTestCase"]
class LoggingMiddlewareTestCase(unittest2.TestCase):
- @mock.patch('st2common.middleware.logging.LOG')
- @mock.patch('st2common.middleware.logging.Request')
+ @mock.patch("st2common.middleware.logging.LOG")
+ @mock.patch("st2common.middleware.logging.Request")
def test_secret_parameters_are_masked_in_log_message(self, mock_request, mock_log):
-
def app(environ, custom_start_response):
- custom_start_response(status='200 OK', headers=[('Content-Length', 100)])
+ custom_start_response(status="200 OK", headers=[("Content-Length", 100)])
return [None]
router = mock.Mock()
@@ -40,35 +37,38 @@ def app(environ, custom_start_response):
router.match.return_value = (endpoint, None)
middleware = LoggingMiddleware(app=app, router=router)
- cfg.CONF.set_override(group='log', name='mask_secrets_blacklist',
- override=['blacklisted_4', 'blacklisted_5'])
+ cfg.CONF.set_override(
+ group="log",
+ name="mask_secrets_blacklist",
+ override=["blacklisted_4", "blacklisted_5"],
+ )
environ = {}
mock_request.return_value.GET.dict_of_lists.return_value = {
- 'foo': 'bar',
- 'bar': 'baz',
- 'x-auth-token': 'secret',
- 'st2-api-key': 'secret',
- 'password': 'secret',
- 'st2_auth_token': 'secret',
- 'token': 'secret',
- 'blacklisted_4': 'super secret',
- 'blacklisted_5': 'super secret',
+ "foo": "bar",
+ "bar": "baz",
+ "x-auth-token": "secret",
+ "st2-api-key": "secret",
+ "password": "secret",
+ "st2_auth_token": "secret",
+ "token": "secret",
+ "blacklisted_4": "super secret",
+ "blacklisted_5": "super secret",
}
middleware(environ=environ, start_response=mock.Mock())
expected_query = {
- 'foo': 'bar',
- 'bar': 'baz',
- 'x-auth-token': MASKED_ATTRIBUTE_VALUE,
- 'st2-api-key': MASKED_ATTRIBUTE_VALUE,
- 'password': MASKED_ATTRIBUTE_VALUE,
- 'token': MASKED_ATTRIBUTE_VALUE,
- 'st2_auth_token': MASKED_ATTRIBUTE_VALUE,
- 'blacklisted_4': MASKED_ATTRIBUTE_VALUE,
- 'blacklisted_5': MASKED_ATTRIBUTE_VALUE,
+ "foo": "bar",
+ "bar": "baz",
+ "x-auth-token": MASKED_ATTRIBUTE_VALUE,
+ "st2-api-key": MASKED_ATTRIBUTE_VALUE,
+ "password": MASKED_ATTRIBUTE_VALUE,
+ "token": MASKED_ATTRIBUTE_VALUE,
+ "st2_auth_token": MASKED_ATTRIBUTE_VALUE,
+ "blacklisted_4": MASKED_ATTRIBUTE_VALUE,
+ "blacklisted_5": MASKED_ATTRIBUTE_VALUE,
}
call_kwargs = mock_log.info.call_args_list[0][1]
- query = call_kwargs['extra']['query']
+ query = call_kwargs["extra"]["query"]
self.assertEqual(query, expected_query)
diff --git a/st2common/tests/unit/test_metrics.py b/st2common/tests/unit/test_metrics.py
index 084db97c62..4b0df66aa1 100644
--- a/st2common/tests/unit/test_metrics.py
+++ b/st2common/tests/unit/test_metrics.py
@@ -29,16 +29,16 @@
from st2common.util.date import get_datetime_utc_now
__all__ = [
- 'TestBaseMetricsDriver',
- 'TestStatsDMetricsDriver',
- 'TestCounterContextManager',
- 'TestTimerContextManager',
- 'TestCounterWithTimerContextManager'
+ "TestBaseMetricsDriver",
+ "TestStatsDMetricsDriver",
+ "TestCounterContextManager",
+ "TestTimerContextManager",
+ "TestCounterWithTimerContextManager",
]
-cfg.CONF.set_override('driver', 'noop', group='metrics')
-cfg.CONF.set_override('host', '127.0.0.1', group='metrics')
-cfg.CONF.set_override('port', 8080, group='metrics')
+cfg.CONF.set_override("driver", "noop", group="metrics")
+cfg.CONF.set_override("host", "127.0.0.1", group="metrics")
+cfg.CONF.set_override("port", 8080, group="metrics")
class TestBaseMetricsDriver(unittest2.TestCase):
@@ -48,45 +48,43 @@ def setUp(self):
self._driver = base.BaseMetricsDriver()
def test_time(self):
- self._driver.time('test', 10)
+ self._driver.time("test", 10)
def test_inc_counter(self):
- self._driver.inc_counter('test')
+ self._driver.inc_counter("test")
def test_dec_timer(self):
- self._driver.dec_counter('test')
+ self._driver.dec_counter("test")
class TestStatsDMetricsDriver(unittest2.TestCase):
_driver = None
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def setUp(self, statsd):
- cfg.CONF.set_override(name='prefix', group='metrics', override=None)
+ cfg.CONF.set_override(name="prefix", group="metrics", override=None)
self._driver = StatsdDriver()
statsd.Connection.set_defaults.assert_called_once_with(
- host=cfg.CONF.metrics.host,
- port=cfg.CONF.metrics.port,
- sample_rate=1.0
+ host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port, sample_rate=1.0
)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_time(self, statsd):
mock_timer = MagicMock()
- statsd.Timer('').send.side_effect = mock_timer
- params = ('test', 10)
+ statsd.Timer("").send.side_effect = mock_timer
+ params = ("test", 10)
self._driver.time(*params)
- statsd.Timer('').send.assert_called_with('st2.test', 10)
+ statsd.Timer("").send.assert_called_with("st2.test", 10)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_time_with_float(self, statsd):
mock_timer = MagicMock()
- statsd.Timer('').send.side_effect = mock_timer
- params = ('test', 10.5)
+ statsd.Timer("").send.side_effect = mock_timer
+ params = ("test", 10.5)
self._driver.time(*params)
- statsd.Timer().send.assert_called_with('st2.test', 10.5)
+ statsd.Timer().send.assert_called_with("st2.test", 10.5)
def test_time_with_invalid_key(self):
params = (2, 2)
@@ -94,21 +92,21 @@ def test_time_with_invalid_key(self):
self._driver.time(*params)
def test_time_with_invalid_time(self):
- params = ('test', '1')
+ params = ("test", "1")
with self.assertRaises(AssertionError):
self._driver.time(*params)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_inc_counter_with_default_amount(self, statsd):
- key = 'test'
+ key = "test"
mock_counter = MagicMock()
statsd.Counter(key).increment.side_effect = mock_counter
self._driver.inc_counter(key)
mock_counter.assert_called_once_with(delta=1)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_inc_counter_with_amount(self, statsd):
- params = ('test', 2)
+ params = ("test", 2)
mock_counter = MagicMock()
statsd.Counter(params[0]).increment.side_effect = mock_counter
self._driver.inc_counter(*params)
@@ -120,21 +118,21 @@ def test_inc_timer_with_invalid_key(self):
self._driver.inc_counter(*params)
def test_inc_timer_with_invalid_amount(self):
- params = ('test', '1')
+ params = ("test", "1")
with self.assertRaises(AssertionError):
self._driver.inc_counter(*params)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_dec_timer_with_default_amount(self, statsd):
- key = 'test'
+ key = "test"
mock_counter = MagicMock()
statsd.Counter().decrement.side_effect = mock_counter
self._driver.dec_counter(key)
mock_counter.assert_called_once_with(delta=1)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_dec_timer_with_amount(self, statsd):
- params = ('test', 2)
+ params = ("test", 2)
mock_counter = MagicMock()
statsd.Counter().decrement.side_effect = mock_counter
self._driver.dec_counter(*params)
@@ -146,41 +144,41 @@ def test_dec_timer_with_invalid_key(self):
self._driver.dec_counter(*params)
def test_dec_timer_with_invalid_amount(self):
- params = ('test', '1')
+ params = ("test", "1")
with self.assertRaises(AssertionError):
self._driver.dec_counter(*params)
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_set_gauge_success(self, statsd):
- params = ('key', 100)
+ params = ("key", 100)
mock_gauge = MagicMock()
statsd.Gauge().send.side_effect = mock_gauge
self._driver.set_gauge(*params)
mock_gauge.assert_called_once_with(None, params[1])
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_inc_gauge_success(self, statsd):
- params = ('key1',)
+ params = ("key1",)
mock_gauge = MagicMock()
statsd.Gauge().increment.side_effect = mock_gauge
self._driver.inc_gauge(*params)
mock_gauge.assert_called_once_with(None, 1)
- params = ('key2', 100)
+ params = ("key2", 100)
mock_gauge = MagicMock()
statsd.Gauge().increment.side_effect = mock_gauge
self._driver.inc_gauge(*params)
mock_gauge.assert_called_once_with(None, params[1])
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_dec_gauge_success(self, statsd):
- params = ('key1',)
+ params = ("key1",)
mock_gauge = MagicMock()
statsd.Gauge().decrement.side_effect = mock_gauge
self._driver.dec_gauge(*params)
mock_gauge.assert_called_once_with(None, 1)
- params = ('key2', 100)
+ params = ("key2", 100)
mock_gauge = MagicMock()
statsd.Gauge().decrement.side_effect = mock_gauge
self._driver.dec_gauge(*params)
@@ -188,71 +186,71 @@ def test_dec_gauge_success(self, statsd):
def test_get_full_key_name(self):
# No prefix specified in the config
- cfg.CONF.set_override(name='prefix', group='metrics', override=None)
+ cfg.CONF.set_override(name="prefix", group="metrics", override=None)
- result = get_full_key_name('api.requests')
- self.assertEqual(result, 'st2.api.requests')
+ result = get_full_key_name("api.requests")
+ self.assertEqual(result, "st2.api.requests")
# Prefix is defined in the config
- cfg.CONF.set_override(name='prefix', group='metrics', override='staging')
+ cfg.CONF.set_override(name="prefix", group="metrics", override="staging")
- result = get_full_key_name('api.requests')
- self.assertEqual(result, 'st2.staging.api.requests')
+ result = get_full_key_name("api.requests")
+ self.assertEqual(result, "st2.staging.api.requests")
- cfg.CONF.set_override(name='prefix', group='metrics', override='prod')
+ cfg.CONF.set_override(name="prefix", group="metrics", override="prod")
- result = get_full_key_name('api.requests')
- self.assertEqual(result, 'st2.prod.api.requests')
+ result = get_full_key_name("api.requests")
+ self.assertEqual(result, "st2.prod.api.requests")
- @patch('st2common.metrics.drivers.statsd_driver.LOG')
- @patch('st2common.metrics.drivers.statsd_driver.statsd')
+ @patch("st2common.metrics.drivers.statsd_driver.LOG")
+ @patch("st2common.metrics.drivers.statsd_driver.statsd")
def test_driver_socket_exceptions_are_not_fatal(self, statsd, mock_log):
# Socket errors such as DNS resolution errors should be considered non fatal and ignored
mock_logger = mock.Mock()
StatsdDriver.logger = mock_logger
# 1. timer
- mock_timer = MagicMock(side_effect=socket.error('error 1'))
- statsd.Timer('').send.side_effect = mock_timer
- params = ('test', 10)
+ mock_timer = MagicMock(side_effect=socket.error("error 1"))
+ statsd.Timer("").send.side_effect = mock_timer
+ params = ("test", 10)
self._driver.time(*params)
- statsd.Timer('').send.assert_called_with('st2.test', 10)
+ statsd.Timer("").send.assert_called_with("st2.test", 10)
# 2. counter
- key = 'test'
- mock_counter = MagicMock(side_effect=socket.error('error 2'))
+ key = "test"
+ mock_counter = MagicMock(side_effect=socket.error("error 2"))
statsd.Counter(key).increment.side_effect = mock_counter
self._driver.inc_counter(key)
mock_counter.assert_called_once_with(delta=1)
- key = 'test'
- mock_counter = MagicMock(side_effect=socket.error('error 3'))
+ key = "test"
+ mock_counter = MagicMock(side_effect=socket.error("error 3"))
statsd.Counter(key).decrement.side_effect = mock_counter
self._driver.dec_counter(key)
mock_counter.assert_called_once_with(delta=1)
# 3. gauge
- params = ('key', 100)
- mock_gauge = MagicMock(side_effect=socket.error('error 4'))
+ params = ("key", 100)
+ mock_gauge = MagicMock(side_effect=socket.error("error 4"))
statsd.Gauge().send.side_effect = mock_gauge
self._driver.set_gauge(*params)
mock_gauge.assert_called_once_with(None, params[1])
- params = ('key1',)
- mock_gauge = MagicMock(side_effect=socket.error('error 5'))
+ params = ("key1",)
+ mock_gauge = MagicMock(side_effect=socket.error("error 5"))
statsd.Gauge().increment.side_effect = mock_gauge
self._driver.inc_gauge(*params)
mock_gauge.assert_called_once_with(None, 1)
- params = ('key1',)
- mock_gauge = MagicMock(side_effect=socket.error('error 6'))
+ params = ("key1",)
+ mock_gauge = MagicMock(side_effect=socket.error("error 6"))
statsd.Gauge().decrement.side_effect = mock_gauge
self._driver.dec_gauge(*params)
mock_gauge.assert_called_once_with(None, 1)
class TestCounterContextManager(unittest2.TestCase):
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.METRICS")
def test_counter(self, metrics_patch):
test_key = "test_key"
with base.Counter(test_key):
@@ -261,8 +259,8 @@ def test_counter(self, metrics_patch):
class TestTimerContextManager(unittest2.TestCase):
- @patch('st2common.metrics.base.get_datetime_utc_now')
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.get_datetime_utc_now")
+ @patch("st2common.metrics.base.METRICS")
def test_time(self, metrics_patch, datetime_patch):
start_time = get_datetime_utc_now()
middle_time = start_time + timedelta(seconds=1)
@@ -272,7 +270,7 @@ def test_time(self, metrics_patch, datetime_patch):
middle_time,
middle_time,
middle_time,
- end_time
+ end_time,
]
test_key = "test_key"
with base.Timer(test_key) as timer:
@@ -280,23 +278,19 @@ def test_time(self, metrics_patch, datetime_patch):
metrics_patch.time.assert_not_called()
timer.send_time()
metrics_patch.time.assert_called_with(
- test_key,
- (end_time - middle_time).total_seconds()
+ test_key, (end_time - middle_time).total_seconds()
)
second_test_key = "lakshmi_has_toes"
timer.send_time(second_test_key)
metrics_patch.time.assert_called_with(
- second_test_key,
- (end_time - middle_time).total_seconds()
+ second_test_key, (end_time - middle_time).total_seconds()
)
time_delta = timer.get_time_delta()
self.assertEqual(
- time_delta.total_seconds(),
- (end_time - middle_time).total_seconds()
+ time_delta.total_seconds(), (end_time - middle_time).total_seconds()
)
metrics_patch.time.assert_called_with(
- test_key,
- (end_time - start_time).total_seconds()
+ test_key, (end_time - start_time).total_seconds()
)
@@ -306,46 +300,44 @@ def setUp(self):
self.middle_time = self.start_time + timedelta(seconds=1)
self.end_time = self.middle_time + timedelta(seconds=1)
- @patch('st2common.metrics.base.get_datetime_utc_now')
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.get_datetime_utc_now")
+ @patch("st2common.metrics.base.METRICS")
def test_time(self, metrics_patch, datetime_patch):
datetime_patch.side_effect = [
self.start_time,
self.middle_time,
self.middle_time,
self.middle_time,
- self.end_time
+ self.end_time,
]
test_key = "test_key"
with base.CounterWithTimer(test_key) as timer:
self.assertIsInstance(timer._start_time, datetime)
metrics_patch.time.assert_not_called()
timer.send_time()
- metrics_patch.time.assert_called_with(test_key,
- (self.end_time - self.middle_time).total_seconds()
+ metrics_patch.time.assert_called_with(
+ test_key, (self.end_time - self.middle_time).total_seconds()
)
second_test_key = "lakshmi_has_a_nose"
timer.send_time(second_test_key)
metrics_patch.time.assert_called_with(
- second_test_key,
- (self.end_time - self.middle_time).total_seconds()
+ second_test_key, (self.end_time - self.middle_time).total_seconds()
)
time_delta = timer.get_time_delta()
self.assertEqual(
time_delta.total_seconds(),
- (self.end_time - self.middle_time).total_seconds()
+ (self.end_time - self.middle_time).total_seconds(),
)
metrics_patch.inc_counter.assert_called_with(test_key)
metrics_patch.dec_counter.assert_not_called()
metrics_patch.time.assert_called_with(
- test_key,
- (self.end_time - self.start_time).total_seconds()
+ test_key, (self.end_time - self.start_time).total_seconds()
)
class TestCounterWithTimerDecorator(unittest2.TestCase):
- @patch('st2common.metrics.base.get_datetime_utc_now')
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.get_datetime_utc_now")
+ @patch("st2common.metrics.base.METRICS")
def test_time(self, metrics_patch, datetime_patch):
start_time = get_datetime_utc_now()
middle_time = start_time + timedelta(seconds=1)
@@ -355,7 +347,7 @@ def test_time(self, metrics_patch, datetime_patch):
middle_time,
middle_time,
middle_time,
- end_time
+ end_time,
]
test_key = "test_key"
@@ -364,32 +356,30 @@ def _get_tested(metrics_counter_with_timer=None):
self.assertIsInstance(metrics_counter_with_timer._start_time, datetime)
metrics_patch.time.assert_not_called()
metrics_counter_with_timer.send_time()
- metrics_patch.time.assert_called_with(test_key,
- (end_time - middle_time).total_seconds()
+ metrics_patch.time.assert_called_with(
+ test_key, (end_time - middle_time).total_seconds()
)
second_test_key = "lakshmi_has_a_nose"
metrics_counter_with_timer.send_time(second_test_key)
metrics_patch.time.assert_called_with(
- second_test_key,
- (end_time - middle_time).total_seconds()
+ second_test_key, (end_time - middle_time).total_seconds()
)
time_delta = metrics_counter_with_timer.get_time_delta()
self.assertEqual(
- time_delta.total_seconds(),
- (end_time - middle_time).total_seconds()
+ time_delta.total_seconds(), (end_time - middle_time).total_seconds()
)
metrics_patch.inc_counter.assert_called_with(test_key)
metrics_patch.dec_counter.assert_not_called()
_get_tested()
- metrics_patch.time.assert_called_with(test_key,
- (end_time - start_time).total_seconds()
+ metrics_patch.time.assert_called_with(
+ test_key, (end_time - start_time).total_seconds()
)
class TestCounterDecorator(unittest2.TestCase):
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.METRICS")
def test_counter(self, metrics_patch):
test_key = "test_key"
@@ -397,12 +387,13 @@ def test_counter(self, metrics_patch):
def _get_tested():
metrics_patch.inc_counter.assert_called_with(test_key)
metrics_patch.dec_counter.assert_not_called()
+
_get_tested()
class TestTimerDecorator(unittest2.TestCase):
- @patch('st2common.metrics.base.get_datetime_utc_now')
- @patch('st2common.metrics.base.METRICS')
+ @patch("st2common.metrics.base.get_datetime_utc_now")
+ @patch("st2common.metrics.base.METRICS")
def test_time(self, metrics_patch, datetime_patch):
start_time = get_datetime_utc_now()
middle_time = start_time + timedelta(seconds=1)
@@ -412,7 +403,7 @@ def test_time(self, metrics_patch, datetime_patch):
middle_time,
middle_time,
middle_time,
- end_time
+ end_time,
]
test_key = "test_key"
@@ -422,22 +413,19 @@ def _get_tested(metrics_timer=None):
metrics_patch.time.assert_not_called()
metrics_timer.send_time()
metrics_patch.time.assert_called_with(
- test_key,
- (end_time - middle_time).total_seconds()
+ test_key, (end_time - middle_time).total_seconds()
)
second_test_key = "lakshmi_has_toes"
metrics_timer.send_time(second_test_key)
metrics_patch.time.assert_called_with(
- second_test_key,
- (end_time - middle_time).total_seconds()
+ second_test_key, (end_time - middle_time).total_seconds()
)
time_delta = metrics_timer.get_time_delta()
self.assertEqual(
- time_delta.total_seconds(),
- (end_time - middle_time).total_seconds()
+ time_delta.total_seconds(), (end_time - middle_time).total_seconds()
)
+
_get_tested()
metrics_patch.time.assert_called_with(
- test_key,
- (end_time - start_time).total_seconds()
+ test_key, (end_time - start_time).total_seconds()
)
diff --git a/st2common/tests/unit/test_misc_utils.py b/st2common/tests/unit/test_misc_utils.py
index d7008e921c..f05615573b 100644
--- a/st2common/tests/unit/test_misc_utils.py
+++ b/st2common/tests/unit/test_misc_utils.py
@@ -24,71 +24,61 @@
from st2common.util.misc import sanitize_output
from st2common.util.ujson import fast_deepcopy
-__all__ = [
- 'MiscUtilTestCase'
-]
+__all__ = ["MiscUtilTestCase"]
class MiscUtilTestCase(unittest2.TestCase):
def test_rstrip_last_char(self):
- self.assertEqual(rstrip_last_char(None, '\n'), None)
- self.assertEqual(rstrip_last_char('stuff', None), 'stuff')
- self.assertEqual(rstrip_last_char('', '\n'), '')
- self.assertEqual(rstrip_last_char('foo', '\n'), 'foo')
- self.assertEqual(rstrip_last_char('foo\n', '\n'), 'foo')
- self.assertEqual(rstrip_last_char('foo\n\n', '\n'), 'foo\n')
- self.assertEqual(rstrip_last_char('foo\r', '\r'), 'foo')
- self.assertEqual(rstrip_last_char('foo\r\r', '\r'), 'foo\r')
- self.assertEqual(rstrip_last_char('foo\r\n', '\r\n'), 'foo')
- self.assertEqual(rstrip_last_char('foo\r\r\n', '\r\n'), 'foo\r')
- self.assertEqual(rstrip_last_char('foo\n\r', '\r\n'), 'foo\n\r')
+ self.assertEqual(rstrip_last_char(None, "\n"), None)
+ self.assertEqual(rstrip_last_char("stuff", None), "stuff")
+ self.assertEqual(rstrip_last_char("", "\n"), "")
+ self.assertEqual(rstrip_last_char("foo", "\n"), "foo")
+ self.assertEqual(rstrip_last_char("foo\n", "\n"), "foo")
+ self.assertEqual(rstrip_last_char("foo\n\n", "\n"), "foo\n")
+ self.assertEqual(rstrip_last_char("foo\r", "\r"), "foo")
+ self.assertEqual(rstrip_last_char("foo\r\r", "\r"), "foo\r")
+ self.assertEqual(rstrip_last_char("foo\r\n", "\r\n"), "foo")
+ self.assertEqual(rstrip_last_char("foo\r\r\n", "\r\n"), "foo\r")
+ self.assertEqual(rstrip_last_char("foo\n\r", "\r\n"), "foo\n\r")
def test_strip_shell_chars(self):
self.assertEqual(strip_shell_chars(None), None)
- self.assertEqual(strip_shell_chars('foo'), 'foo')
- self.assertEqual(strip_shell_chars('foo\r'), 'foo')
- self.assertEqual(strip_shell_chars('fo\ro\r'), 'fo\ro')
- self.assertEqual(strip_shell_chars('foo\n'), 'foo')
- self.assertEqual(strip_shell_chars('fo\no\n'), 'fo\no')
- self.assertEqual(strip_shell_chars('foo\r\n'), 'foo')
- self.assertEqual(strip_shell_chars('fo\no\r\n'), 'fo\no')
- self.assertEqual(strip_shell_chars('foo\r\n\r\n'), 'foo\r\n')
+ self.assertEqual(strip_shell_chars("foo"), "foo")
+ self.assertEqual(strip_shell_chars("foo\r"), "foo")
+ self.assertEqual(strip_shell_chars("fo\ro\r"), "fo\ro")
+ self.assertEqual(strip_shell_chars("foo\n"), "foo")
+ self.assertEqual(strip_shell_chars("fo\no\n"), "fo\no")
+ self.assertEqual(strip_shell_chars("foo\r\n"), "foo")
+ self.assertEqual(strip_shell_chars("fo\no\r\n"), "fo\no")
+ self.assertEqual(strip_shell_chars("foo\r\n\r\n"), "foo\r\n")
def test_lowercase_value(self):
- value = 'TEST'
- expected_value = 'test'
+ value = "TEST"
+ expected_value = "test"
self.assertEqual(expected_value, lowercase_value(value=value))
- value = ['testA', 'TESTb', 'TESTC']
- expected_value = ['testa', 'testb', 'testc']
+ value = ["testA", "TESTb", "TESTC"]
+ expected_value = ["testa", "testb", "testc"]
self.assertEqual(expected_value, lowercase_value(value=value))
- value = {
- 'testA': 'testB',
- 'testC': 'TESTD',
- 'TESTE': 'TESTE'
- }
- expected_value = {
- 'testa': 'testb',
- 'testc': 'testd',
- 'teste': 'teste'
- }
+ value = {"testA": "testB", "testC": "TESTD", "TESTE": "TESTE"}
+ expected_value = {"testa": "testb", "testc": "testd", "teste": "teste"}
self.assertEqual(expected_value, lowercase_value(value=value))
def test_fast_deepcopy_success(self):
values = [
- 'a',
- u'٩(̾●̮̮̃̾•̃̾)۶',
+ "a",
+ "٩(̾●̮̮̃̾•̃̾)۶",
1,
- [1, 2, '3', 'b'],
- {'a': 1, 'b': '3333', 'c': 'd'},
+ [1, 2, "3", "b"],
+ {"a": 1, "b": "3333", "c": "d"},
]
expected_values = [
- 'a',
- u'٩(̾●̮̮̃̾•̃̾)۶',
+ "a",
+ "٩(̾●̮̮̃̾•̃̾)۶",
1,
- [1, 2, '3', 'b'],
- {'a': 1, 'b': '3333', 'c': 'd'},
+ [1, 2, "3", "b"],
+ {"a": 1, "b": "3333", "c": "d"},
]
for value, expected_value in zip(values, expected_values):
@@ -99,18 +89,18 @@ def test_fast_deepcopy_success(self):
def test_sanitize_output_use_pyt_false(self):
# pty is not used, \r\n shouldn't be replaced with \n
input_strs = [
- 'foo',
- 'foo\n',
- 'foo\r\n',
- 'foo\nbar\nbaz\n',
- 'foo\r\nbar\r\nbaz\r\n',
+ "foo",
+ "foo\n",
+ "foo\r\n",
+ "foo\nbar\nbaz\n",
+ "foo\r\nbar\r\nbaz\r\n",
]
expected = [
- 'foo',
- 'foo',
- 'foo',
- 'foo\nbar\nbaz',
- 'foo\r\nbar\r\nbaz',
+ "foo",
+ "foo",
+ "foo",
+ "foo\nbar\nbaz",
+ "foo\r\nbar\r\nbaz",
]
for input_str, expected_output in zip(input_strs, expected):
@@ -120,18 +110,18 @@ def test_sanitize_output_use_pyt_false(self):
def test_sanitize_output_use_pyt_true(self):
# pty is used, \r\n should be replaced with \n
input_strs = [
- 'foo',
- 'foo\n',
- 'foo\r\n',
- 'foo\nbar\nbaz\n',
- 'foo\r\nbar\r\nbaz\r\n',
+ "foo",
+ "foo\n",
+ "foo\r\n",
+ "foo\nbar\nbaz\n",
+ "foo\r\nbar\r\nbaz\r\n",
]
expected = [
- 'foo',
- 'foo',
- 'foo',
- 'foo\nbar\nbaz',
- 'foo\nbar\nbaz',
+ "foo",
+ "foo",
+ "foo",
+ "foo\nbar\nbaz",
+ "foo\nbar\nbaz",
]
for input_str, expected_output in zip(input_strs, expected):
diff --git a/st2common/tests/unit/test_model_utils_profiling.py b/st2common/tests/unit/test_model_utils_profiling.py
index 2225e39a7e..db37039c80 100644
--- a/st2common/tests/unit/test_model_utils_profiling.py
+++ b/st2common/tests/unit/test_model_utils_profiling.py
@@ -28,31 +28,37 @@ def setUp(self):
super(MongoDBProfilingTestCase, self).setUp()
disable_profiling()
- @mock.patch('st2common.models.utils.profiling.LOG')
+ @mock.patch("st2common.models.utils.profiling.LOG")
def test_logging_profiling_is_disabled(self, mock_log):
disable_profiling()
- queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1)
+ queryset = User.query(
+ name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1
+ )
result = log_query_and_profile_data_for_queryset(queryset=queryset)
self.assertEqual(queryset, result)
call_args_list = mock_log.debug.call_args_list
self.assertItemsEqual(call_args_list, [])
- @mock.patch('st2common.models.utils.profiling.LOG')
+ @mock.patch("st2common.models.utils.profiling.LOG")
def test_logging_profiling_is_enabled(self, mock_log):
enable_profiling()
- queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1)
+ queryset = User.query(
+ name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1
+ )
result = log_query_and_profile_data_for_queryset(queryset=queryset)
call_args_list = mock_log.debug.call_args_list
call_args = call_args_list[0][0]
call_kwargs = call_args_list[0][1]
- expected_result = ("db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})"
- ".sort({aa: 1, bb: -1}).limit(1);")
+ expected_result = (
+ "db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})"
+ ".sort({aa: 1, bb: -1}).limit(1);"
+ )
self.assertEqual(queryset, result)
self.assertIn(expected_result, call_args[0])
- self.assertIn('mongo_query', call_kwargs['extra'])
- self.assertIn('mongo_shell_query', call_kwargs['extra'])
+ self.assertIn("mongo_query", call_kwargs["extra"])
+ self.assertIn("mongo_shell_query", call_kwargs["extra"])
def test_logging_profiling_is_enabled_non_queryset_object(self):
enable_profiling()
diff --git a/st2common/tests/unit/test_mongoescape.py b/st2common/tests/unit/test_mongoescape.py
index 05e3b7962f..0ad12e2823 100644
--- a/st2common/tests/unit/test_mongoescape.py
+++ b/st2common/tests/unit/test_mongoescape.py
@@ -21,68 +21,70 @@
class TestMongoEscape(unittest.TestCase):
def test_unnested(self):
- field = {'k1.k1.k1': 'v1', 'k2$': 'v2', '$k3.': 'v3'}
+ field = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"}
escaped = mongoescape.escape_chars(field)
- self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': 'v1',
- u'k2\uff04': 'v2',
- u'\uff04k3\uff0e': 'v3'}, 'Escaping failed.')
+ self.assertEqual(
+ escaped,
+ {"k1\uff0ek1\uff0ek1": "v1", "k2\uff04": "v2", "\uff04k3\uff0e": "v3"},
+ "Escaping failed.",
+ )
unescaped = mongoescape.unescape_chars(escaped)
- self.assertEqual(unescaped, field, 'Unescaping failed.')
+ self.assertEqual(unescaped, field, "Unescaping failed.")
def test_nested(self):
- nested_field = {'nk1.nk1.nk1': 'v1', 'nk2$': 'v2', '$nk3.': 'v3'}
- field = {'k1.k1.k1': nested_field, 'k2$': 'v2', '$k3.': 'v3'}
+ nested_field = {"nk1.nk1.nk1": "v1", "nk2$": "v2", "$nk3.": "v3"}
+ field = {"k1.k1.k1": nested_field, "k2$": "v2", "$k3.": "v3"}
escaped = mongoescape.escape_chars(field)
- self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': {u'\uff04nk3\uff0e': 'v3',
- u'nk1\uff0enk1\uff0enk1': 'v1',
- u'nk2\uff04': 'v2'},
- u'k2\uff04': 'v2',
- u'\uff04k3\uff0e': 'v3'}, 'un-escaping failed.')
+ self.assertEqual(
+ escaped,
+ {
+ "k1\uff0ek1\uff0ek1": {
+ "\uff04nk3\uff0e": "v3",
+ "nk1\uff0enk1\uff0enk1": "v1",
+ "nk2\uff04": "v2",
+ },
+ "k2\uff04": "v2",
+ "\uff04k3\uff0e": "v3",
+ },
+ "un-escaping failed.",
+ )
unescaped = mongoescape.unescape_chars(escaped)
- self.assertEqual(unescaped, field, 'Unescaping failed.')
+ self.assertEqual(unescaped, field, "Unescaping failed.")
def test_unescaping_of_rule_criteria(self):
# Verify that dot escaped in rule criteria is correctly escaped.
# Note: In the past we used different character to escape dot in the
# rule criteria.
- escaped = {
- u'k1\u2024k1\u2024k1': 'v1',
- u'k2$': 'v2',
- u'$k3\u2024': 'v3'
- }
- unescaped = {
- 'k1.k1.k1': 'v1',
- 'k2$': 'v2',
- '$k3.': 'v3'
- }
+ escaped = {"k1\u2024k1\u2024k1": "v1", "k2$": "v2", "$k3\u2024": "v3"}
+ unescaped = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"}
result = mongoescape.unescape_chars(escaped)
self.assertEqual(result, unescaped)
def test_original_value(self):
- field = {'k1.k2.k3': 'v1'}
+ field = {"k1.k2.k3": "v1"}
escaped = mongoescape.escape_chars(field)
- self.assertIn('k1.k2.k3', list(field.keys()))
- self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys()))
+ self.assertIn("k1.k2.k3", list(field.keys()))
+ self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys()))
unescaped = mongoescape.unescape_chars(escaped)
- self.assertIn('k1.k2.k3', list(unescaped.keys()))
- self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys()))
+ self.assertIn("k1.k2.k3", list(unescaped.keys()))
+ self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys()))
def test_complex(self):
field = {
- 'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}],
- 'k3': [{'l5.l6': '789'}],
- 'k4.k5': [1, 2, 3],
- 'k6': ['a', 'b']
+ "k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}],
+ "k3": [{"l5.l6": "789"}],
+ "k4.k5": [1, 2, 3],
+ "k6": ["a", "b"],
}
expected = {
- u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}],
- 'k3': [{u'l5\uff0el6': '789'}],
- u'k4\uff0ek5': [1, 2, 3],
- 'k6': ['a', 'b']
+ "k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}],
+ "k3": [{"l5\uff0el6": "789"}],
+ "k4\uff0ek5": [1, 2, 3],
+ "k6": ["a", "b"],
}
escaped = mongoescape.escape_chars(field)
@@ -93,17 +95,17 @@ def test_complex(self):
def test_complex_list(self):
field = [
- {'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}]},
- {'k3': [{'l5.l6': '789'}]},
- {'k4.k5': [1, 2, 3]},
- {'k6': ['a', 'b']}
+ {"k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}]},
+ {"k3": [{"l5.l6": "789"}]},
+ {"k4.k5": [1, 2, 3]},
+ {"k6": ["a", "b"]},
]
expected = [
- {u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}]},
- {'k3': [{u'l5\uff0el6': '789'}]},
- {u'k4\uff0ek5': [1, 2, 3]},
- {'k6': ['a', 'b']}
+ {"k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}]},
+ {"k3": [{"l5\uff0el6": "789"}]},
+ {"k4\uff0ek5": [1, 2, 3]},
+ {"k6": ["a", "b"]},
]
escaped = mongoescape.escape_chars(field)
diff --git a/st2common/tests/unit/test_notification_helper.py b/st2common/tests/unit/test_notification_helper.py
index d169dd5a5f..9c00ea4771 100644
--- a/st2common/tests/unit/test_notification_helper.py
+++ b/st2common/tests/unit/test_notification_helper.py
@@ -20,7 +20,6 @@
class NotificationsHelperTestCase(unittest2.TestCase):
-
def test_model_transformations(self):
notify = {}
@@ -31,42 +30,56 @@ def test_model_transformations(self):
notify_api = NotificationsHelper.from_model(notify_model)
self.assertEqual(notify_api, {})
- notify['on-complete'] = {
- 'message': 'Action completed.',
- 'routes': [
- '66'
- ],
- 'data': {
- 'foo': '{{foo}}',
- 'bar': 1,
- 'baz': [1, 2, 3]
- }
+ notify["on-complete"] = {
+ "message": "Action completed.",
+ "routes": ["66"],
+ "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]},
}
- notify['on-success'] = {
- 'message': 'Action succeeded.',
- 'routes': [
- '100'
- ],
- 'data': {
- 'foo': '{{foo}}',
- 'bar': 1,
- }
+ notify["on-success"] = {
+ "message": "Action succeeded.",
+ "routes": ["100"],
+ "data": {
+ "foo": "{{foo}}",
+ "bar": 1,
+ },
}
notify_model = NotificationsHelper.to_model(notify)
- self.assertEqual(notify['on-complete']['message'], notify_model.on_complete.message)
- self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data)
- self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes)
- self.assertEqual(notify['on-success']['message'], notify_model.on_success.message)
- self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data)
- self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes)
+ self.assertEqual(
+ notify["on-complete"]["message"], notify_model.on_complete.message
+ )
+ self.assertDictEqual(
+ notify["on-complete"]["data"], notify_model.on_complete.data
+ )
+ self.assertListEqual(
+ notify["on-complete"]["routes"], notify_model.on_complete.routes
+ )
+ self.assertEqual(
+ notify["on-success"]["message"], notify_model.on_success.message
+ )
+ self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data)
+ self.assertListEqual(
+ notify["on-success"]["routes"], notify_model.on_success.routes
+ )
notify_api = NotificationsHelper.from_model(notify_model)
- self.assertEqual(notify['on-complete']['message'], notify_api['on-complete']['message'])
- self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data'])
- self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes'])
- self.assertEqual(notify['on-success']['message'], notify_api['on-success']['message'])
- self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data'])
- self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes'])
+ self.assertEqual(
+ notify["on-complete"]["message"], notify_api["on-complete"]["message"]
+ )
+ self.assertDictEqual(
+ notify["on-complete"]["data"], notify_api["on-complete"]["data"]
+ )
+ self.assertListEqual(
+ notify["on-complete"]["routes"], notify_api["on-complete"]["routes"]
+ )
+ self.assertEqual(
+ notify["on-success"]["message"], notify_api["on-success"]["message"]
+ )
+ self.assertDictEqual(
+ notify["on-success"]["data"], notify_api["on-success"]["data"]
+ )
+ self.assertListEqual(
+ notify["on-success"]["routes"], notify_api["on-success"]["routes"]
+ )
def test_model_transformations_missing_fields(self):
notify = {}
@@ -78,33 +91,39 @@ def test_model_transformations_missing_fields(self):
notify_api = NotificationsHelper.from_model(notify_model)
self.assertEqual(notify_api, {})
- notify['on-complete'] = {
- 'routes': [
- '66'
- ],
- 'data': {
- 'foo': '{{foo}}',
- 'bar': 1,
- 'baz': [1, 2, 3]
- }
+ notify["on-complete"] = {
+ "routes": ["66"],
+ "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]},
}
- notify['on-success'] = {
- 'routes': [
- '100'
- ],
- 'data': {
- 'foo': '{{foo}}',
- 'bar': 1,
- }
+ notify["on-success"] = {
+ "routes": ["100"],
+ "data": {
+ "foo": "{{foo}}",
+ "bar": 1,
+ },
}
notify_model = NotificationsHelper.to_model(notify)
- self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data)
- self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes)
- self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data)
- self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes)
+ self.assertDictEqual(
+ notify["on-complete"]["data"], notify_model.on_complete.data
+ )
+ self.assertListEqual(
+ notify["on-complete"]["routes"], notify_model.on_complete.routes
+ )
+ self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data)
+ self.assertListEqual(
+ notify["on-success"]["routes"], notify_model.on_success.routes
+ )
notify_api = NotificationsHelper.from_model(notify_model)
- self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data'])
- self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes'])
- self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data'])
- self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes'])
+ self.assertDictEqual(
+ notify["on-complete"]["data"], notify_api["on-complete"]["data"]
+ )
+ self.assertListEqual(
+ notify["on-complete"]["routes"], notify_api["on-complete"]["routes"]
+ )
+ self.assertDictEqual(
+ notify["on-success"]["data"], notify_api["on-success"]["data"]
+ )
+ self.assertListEqual(
+ notify["on-success"]["routes"], notify_api["on-success"]["routes"]
+ )
diff --git a/st2common/tests/unit/test_operators.py b/st2common/tests/unit/test_operators.py
index 48f693af30..5917e4277c 100644
--- a/st2common/tests/unit/test_operators.py
+++ b/st2common/tests/unit/test_operators.py
@@ -44,6 +44,7 @@ class ListOfDictsStrictEqualTest(unittest2.TestCase):
We should test our comparison functions, even if they're only used in our
other tests.
"""
+
def test_empty_lists(self):
self.assertTrue(list_of_dicts_strict_equal([], []))
@@ -54,65 +55,105 @@ def test_multiple_empty_dicts(self):
self.assertTrue(list_of_dicts_strict_equal([{}, {}], [{}, {}]))
def test_simple_dicts(self):
- self.assertTrue(list_of_dicts_strict_equal([
- {'a': 1},
- ], [
- {'a': 1},
- ]))
-
- self.assertFalse(list_of_dicts_strict_equal([
- {'a': 1},
- ], [
- {'a': 2},
- ]))
+ self.assertTrue(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ ],
+ [
+ {"a": 1},
+ ],
+ )
+ )
+
+ self.assertFalse(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ ],
+ [
+ {"a": 2},
+ ],
+ )
+ )
def test_lists_of_different_lengths(self):
- self.assertFalse(list_of_dicts_strict_equal([
- {'a': 1},
- ], [
- {'a': 1},
- {'b': 2},
- ]))
-
- self.assertFalse(list_of_dicts_strict_equal([
- {'a': 1},
- {'b': 2},
- ], [
- {'a': 1},
- ]))
+ self.assertFalse(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ ],
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ )
+ )
+
+ self.assertFalse(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ [
+ {"a": 1},
+ ],
+ )
+ )
def test_less_simple_dicts(self):
- self.assertTrue(list_of_dicts_strict_equal([
- {'a': 1},
- {'b': 2},
- ], [
- {'a': 1},
- {'b': 2},
- ]))
-
- self.assertTrue(list_of_dicts_strict_equal([
- {'a': 1},
- {'a': 1},
- ], [
- {'a': 1},
- {'a': 1},
- ]))
-
- self.assertFalse(list_of_dicts_strict_equal([
- {'a': 1},
- {'a': 1},
- ], [
- {'a': 1},
- {'b': 2},
- ]))
-
- self.assertFalse(list_of_dicts_strict_equal([
- {'a': 1},
- {'b': 2},
- ], [
- {'a': 1},
- {'a': 1},
- ]))
+ self.assertTrue(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ )
+ )
+
+ self.assertTrue(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ {"a": 1},
+ ],
+ [
+ {"a": 1},
+ {"a": 1},
+ ],
+ )
+ )
+
+ self.assertFalse(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ {"a": 1},
+ ],
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ )
+ )
+
+ self.assertFalse(
+ list_of_dicts_strict_equal(
+ [
+ {"a": 1},
+ {"b": 2},
+ ],
+ [
+ {"a": 1},
+ {"a": 1},
+ ],
+ )
+ )
class SearchOperatorTest(unittest2.TestCase):
@@ -120,774 +161,850 @@ class SearchOperatorTest(unittest2.TestCase):
# parser. As such, its tests are much more complex than other commands, so we
# pull its tests out into their own test case.
def test_search_with_weird_condition(self):
- op = operators.get_operator('search')
+ op = operators.get_operator("search")
with self.assertRaises(operators.UnrecognizedConditionError):
- op([], [], 'weird', None)
+ op([], [], "weird", None)
def test_search_any_true(self):
- op = operators.get_operator('search')
+ op = operators.get_operator("search")
called_function_args = []
def record_function_args(criterion_k, criterion_v, payload_lookup):
- called_function_args.append({
- 'criterion_k': criterion_k,
- 'criterion_v': criterion_v,
- 'payload_lookup': {
- 'field_name': payload_lookup.get_value('item.field_name')[0],
- 'to_value': payload_lookup.get_value('item.to_value')[0],
- },
- })
- return (len(called_function_args) < 3)
+ called_function_args.append(
+ {
+ "criterion_k": criterion_k,
+ "criterion_v": criterion_v,
+ "payload_lookup": {
+ "field_name": payload_lookup.get_value("item.field_name")[0],
+ "to_value": payload_lookup.get_value("item.to_value")[0],
+ },
+ }
+ )
+ return len(called_function_args) < 3
payload = [
{
- 'field_name': "Status",
- 'to_value': "Approved",
- }, {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- }
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
]
criteria_pattern = {
- 'item.field_name': {
- 'type': "equals",
- 'pattern': "Status",
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Approved",
},
- 'item.to_value': {
- 'type': "equals",
- 'pattern': "Approved",
- }
}
- result = op(payload, criteria_pattern, 'any', record_function_args)
+ result = op(payload, criteria_pattern, "any", record_function_args)
self.assertTrue(result)
- self.assertTrue(list_of_dicts_strict_equal(called_function_args, [
- # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
- {
- # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- }, {
- # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- },
- # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
- {
- # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }, {
- # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }
- ]))
+ self.assertTrue(
+ list_of_dicts_strict_equal(
+ called_function_args,
+ [
+ # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
+ {
+ # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ {
+ # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
+ {
+ # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ {
+ # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ ],
+ )
+ )
def test_search_any_false(self):
- op = operators.get_operator('search')
+ op = operators.get_operator("search")
called_function_args = []
def record_function_args(criterion_k, criterion_v, payload_lookup):
- called_function_args.append({
- 'criterion_k': criterion_k,
- 'criterion_v': criterion_v,
- 'payload_lookup': {
- 'field_name': payload_lookup.get_value('item.field_name')[0],
- 'to_value': payload_lookup.get_value('item.to_value')[0],
- },
- })
+ called_function_args.append(
+ {
+ "criterion_k": criterion_k,
+ "criterion_v": criterion_v,
+ "payload_lookup": {
+ "field_name": payload_lookup.get_value("item.field_name")[0],
+ "to_value": payload_lookup.get_value("item.to_value")[0],
+ },
+ }
+ )
return (len(called_function_args) % 2) == 0
payload = [
{
- 'field_name': "Status",
- 'to_value': "Denied",
- }, {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- }
+ "field_name": "Status",
+ "to_value": "Denied",
+ },
+ {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
]
criteria_pattern = {
- 'item.field_name': {
- 'type': "equals",
- 'pattern': "Status",
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Approved",
},
- 'item.to_value': {
- 'type': "equals",
- 'pattern': "Approved",
- }
}
- result = op(payload, criteria_pattern, 'any', record_function_args)
+ result = op(payload, criteria_pattern, "any", record_function_args)
self.assertFalse(result)
- self.assertEqual(called_function_args, [
- # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"}
- {
- # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Denied",
- },
- }, {
- # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Denied",
- },
- },
- # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
- {
- # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }, {
- # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }
- ])
+ self.assertEqual(
+ called_function_args,
+ [
+ # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"}
+ {
+ # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Denied",
+ },
+ },
+ {
+ # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Denied",
+ },
+ },
+ # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
+ {
+ # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ {
+ # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ ],
+ )
def test_search_all_false(self):
- op = operators.get_operator('search')
+ op = operators.get_operator("search")
called_function_args = []
def record_function_args(criterion_k, criterion_v, payload_lookup):
- called_function_args.append({
- 'criterion_k': criterion_k,
- 'criterion_v': criterion_v,
- 'payload_lookup': {
- 'field_name': payload_lookup.get_value('item.field_name')[0],
- 'to_value': payload_lookup.get_value('item.to_value')[0],
- },
- })
+ called_function_args.append(
+ {
+ "criterion_k": criterion_k,
+ "criterion_v": criterion_v,
+ "payload_lookup": {
+ "field_name": payload_lookup.get_value("item.field_name")[0],
+ "to_value": payload_lookup.get_value("item.to_value")[0],
+ },
+ }
+ )
return (len(called_function_args) % 2) == 0
payload = [
{
- 'field_name': "Status",
- 'to_value': "Approved",
- }, {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- }
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
]
criteria_pattern = {
- 'item.field_name': {
- 'type': "equals",
- 'pattern': "Status",
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Approved",
},
- 'item.to_value': {
- 'type': "equals",
- 'pattern': "Approved",
- }
}
- result = op(payload, criteria_pattern, 'all', record_function_args)
+ result = op(payload, criteria_pattern, "all", record_function_args)
self.assertFalse(result)
- self.assertEqual(called_function_args, [
- # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
- {
- # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- }, {
- # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- },
- # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
- {
- # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Status",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }, {
- # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Assigned to",
- 'to_value': "Stanley",
- },
- }
- ])
+ self.assertEqual(
+ called_function_args,
+ [
+ # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
+ {
+ # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ {
+ # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"}
+ {
+ # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Status",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ {
+ # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Assigned to",
+ "to_value": "Stanley",
+ },
+ },
+ ],
+ )
def test_search_all_true(self):
- op = operators.get_operator('search')
+ op = operators.get_operator("search")
called_function_args = []
def record_function_args(criterion_k, criterion_v, payload_lookup):
- called_function_args.append({
- 'criterion_k': criterion_k,
- 'criterion_v': criterion_v,
- 'payload_lookup': {
- 'field_name': payload_lookup.get_value('item.field_name')[0],
- 'to_value': payload_lookup.get_value('item.to_value')[0],
- },
- })
+ called_function_args.append(
+ {
+ "criterion_k": criterion_k,
+ "criterion_v": criterion_v,
+ "payload_lookup": {
+ "field_name": payload_lookup.get_value("item.field_name")[0],
+ "to_value": payload_lookup.get_value("item.to_value")[0],
+ },
+ }
+ )
return True
payload = [
{
- 'field_name': "Status",
- 'to_value': "Approved",
- }, {
- 'field_name': "Signed off by",
- 'to_value': "Approved",
- }
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ {
+ "field_name": "Signed off by",
+ "to_value": "Approved",
+ },
]
criteria_pattern = {
- 'item.field_name': {
- 'type': "startswith",
- 'pattern': "S",
+ "item.field_name": {
+ "type": "startswith",
+ "pattern": "S",
+ },
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Approved",
},
- 'item.to_value': {
- 'type': "equals",
- 'pattern': "Approved",
- }
}
- result = op(payload, criteria_pattern, 'all', record_function_args)
+ result = op(payload, criteria_pattern, "all", record_function_args)
self.assertTrue(result)
- self.assertEqual(called_function_args, [
- # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
- {
- # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "startswith",
- 'pattern': "S",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- }, {
- # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Status",
- 'to_value': "Approved",
- },
- },
- # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"}
- {
- # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"}
- 'criterion_k': 'item.field_name',
- 'criterion_v': {
- 'type': "startswith",
- 'pattern': "S",
- },
- 'payload_lookup': {
- 'field_name': "Signed off by",
- 'to_value': "Approved",
- },
- }, {
- # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
- 'criterion_k': 'item.to_value',
- 'criterion_v': {
- 'type': "equals",
- 'pattern': "Approved",
- },
- 'payload_lookup': {
- 'field_name': "Signed off by",
- 'to_value': "Approved",
- },
- }
- ])
+ self.assertEqual(
+ called_function_args,
+ [
+ # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"}
+ {
+ # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "startswith",
+ "pattern": "S",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ {
+ # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ },
+ # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"}
+ {
+ # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"}
+ "criterion_k": "item.field_name",
+ "criterion_v": {
+ "type": "startswith",
+ "pattern": "S",
+ },
+ "payload_lookup": {
+ "field_name": "Signed off by",
+ "to_value": "Approved",
+ },
+ },
+ {
+ # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"}
+ "criterion_k": "item.to_value",
+ "criterion_v": {
+ "type": "equals",
+ "pattern": "Approved",
+ },
+ "payload_lookup": {
+ "field_name": "Signed off by",
+ "to_value": "Approved",
+ },
+ },
+ ],
+ )
class OperatorTest(unittest2.TestCase):
def test_matchwildcard(self):
- op = operators.get_operator('matchwildcard')
- self.assertTrue(op('v1', 'v1'), 'Failed matchwildcard.')
+ op = operators.get_operator("matchwildcard")
+ self.assertTrue(op("v1", "v1"), "Failed matchwildcard.")
- self.assertFalse(op('test foo test', 'foo'), 'Passed matchwildcard.')
- self.assertTrue(op('test foo test', '*foo*'), 'Failed matchwildcard.')
- self.assertTrue(op('bar', 'b*r'), 'Failed matchwildcard.')
- self.assertTrue(op('bar', 'b?r'), 'Failed matchwildcard.')
+ self.assertFalse(op("test foo test", "foo"), "Passed matchwildcard.")
+ self.assertTrue(op("test foo test", "*foo*"), "Failed matchwildcard.")
+ self.assertTrue(op("bar", "b*r"), "Failed matchwildcard.")
+ self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'bar', 'b?r'), 'Failed matchwildcard.')
- self.assertTrue(op('bar', b'b?r'), 'Failed matchwildcard.')
- self.assertTrue(op(b'bar', b'b?r'), 'Failed matchwildcard.')
- self.assertTrue(op(u'bar', b'b?r'), 'Failed matchwildcard.')
- self.assertTrue(op(u'bar', u'b?r'), 'Failed matchwildcard.')
+ self.assertTrue(op(b"bar", "b?r"), "Failed matchwildcard.")
+ self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.")
+ self.assertTrue(op(b"bar", b"b?r"), "Failed matchwildcard.")
+ self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.")
+ self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.")
- self.assertFalse(op('1', None), 'Passed matchwildcard with None as criteria_pattern.')
+ self.assertFalse(
+ op("1", None), "Passed matchwildcard with None as criteria_pattern."
+ )
def test_matchregex(self):
- op = operators.get_operator('matchregex')
- self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.')
+ op = operators.get_operator("matchregex")
+ self.assertTrue(op("v1", "v1$"), "Failed matchregex.")
# Multi line string, make sure re.DOTALL is used
- string = '''ponies
+ string = """ponies
moar
foo
bar
yeah!
- '''
- self.assertTrue(op(string, '.*bar.*'), 'Failed matchregex.')
+ """
+ self.assertTrue(op(string, ".*bar.*"), "Failed matchregex.")
- string = 'foo\r\nponies\nbar\nfooooo'
- self.assertTrue(op(string, '.*ponies.*'), 'Failed matchregex.')
+ string = "foo\r\nponies\nbar\nfooooo"
+ self.assertTrue(op(string, ".*ponies.*"), "Failed matchregex.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'foo ponies bar', '.*ponies.*'), 'Failed matchregex.')
- self.assertTrue(op('foo ponies bar', b'.*ponies.*'), 'Failed matchregex.')
- self.assertTrue(op(b'foo ponies bar', b'.*ponies.*'), 'Failed matchregex.')
- self.assertTrue(op(b'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.')
- self.assertTrue(op(u'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.')
+ self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.")
+ self.assertTrue(op("foo ponies bar", b".*ponies.*"), "Failed matchregex.")
+ self.assertTrue(op(b"foo ponies bar", b".*ponies.*"), "Failed matchregex.")
+ self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.")
+ self.assertTrue(op("foo ponies bar", ".*ponies.*"), "Failed matchregex.")
def test_iregex(self):
- op = operators.get_operator('iregex')
- self.assertTrue(op('V1', 'v1$'), 'Failed iregex.')
+ op = operators.get_operator("iregex")
+ self.assertTrue(op("V1", "v1$"), "Failed iregex.")
- string = 'fooPONIESbarfooooo'
- self.assertTrue(op(string, 'ponies'), 'Failed iregex.')
+ string = "fooPONIESbarfooooo"
+ self.assertTrue(op(string, "ponies"), "Failed iregex.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'fooPONIESbarfooooo', 'ponies'), 'Failed iregex.')
- self.assertTrue(op('fooPONIESbarfooooo', b'ponies'), 'Failed iregex.')
- self.assertTrue(op(b'fooPONIESbarfooooo', b'ponies'), 'Failed iregex.')
- self.assertTrue(op(b'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.')
- self.assertTrue(op(u'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.')
+ self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.")
+ self.assertTrue(op("fooPONIESbarfooooo", b"ponies"), "Failed iregex.")
+ self.assertTrue(op(b"fooPONIESbarfooooo", b"ponies"), "Failed iregex.")
+ self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.")
+ self.assertTrue(op("fooPONIESbarfooooo", "ponies"), "Failed iregex.")
def test_iregex_fail(self):
- op = operators.get_operator('iregex')
- self.assertFalse(op('V1_foo', 'v1$'), 'Passed iregex.')
- self.assertFalse(op('1', None), 'Passed iregex with None as criteria_pattern.')
+ op = operators.get_operator("iregex")
+ self.assertFalse(op("V1_foo", "v1$"), "Passed iregex.")
+ self.assertFalse(op("1", None), "Passed iregex with None as criteria_pattern.")
def test_regex(self):
- op = operators.get_operator('regex')
- self.assertTrue(op('v1', 'v1$'), 'Failed regex.')
+ op = operators.get_operator("regex")
+ self.assertTrue(op("v1", "v1$"), "Failed regex.")
- string = 'fooponiesbarfooooo'
- self.assertTrue(op(string, 'ponies'), 'Failed regex.')
+ string = "fooponiesbarfooooo"
+ self.assertTrue(op(string, "ponies"), "Failed regex.")
# Example with | modifier
- string = 'apple ponies oranges'
- self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.')
+ string = "apple ponies oranges"
+ self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.")
- string = 'apple unicorns oranges'
- self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.')
+ string = "apple unicorns oranges"
+ self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'apples unicorns oranges', '(ponies|unicorns)'), 'Failed regex.')
- self.assertTrue(op('apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.')
- self.assertTrue(op(b'apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.')
- self.assertTrue(op(b'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.')
- self.assertTrue(op(u'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.')
-
- string = 'apple unicorns oranges'
- self.assertFalse(op(string, '(pikachu|snorlax|charmander)'), 'Passed regex.')
+ self.assertTrue(
+ op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex."
+ )
+ self.assertTrue(
+ op("apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex."
+ )
+ self.assertTrue(
+ op(b"apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex."
+ )
+ self.assertTrue(
+ op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex."
+ )
+ self.assertTrue(
+ op("apples unicorns oranges", "(ponies|unicorns)"), "Failed regex."
+ )
+
+ string = "apple unicorns oranges"
+ self.assertFalse(op(string, "(pikachu|snorlax|charmander)"), "Passed regex.")
def test_regex_fail(self):
- op = operators.get_operator('regex')
- self.assertFalse(op('v1_foo', 'v1$'), 'Passed regex.')
+ op = operators.get_operator("regex")
+ self.assertFalse(op("v1_foo", "v1$"), "Passed regex.")
- string = 'fooPONIESbarfooooo'
- self.assertFalse(op(string, 'ponies'), 'Passed regex.')
+ string = "fooPONIESbarfooooo"
+ self.assertFalse(op(string, "ponies"), "Passed regex.")
- self.assertFalse(op('1', None), 'Passed regex with None as criteria_pattern.')
+ self.assertFalse(op("1", None), "Passed regex with None as criteria_pattern.")
def test_matchregex_case_variants(self):
- op = operators.get_operator('MATCHREGEX')
- self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.')
- op = operators.get_operator('MATCHregex')
- self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.')
+ op = operators.get_operator("MATCHREGEX")
+ self.assertTrue(op("v1", "v1$"), "Failed matchregex.")
+ op = operators.get_operator("MATCHregex")
+ self.assertTrue(op("v1", "v1$"), "Failed matchregex.")
def test_matchregex_fail(self):
- op = operators.get_operator('matchregex')
- self.assertFalse(op('v1_foo', 'v1$'), 'Passed matchregex.')
- self.assertFalse(op('1', None), 'Passed matchregex with None as criteria_pattern.')
+ op = operators.get_operator("matchregex")
+ self.assertFalse(op("v1_foo", "v1$"), "Passed matchregex.")
+ self.assertFalse(
+ op("1", None), "Passed matchregex with None as criteria_pattern."
+ )
def test_equals_numeric(self):
- op = operators.get_operator('equals')
- self.assertTrue(op(1, 1), 'Failed equals.')
+ op = operators.get_operator("equals")
+ self.assertTrue(op(1, 1), "Failed equals.")
def test_equals_string(self):
- op = operators.get_operator('equals')
- self.assertTrue(op('1', '1'), 'Failed equals.')
- self.assertTrue(op('', ''), 'Failed equals.')
+ op = operators.get_operator("equals")
+ self.assertTrue(op("1", "1"), "Failed equals.")
+ self.assertTrue(op("", ""), "Failed equals.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'1', '1'), 'Failed equals.')
- self.assertTrue(op('1', b'1'), 'Failed equals.')
- self.assertTrue(op(b'1', b'1'), 'Failed equals.')
- self.assertTrue(op(b'1', u'1'), 'Failed equals.')
- self.assertTrue(op(u'1', u'1'), 'Failed equals.')
+ self.assertTrue(op(b"1", "1"), "Failed equals.")
+ self.assertTrue(op("1", b"1"), "Failed equals.")
+ self.assertTrue(op(b"1", b"1"), "Failed equals.")
+ self.assertTrue(op(b"1", "1"), "Failed equals.")
+ self.assertTrue(op("1", "1"), "Failed equals.")
def test_equals_fail(self):
- op = operators.get_operator('equals')
- self.assertFalse(op('1', '2'), 'Passed equals.')
- self.assertFalse(op('1', None), 'Passed equals with None as criteria_pattern.')
+ op = operators.get_operator("equals")
+ self.assertFalse(op("1", "2"), "Passed equals.")
+ self.assertFalse(op("1", None), "Passed equals with None as criteria_pattern.")
def test_nequals(self):
- op = operators.get_operator('nequals')
- self.assertTrue(op('foo', 'bar'))
- self.assertTrue(op('foo', 'foo1'))
- self.assertTrue(op('foo', 'FOO'))
- self.assertTrue(op('True', True))
- self.assertTrue(op('None', None))
-
- self.assertFalse(op('True', 'True'))
+ op = operators.get_operator("nequals")
+ self.assertTrue(op("foo", "bar"))
+ self.assertTrue(op("foo", "foo1"))
+ self.assertTrue(op("foo", "FOO"))
+ self.assertTrue(op("True", True))
+ self.assertTrue(op("None", None))
+
+ self.assertFalse(op("True", "True"))
self.assertFalse(op(None, None))
def test_iequals(self):
- op = operators.get_operator('iequals')
- self.assertTrue(op('ABC', 'ABC'), 'Failed iequals.')
- self.assertTrue(op('ABC', 'abc'), 'Failed iequals.')
- self.assertTrue(op('AbC', 'aBc'), 'Failed iequals.')
+ op = operators.get_operator("iequals")
+ self.assertTrue(op("ABC", "ABC"), "Failed iequals.")
+ self.assertTrue(op("ABC", "abc"), "Failed iequals.")
+ self.assertTrue(op("AbC", "aBc"), "Failed iequals.")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'AbC', 'aBc'), 'Failed iequals.')
- self.assertTrue(op('AbC', b'aBc'), 'Failed iequals.')
- self.assertTrue(op(b'AbC', b'aBc'), 'Failed iequals.')
- self.assertTrue(op(b'AbC', u'aBc'), 'Failed iequals.')
- self.assertTrue(op(u'AbC', u'aBc'), 'Failed iequals.')
+ self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.")
+ self.assertTrue(op("AbC", b"aBc"), "Failed iequals.")
+ self.assertTrue(op(b"AbC", b"aBc"), "Failed iequals.")
+ self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.")
+ self.assertTrue(op("AbC", "aBc"), "Failed iequals.")
def test_iequals_fail(self):
- op = operators.get_operator('iequals')
- self.assertFalse(op('ABC', 'BCA'), 'Passed iequals.')
- self.assertFalse(op('1', None), 'Passed iequals with None as criteria_pattern.')
+ op = operators.get_operator("iequals")
+ self.assertFalse(op("ABC", "BCA"), "Passed iequals.")
+ self.assertFalse(op("1", None), "Passed iequals with None as criteria_pattern.")
def test_contains(self):
- op = operators.get_operator('contains')
- self.assertTrue(op('hasystack needle haystack', 'needle'))
- self.assertTrue(op('needle', 'needle'))
- self.assertTrue(op('needlehaystack', 'needle'))
- self.assertTrue(op('needle haystack', 'needle'))
- self.assertTrue(op('haystackneedle', 'needle'))
- self.assertTrue(op('haystack needle', 'needle'))
+ op = operators.get_operator("contains")
+ self.assertTrue(op("hasystack needle haystack", "needle"))
+ self.assertTrue(op("needle", "needle"))
+ self.assertTrue(op("needlehaystack", "needle"))
+ self.assertTrue(op("needle haystack", "needle"))
+ self.assertTrue(op("haystackneedle", "needle"))
+ self.assertTrue(op("haystack needle", "needle"))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'haystack needle', 'needle'))
- self.assertTrue(op('haystack needle', b'needle'))
- self.assertTrue(op(b'haystack needle', b'needle'))
- self.assertTrue(op(b'haystack needle', u'needle'))
- self.assertTrue(op(u'haystack needle', b'needle'))
+ self.assertTrue(op(b"haystack needle", "needle"))
+ self.assertTrue(op("haystack needle", b"needle"))
+ self.assertTrue(op(b"haystack needle", b"needle"))
+ self.assertTrue(op(b"haystack needle", "needle"))
+ self.assertTrue(op("haystack needle", b"needle"))
def test_contains_fail(self):
- op = operators.get_operator('contains')
- self.assertFalse(op('hasystack needl haystack', 'needle'))
- self.assertFalse(op('needla', 'needle'))
- self.assertFalse(op('1', None), 'Passed contains with None as criteria_pattern.')
+ op = operators.get_operator("contains")
+ self.assertFalse(op("hasystack needl haystack", "needle"))
+ self.assertFalse(op("needla", "needle"))
+ self.assertFalse(
+ op("1", None), "Passed contains with None as criteria_pattern."
+ )
def test_icontains(self):
- op = operators.get_operator('icontains')
- self.assertTrue(op('hasystack nEEdle haystack', 'needle'))
- self.assertTrue(op('neeDle', 'NeedlE'))
- self.assertTrue(op('needlehaystack', 'needle'))
- self.assertTrue(op('NEEDLE haystack', 'NEEDLE'))
- self.assertTrue(op('haystackNEEDLE', 'needle'))
- self.assertTrue(op('haystack needle', 'NEEDLE'))
+ op = operators.get_operator("icontains")
+ self.assertTrue(op("hasystack nEEdle haystack", "needle"))
+ self.assertTrue(op("neeDle", "NeedlE"))
+ self.assertTrue(op("needlehaystack", "needle"))
+ self.assertTrue(op("NEEDLE haystack", "NEEDLE"))
+ self.assertTrue(op("haystackNEEDLE", "needle"))
+ self.assertTrue(op("haystack needle", "NEEDLE"))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'haystack needle', 'NEEDLE'))
- self.assertTrue(op('haystack needle', b'NEEDLE'))
- self.assertTrue(op(b'haystack needle', b'NEEDLE'))
- self.assertTrue(op(b'haystack needle', u'NEEDLE'))
- self.assertTrue(op(u'haystack needle', b'NEEDLE'))
+ self.assertTrue(op(b"haystack needle", "NEEDLE"))
+ self.assertTrue(op("haystack needle", b"NEEDLE"))
+ self.assertTrue(op(b"haystack needle", b"NEEDLE"))
+ self.assertTrue(op(b"haystack needle", "NEEDLE"))
+ self.assertTrue(op("haystack needle", b"NEEDLE"))
def test_icontains_fail(self):
- op = operators.get_operator('icontains')
- self.assertFalse(op('hasystack needl haystack', 'needle'))
- self.assertFalse(op('needla', 'needle'))
- self.assertFalse(op('1', None), 'Passed icontains with None as criteria_pattern.')
+ op = operators.get_operator("icontains")
+ self.assertFalse(op("hasystack needl haystack", "needle"))
+ self.assertFalse(op("needla", "needle"))
+ self.assertFalse(
+ op("1", None), "Passed icontains with None as criteria_pattern."
+ )
def test_ncontains(self):
- op = operators.get_operator('ncontains')
- self.assertTrue(op('hasystack needle haystack', 'foo'))
- self.assertTrue(op('needle', 'foo'))
- self.assertTrue(op('needlehaystack', 'needlex'))
- self.assertTrue(op('needle haystack', 'needlex'))
- self.assertTrue(op('haystackneedle', 'needlex'))
- self.assertTrue(op('haystack needle', 'needlex'))
+ op = operators.get_operator("ncontains")
+ self.assertTrue(op("hasystack needle haystack", "foo"))
+ self.assertTrue(op("needle", "foo"))
+ self.assertTrue(op("needlehaystack", "needlex"))
+ self.assertTrue(op("needle haystack", "needlex"))
+ self.assertTrue(op("haystackneedle", "needlex"))
+ self.assertTrue(op("haystack needle", "needlex"))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'haystack needle', 'needlex'))
- self.assertTrue(op('haystack needle', b'needlex'))
- self.assertTrue(op(b'haystack needle', b'needlex'))
- self.assertTrue(op(b'haystack needle', u'needlex'))
- self.assertTrue(op(u'haystack needle', b'needlex'))
+ self.assertTrue(op(b"haystack needle", "needlex"))
+ self.assertTrue(op("haystack needle", b"needlex"))
+ self.assertTrue(op(b"haystack needle", b"needlex"))
+ self.assertTrue(op(b"haystack needle", "needlex"))
+ self.assertTrue(op("haystack needle", b"needlex"))
def test_ncontains_fail(self):
- op = operators.get_operator('ncontains')
- self.assertFalse(op('hasystack needle haystack', 'needle'))
- self.assertFalse(op('needla', 'needla'))
- self.assertFalse(op('1', None), 'Passed ncontains with None as criteria_pattern.')
+ op = operators.get_operator("ncontains")
+ self.assertFalse(op("hasystack needle haystack", "needle"))
+ self.assertFalse(op("needla", "needla"))
+ self.assertFalse(
+ op("1", None), "Passed ncontains with None as criteria_pattern."
+ )
def test_incontains(self):
- op = operators.get_operator('incontains')
- self.assertTrue(op('hasystack needle haystack', 'FOO'))
- self.assertTrue(op('needle', 'FOO'))
- self.assertTrue(op('needlehaystack', 'needlex'))
- self.assertTrue(op('needle haystack', 'needlex'))
- self.assertTrue(op('haystackneedle', 'needlex'))
- self.assertTrue(op('haystack needle', 'needlex'))
+ op = operators.get_operator("incontains")
+ self.assertTrue(op("hasystack needle haystack", "FOO"))
+ self.assertTrue(op("needle", "FOO"))
+ self.assertTrue(op("needlehaystack", "needlex"))
+ self.assertTrue(op("needle haystack", "needlex"))
+ self.assertTrue(op("haystackneedle", "needlex"))
+ self.assertTrue(op("haystack needle", "needlex"))
def test_incontains_fail(self):
- op = operators.get_operator('incontains')
- self.assertFalse(op('hasystack needle haystack', 'nEeDle'))
- self.assertFalse(op('needlA', 'needlA'))
- self.assertFalse(op('1', None), 'Passed incontains with None as criteria_pattern.')
+ op = operators.get_operator("incontains")
+ self.assertFalse(op("hasystack needle haystack", "nEeDle"))
+ self.assertFalse(op("needlA", "needlA"))
+ self.assertFalse(
+ op("1", None), "Passed incontains with None as criteria_pattern."
+ )
def test_startswith(self):
- op = operators.get_operator('startswith')
- self.assertTrue(op('hasystack needle haystack', 'hasystack'))
- self.assertTrue(op('a hasystack needle haystack', 'a '))
+ op = operators.get_operator("startswith")
+ self.assertTrue(op("hasystack needle haystack", "hasystack"))
+ self.assertTrue(op("a hasystack needle haystack", "a "))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'haystack needle', 'haystack'))
- self.assertTrue(op('haystack needle', b'haystack'))
- self.assertTrue(op(b'haystack needle', b'haystack'))
- self.assertTrue(op(b'haystack needle', u'haystack'))
- self.assertTrue(op(u'haystack needle', b'haystack'))
+ self.assertTrue(op(b"haystack needle", "haystack"))
+ self.assertTrue(op("haystack needle", b"haystack"))
+ self.assertTrue(op(b"haystack needle", b"haystack"))
+ self.assertTrue(op(b"haystack needle", "haystack"))
+ self.assertTrue(op("haystack needle", b"haystack"))
def test_startswith_fail(self):
- op = operators.get_operator('startswith')
- self.assertFalse(op('hasystack needle haystack', 'needle'))
- self.assertFalse(op('a hasystack needle haystack', 'haystack'))
- self.assertFalse(op('1', None), 'Passed startswith with None as criteria_pattern.')
+ op = operators.get_operator("startswith")
+ self.assertFalse(op("hasystack needle haystack", "needle"))
+ self.assertFalse(op("a hasystack needle haystack", "haystack"))
+ self.assertFalse(
+ op("1", None), "Passed startswith with None as criteria_pattern."
+ )
def test_istartswith(self):
- op = operators.get_operator('istartswith')
- self.assertTrue(op('haystack needle haystack', 'HAYstack'))
- self.assertTrue(op('HAYSTACK needle haystack', 'haystack'))
+ op = operators.get_operator("istartswith")
+ self.assertTrue(op("haystack needle haystack", "HAYstack"))
+ self.assertTrue(op("HAYSTACK needle haystack", "haystack"))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'HAYSTACK needle haystack', 'haystack'))
- self.assertTrue(op('HAYSTACK needle haystack', b'haystack'))
- self.assertTrue(op(b'HAYSTACK needle haystack', b'haystack'))
- self.assertTrue(op(b'HAYSTACK needle haystack', u'haystack'))
- self.assertTrue(op(u'HAYSTACK needle haystack', b'haystack'))
+ self.assertTrue(op(b"HAYSTACK needle haystack", "haystack"))
+ self.assertTrue(op("HAYSTACK needle haystack", b"haystack"))
+ self.assertTrue(op(b"HAYSTACK needle haystack", b"haystack"))
+ self.assertTrue(op(b"HAYSTACK needle haystack", "haystack"))
+ self.assertTrue(op("HAYSTACK needle haystack", b"haystack"))
def test_istartswith_fail(self):
- op = operators.get_operator('istartswith')
- self.assertFalse(op('hasystack needle haystack', 'NEEDLE'))
- self.assertFalse(op('a hasystack needle haystack', 'haystack'))
- self.assertFalse(op('1', None), 'Passed istartswith with None as criteria_pattern.')
+ op = operators.get_operator("istartswith")
+ self.assertFalse(op("hasystack needle haystack", "NEEDLE"))
+ self.assertFalse(op("a hasystack needle haystack", "haystack"))
+ self.assertFalse(
+ op("1", None), "Passed istartswith with None as criteria_pattern."
+ )
def test_endswith(self):
- op = operators.get_operator('endswith')
- self.assertTrue(op('hasystack needle haystackend', 'haystackend'))
- self.assertTrue(op('a hasystack needle haystack b', 'b'))
+ op = operators.get_operator("endswith")
+ self.assertTrue(op("hasystack needle haystackend", "haystackend"))
+ self.assertTrue(op("a hasystack needle haystack b", "b"))
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'a hasystack needle haystack b', 'b'))
- self.assertTrue(op('a hasystack needle haystack b', b'b'))
- self.assertTrue(op(b'a hasystack needle haystack b', b'b'))
- self.assertTrue(op(b'a hasystack needle haystack b', u'b'))
- self.assertTrue(op(u'a hasystack needle haystack b', b'b'))
+ self.assertTrue(op(b"a hasystack needle haystack b", "b"))
+ self.assertTrue(op("a hasystack needle haystack b", b"b"))
+ self.assertTrue(op(b"a hasystack needle haystack b", b"b"))
+ self.assertTrue(op(b"a hasystack needle haystack b", "b"))
+ self.assertTrue(op("a hasystack needle haystack b", b"b"))
def test_endswith_fail(self):
- op = operators.get_operator('endswith')
- self.assertFalse(op('hasystack needle haystackend', 'haystack'))
- self.assertFalse(op('a hasystack needle haystack', 'a'))
- self.assertFalse(op('1', None), 'Passed endswith with None as criteria_pattern.')
+ op = operators.get_operator("endswith")
+ self.assertFalse(op("hasystack needle haystackend", "haystack"))
+ self.assertFalse(op("a hasystack needle haystack", "a"))
+ self.assertFalse(
+ op("1", None), "Passed endswith with None as criteria_pattern."
+ )
def test_iendswith(self):
- op = operators.get_operator('iendswith')
- self.assertTrue(op('haystack needle haystackEND', 'HAYstackend'))
- self.assertTrue(op('HAYSTACK needle haystackend', 'haystackEND'))
+ op = operators.get_operator("iendswith")
+ self.assertTrue(op("haystack needle haystackEND", "HAYstackend"))
+ self.assertTrue(op("HAYSTACK needle haystackend", "haystackEND"))
def test_iendswith_fail(self):
- op = operators.get_operator('iendswith')
- self.assertFalse(op('hasystack needle haystack', 'NEEDLE'))
- self.assertFalse(op('a hasystack needle haystack', 'a '))
- self.assertFalse(op('1', None), 'Passed iendswith with None as criteria_pattern.')
+ op = operators.get_operator("iendswith")
+ self.assertFalse(op("hasystack needle haystack", "NEEDLE"))
+ self.assertFalse(op("a hasystack needle haystack", "a "))
+ self.assertFalse(
+ op("1", None), "Passed iendswith with None as criteria_pattern."
+ )
def test_lt(self):
- op = operators.get_operator('lessthan')
- self.assertTrue(op(1, 2), 'Failed lessthan.')
+ op = operators.get_operator("lessthan")
+ self.assertTrue(op(1, 2), "Failed lessthan.")
def test_lt_char(self):
- op = operators.get_operator('lessthan')
- self.assertTrue(op('a', 'b'), 'Failed lessthan.')
+ op = operators.get_operator("lessthan")
+ self.assertTrue(op("a", "b"), "Failed lessthan.")
def test_lt_fail(self):
- op = operators.get_operator('lessthan')
- self.assertFalse(op(1, 1), 'Passed lessthan.')
- self.assertFalse(op('1', None), 'Passed lessthan with None as criteria_pattern.')
+ op = operators.get_operator("lessthan")
+ self.assertFalse(op(1, 1), "Passed lessthan.")
+ self.assertFalse(
+ op("1", None), "Passed lessthan with None as criteria_pattern."
+ )
def test_gt(self):
- op = operators.get_operator('greaterthan')
- self.assertTrue(op(2, 1), 'Failed greaterthan.')
+ op = operators.get_operator("greaterthan")
+ self.assertTrue(op(2, 1), "Failed greaterthan.")
def test_gt_str(self):
- op = operators.get_operator('lessthan')
- self.assertTrue(op('aba', 'bcb'), 'Failed greaterthan.')
+ op = operators.get_operator("lessthan")
+ self.assertTrue(op("aba", "bcb"), "Failed greaterthan.")
def test_gt_fail(self):
- op = operators.get_operator('greaterthan')
- self.assertFalse(op(2, 3), 'Passed greaterthan.')
- self.assertFalse(op('1', None), 'Passed greaterthan with None as criteria_pattern.')
+ op = operators.get_operator("greaterthan")
+ self.assertFalse(op(2, 3), "Passed greaterthan.")
+ self.assertFalse(
+ op("1", None), "Passed greaterthan with None as criteria_pattern."
+ )
def test_timediff_lt(self):
- op = operators.get_operator('timediff_lt')
- self.assertTrue(op(date_utils.get_datetime_utc_now().isoformat(), 10),
- 'Failed test_timediff_lt.')
+ op = operators.get_operator("timediff_lt")
+ self.assertTrue(
+ op(date_utils.get_datetime_utc_now().isoformat(), 10),
+ "Failed test_timediff_lt.",
+ )
def test_timediff_lt_fail(self):
- op = operators.get_operator('timediff_lt')
- self.assertFalse(op('2014-07-01T00:01:01.000000', 10),
- 'Passed test_timediff_lt.')
- self.assertFalse(op('2014-07-01T00:01:01.000000', None),
- 'Passed test_timediff_lt with None as criteria_pattern.')
+ op = operators.get_operator("timediff_lt")
+ self.assertFalse(
+ op("2014-07-01T00:01:01.000000", 10), "Passed test_timediff_lt."
+ )
+ self.assertFalse(
+ op("2014-07-01T00:01:01.000000", None),
+ "Passed test_timediff_lt with None as criteria_pattern.",
+ )
def test_timediff_gt(self):
- op = operators.get_operator('timediff_gt')
- self.assertTrue(op('2014-07-01T00:01:01.000000', 1),
- 'Failed test_timediff_gt.')
+ op = operators.get_operator("timediff_gt")
+ self.assertTrue(op("2014-07-01T00:01:01.000000", 1), "Failed test_timediff_gt.")
def test_timediff_gt_fail(self):
- op = operators.get_operator('timediff_gt')
- self.assertFalse(op(date_utils.get_datetime_utc_now().isoformat(), 10),
- 'Passed test_timediff_gt.')
- self.assertFalse(op('2014-07-01T00:01:01.000000', None),
- 'Passed test_timediff_gt with None as criteria_pattern.')
+ op = operators.get_operator("timediff_gt")
+ self.assertFalse(
+ op(date_utils.get_datetime_utc_now().isoformat(), 10),
+ "Passed test_timediff_gt.",
+ )
+ self.assertFalse(
+ op("2014-07-01T00:01:01.000000", None),
+ "Passed test_timediff_gt with None as criteria_pattern.",
+ )
def test_exists(self):
- op = operators.get_operator('exists')
- self.assertTrue(op(False, None), 'Should return True')
- self.assertTrue(op(1, None), 'Should return True')
- self.assertTrue(op('foo', None), 'Should return True')
- self.assertFalse(op(None, None), 'Should return False')
+ op = operators.get_operator("exists")
+ self.assertTrue(op(False, None), "Should return True")
+ self.assertTrue(op(1, None), "Should return True")
+ self.assertTrue(op("foo", None), "Should return True")
+ self.assertFalse(op(None, None), "Should return False")
def test_nexists(self):
- op = operators.get_operator('nexists')
- self.assertFalse(op(False, None), 'Should return False')
- self.assertFalse(op(1, None), 'Should return False')
- self.assertFalse(op('foo', None), 'Should return False')
- self.assertTrue(op(None, None), 'Should return True')
+ op = operators.get_operator("nexists")
+ self.assertFalse(op(False, None), "Should return False")
+ self.assertFalse(op(1, None), "Should return False")
+ self.assertFalse(op("foo", None), "Should return False")
+ self.assertTrue(op(None, None), "Should return True")
def test_inside(self):
- op = operators.get_operator('inside')
- self.assertFalse(op('a', None), 'Should return False')
- self.assertFalse(op('a', 'bcd'), 'Should return False')
- self.assertTrue(op('a', 'abc'), 'Should return True')
+ op = operators.get_operator("inside")
+ self.assertFalse(op("a", None), "Should return False")
+ self.assertFalse(op("a", "bcd"), "Should return False")
+ self.assertTrue(op("a", "abc"), "Should return True")
# Mixing bytes and strings / unicode should still work
- self.assertTrue(op(b'a', 'abc'), 'Should return True')
- self.assertTrue(op('a', b'abc'), 'Should return True')
- self.assertTrue(op(b'a', b'abc'), 'Should return True')
+ self.assertTrue(op(b"a", "abc"), "Should return True")
+ self.assertTrue(op("a", b"abc"), "Should return True")
+ self.assertTrue(op(b"a", b"abc"), "Should return True")
def test_ninside(self):
- op = operators.get_operator('ninside')
- self.assertFalse(op('a', None), 'Should return False')
- self.assertFalse(op('a', 'abc'), 'Should return False')
- self.assertTrue(op('a', 'bcd'), 'Should return True')
+ op = operators.get_operator("ninside")
+ self.assertFalse(op("a", None), "Should return False")
+ self.assertFalse(op("a", "abc"), "Should return False")
+ self.assertTrue(op("a", "bcd"), "Should return True")
class GetOperatorsTest(unittest2.TestCase):
def test_get_operator(self):
- self.assertTrue(operators.get_operator('equals'))
- self.assertTrue(operators.get_operator('EQUALS'))
+ self.assertTrue(operators.get_operator("equals"))
+ self.assertTrue(operators.get_operator("EQUALS"))
def test_get_operator_returns_same_operator_with_different_cases(self):
- equals = operators.get_operator('equals')
- EQUALS = operators.get_operator('EQUALS')
- Equals = operators.get_operator('Equals')
+ equals = operators.get_operator("equals")
+ EQUALS = operators.get_operator("EQUALS")
+ Equals = operators.get_operator("Equals")
self.assertEqual(equals, EQUALS)
self.assertEqual(equals, Equals)
def test_get_operator_with_nonexistent_operator(self):
with self.assertRaises(Exception):
- operators.get_operator('weird')
+ operators.get_operator("weird")
def test_get_allowed_operators(self):
# This test will need to change as operators are deprecated
diff --git a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py
index 355e680c85..630b14ed37 100644
--- a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py
+++ b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py
@@ -23,111 +23,117 @@
from st2common.exceptions.content import ParseException
from st2common.models.db.actionalias import ActionAliasDB
-__all__ = [
- 'PackActionAliasUnitTestUtils'
-]
+__all__ = ["PackActionAliasUnitTestUtils"]
-PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/pack_dir_name_doesnt_match_ref')
+PACK_PATH_1 = os.path.join(
+ get_fixtures_base_path(), "packs/pack_dir_name_doesnt_match_ref"
+)
class PackActionAliasUnitTestUtils(BaseActionAliasTestCase):
- action_alias_name = 'mock'
+ action_alias_name = "mock"
mock_get_action_alias_db_by_name = True
def test_assertExtractedParametersMatch_success(self):
format_string = self.action_alias_db.formats[0]
- command = 'show last 3 metrics for my.host'
- expected_parameters = {
- 'count': '3',
- 'server': 'my.host'
- }
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ command = "show last 3 metrics for my.host"
+ expected_parameters = {"count": "3", "server": "my.host"}
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
format_string = self.action_alias_db.formats[0]
- command = 'show last 10 metrics for my.host.example'
- expected_parameters = {
- 'count': '10',
- 'server': 'my.host.example'
- }
- self.assertExtractedParametersMatch(format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ command = "show last 10 metrics for my.host.example"
+ expected_parameters = {"count": "10", "server": "my.host.example"}
+ self.assertExtractedParametersMatch(
+ format_string=format_string, command=command, parameters=expected_parameters
+ )
def test_assertExtractedParametersMatch_command_doesnt_match_format_string(self):
format_string = self.action_alias_db.formats[0]
- command = 'show last foo'
+ command = "show last foo"
expected_parameters = {}
- expected_msg = ('Command "show last foo" doesn\'t match format string '
- '"show last {{count}} metrics for {{server}}"')
-
- self.assertRaisesRegexp(ParseException, expected_msg,
- self.assertExtractedParametersMatch,
- format_string=format_string,
- command=command,
- parameters=expected_parameters)
+ expected_msg = (
+ 'Command "show last foo" doesn\'t match format string '
+ '"show last {{count}} metrics for {{server}}"'
+ )
+
+ self.assertRaisesRegexp(
+ ParseException,
+ expected_msg,
+ self.assertExtractedParametersMatch,
+ format_string=format_string,
+ command=command,
+ parameters=expected_parameters,
+ )
def test_assertCommandMatchesExactlyOneFormatString(self):
# Matches single format string
- format_strings = [
- 'foo bar {{bar}}',
- 'foo bar {{baz}} baz'
- ]
- command = 'foo bar a test=1'
- self.assertCommandMatchesExactlyOneFormatString(format_strings=format_strings,
- command=command)
+ format_strings = ["foo bar {{bar}}", "foo bar {{baz}} baz"]
+ command = "foo bar a test=1"
+ self.assertCommandMatchesExactlyOneFormatString(
+ format_strings=format_strings, command=command
+ )
# Matches multiple format strings
- format_strings = [
- 'foo bar {{bar}}',
- 'foo bar {{baz}}'
- ]
- command = 'foo bar a test=1'
-
- expected_msg = ('Command "foo bar a test=1" matched multiple format '
- 'strings: foo bar {{bar}}, foo bar {{baz}}')
- self.assertRaisesRegexp(AssertionError, expected_msg,
- self.assertCommandMatchesExactlyOneFormatString,
- format_strings=format_strings,
- command=command)
+ format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"]
+ command = "foo bar a test=1"
+
+ expected_msg = (
+ 'Command "foo bar a test=1" matched multiple format '
+ "strings: foo bar {{bar}}, foo bar {{baz}}"
+ )
+ self.assertRaisesRegexp(
+ AssertionError,
+ expected_msg,
+ self.assertCommandMatchesExactlyOneFormatString,
+ format_strings=format_strings,
+ command=command,
+ )
# Doesn't matches any format strings
- format_strings = [
- 'foo bar {{bar}}',
- 'foo bar {{baz}}'
- ]
- command = 'does not match foo'
-
- expected_msg = ('Command "does not match foo" didn\'t match any of the provided format '
- 'strings')
- self.assertRaisesRegexp(AssertionError, expected_msg,
- self.assertCommandMatchesExactlyOneFormatString,
- format_strings=format_strings,
- command=command)
-
- @mock.patch.object(BaseActionAliasTestCase, '_get_base_pack_path',
- mock.Mock(return_value=PACK_PATH_1))
+ format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"]
+ command = "does not match foo"
+
+ expected_msg = (
+ 'Command "does not match foo" didn\'t match any of the provided format '
+ "strings"
+ )
+ self.assertRaisesRegexp(
+ AssertionError,
+ expected_msg,
+ self.assertCommandMatchesExactlyOneFormatString,
+ format_strings=format_strings,
+ command=command,
+ )
+
+ @mock.patch.object(
+ BaseActionAliasTestCase,
+ "_get_base_pack_path",
+ mock.Mock(return_value=PACK_PATH_1),
+ )
def test_base_class_works_when_pack_directory_name_doesnt_match_pack_name(self):
# Verify that the alias can still be succesfuly loaded from disk even if the pack directory
# name doesn't match "pack" resource attribute (aka pack ref)
self.mock_get_action_alias_db_by_name = False
- action_alias_db = self._get_action_alias_db_by_name(name='alias1')
- self.assertEqual(action_alias_db.name, 'alias1')
- self.assertEqual(action_alias_db.pack, 'pack_name_not_the_same_as_dir_name')
+ action_alias_db = self._get_action_alias_db_by_name(name="alias1")
+ self.assertEqual(action_alias_db.name, "alias1")
+ self.assertEqual(action_alias_db.pack, "pack_name_not_the_same_as_dir_name")
# Note: We mock the original method to make testing of all the edge cases easier
def _get_action_alias_db_by_name(self, name):
if not self.mock_get_action_alias_db_by_name:
- return super(PackActionAliasUnitTestUtils, self)._get_action_alias_db_by_name(name)
+ return super(
+ PackActionAliasUnitTestUtils, self
+ )._get_action_alias_db_by_name(name)
values = {
- 'name': self.action_alias_name,
- 'pack': 'mock',
- 'formats': [
- 'show last {{count}} metrics for {{server}}',
- ]
+ "name": self.action_alias_name,
+ "pack": "mock",
+ "formats": [
+ "show last {{count}} metrics for {{server}}",
+ ],
}
action_alias_db = ActionAliasDB(**values)
return action_alias_db
diff --git a/st2common/tests/unit/test_pack_management.py b/st2common/tests/unit/test_pack_management.py
index abc0498489..b350c7d98f 100644
--- a/st2common/tests/unit/test_pack_management.py
+++ b/st2common/tests/unit/test_pack_management.py
@@ -21,37 +21,35 @@
import unittest2
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-PACK_ACTIONS_DIR = os.path.join(BASE_DIR, '../../../contrib/packs/actions')
+PACK_ACTIONS_DIR = os.path.join(BASE_DIR, "../../../contrib/packs/actions")
PACK_ACTIONS_DIR = os.path.abspath(PACK_ACTIONS_DIR)
sys.path.insert(0, PACK_ACTIONS_DIR)
from st2common.util.monkey_patch import use_select_poll_workaround
+
use_select_poll_workaround()
from st2common.util.pack_management import eval_repo_url
-__all__ = [
- 'InstallPackTestCase'
-]
+__all__ = ["InstallPackTestCase"]
class InstallPackTestCase(unittest2.TestCase):
-
def test_eval_repo(self):
- result = eval_repo_url('stackstorm/st2contrib')
- self.assertEqual(result, 'https://github.com/stackstorm/st2contrib')
+ result = eval_repo_url("stackstorm/st2contrib")
+ self.assertEqual(result, "https://github.com/stackstorm/st2contrib")
- result = eval_repo_url('git@github.com:StackStorm/st2contrib.git')
- self.assertEqual(result, 'git@github.com:StackStorm/st2contrib.git')
+ result = eval_repo_url("git@github.com:StackStorm/st2contrib.git")
+ self.assertEqual(result, "git@github.com:StackStorm/st2contrib.git")
- result = eval_repo_url('gitlab@gitlab.com:StackStorm/st2contrib.git')
- self.assertEqual(result, 'gitlab@gitlab.com:StackStorm/st2contrib.git')
+ result = eval_repo_url("gitlab@gitlab.com:StackStorm/st2contrib.git")
+ self.assertEqual(result, "gitlab@gitlab.com:StackStorm/st2contrib.git")
- repo_url = 'https://github.com/StackStorm/st2contrib.git'
+ repo_url = "https://github.com/StackStorm/st2contrib.git"
result = eval_repo_url(repo_url)
self.assertEqual(result, repo_url)
- repo_url = 'https://git-wip-us.apache.org/repos/asf/libcloud.git'
+ repo_url = "https://git-wip-us.apache.org/repos/asf/libcloud.git"
result = eval_repo_url(repo_url)
self.assertEqual(result, repo_url)
diff --git a/st2common/tests/unit/test_param_utils.py b/st2common/tests/unit/test_param_utils.py
index 695d17f448..c2e5810815 100644
--- a/st2common/tests/unit/test_param_utils.py
+++ b/st2common/tests/unit/test_param_utils.py
@@ -36,30 +36,31 @@
from st2tests.fixturesloader import FixturesLoader
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_MODELS = {
- 'actions': ['action_4_action_context_param.yaml', 'action_system_default.yaml'],
- 'runners': ['testrunner1.yaml']
+ "actions": ["action_4_action_context_param.yaml", "action_system_default.yaml"],
+ "runners": ["testrunner1.yaml"],
}
-FIXTURES = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS)
+FIXTURES = FixturesLoader().load_models(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS
+)
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ParamsUtilsTest(DbTestCase):
- action_db = FIXTURES['actions']['action_4_action_context_param.yaml']
- action_system_default_db = FIXTURES['actions']['action_system_default.yaml']
- runnertype_db = FIXTURES['runners']['testrunner1.yaml']
+ action_db = FIXTURES["actions"]["action_4_action_context_param.yaml"]
+ action_system_default_db = FIXTURES["actions"]["action_system_default.yaml"]
+ runnertype_db = FIXTURES["runners"]["testrunner1.yaml"]
def test_get_finalized_params(self):
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555,
- 'runnerimmutable': 'failed_override',
- 'actionimmutable': 'failed_override'
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
+ "runnerimmutable": "failed_override",
+ "actionimmutable": "failed_override",
}
liveaction_db = self._get_liveaction_model(params)
@@ -67,289 +68,320 @@ def test_get_finalized_params(self):
ParamsUtilsTest.runnertype_db.runner_parameters,
ParamsUtilsTest.action_db.parameters,
liveaction_db.parameters,
- liveaction_db.context)
+ liveaction_db.context,
+ )
# Asserts for runner params.
# Assert that default values for runner params are resolved.
- self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo')
+ self.assertEqual(runner_params.get("runnerstr"), "defaultfoo")
# Assert that a runner param from action exec is picked up.
- self.assertEqual(runner_params.get('runnerint'), 555)
+ self.assertEqual(runner_params.get("runnerint"), 555)
# Assert that a runner param can be overridden by action param default.
- self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy')
+ self.assertEqual(runner_params.get("runnerdummy"), "actiondummy")
# Assert that a runner param default can be overridden by 'falsey' action param default,
# (timeout: 0 case).
- self.assertEqual(runner_params.get('runnerdefaultint'), 0)
+ self.assertEqual(runner_params.get("runnerdefaultint"), 0)
# Assert that an immutable param cannot be overridden by action param or execution param.
- self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable')
+ self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable")
# Asserts for action params.
- self.assertEqual(action_params.get('actionstr'), 'foo')
+ self.assertEqual(action_params.get("actionstr"), "foo")
# Assert that a param that is provided in action exec that isn't in action or runner params
# isn't in resolved params.
- self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None)
+ self.assertEqual(
+ action_params.get("some_key_that_aint_exist_in_action_or_runner"), None
+ )
# Assert that an immutable param cannot be overridden by execution param.
- self.assertEqual(action_params.get('actionimmutable'), 'actionimmutable')
+ self.assertEqual(action_params.get("actionimmutable"), "actionimmutable")
# Assert that an action context param is set correctly.
- self.assertEqual(action_params.get('action_api_user'), 'noob')
+ self.assertEqual(action_params.get("action_api_user"), "noob")
# Assert that none of runner params are present in action_params.
for k in action_params:
- self.assertNotIn(k, runner_params, 'Param ' + k + ' is a runner param.')
+ self.assertNotIn(k, runner_params, "Param " + k + " is a runner param.")
def test_get_finalized_params_system_values(self):
- KeyValuePair.add_or_update(KeyValuePairDB(name='actionstr', value='foo'))
- KeyValuePair.add_or_update(KeyValuePairDB(name='actionnumber', value='1.0'))
- params = {
- 'runnerint': 555
- }
+ KeyValuePair.add_or_update(KeyValuePairDB(name="actionstr", value="foo"))
+ KeyValuePair.add_or_update(KeyValuePairDB(name="actionnumber", value="1.0"))
+ params = {"runnerint": 555}
liveaction_db = self._get_liveaction_model(params)
runner_params, action_params = param_utils.get_finalized_params(
ParamsUtilsTest.runnertype_db.runner_parameters,
ParamsUtilsTest.action_system_default_db.parameters,
liveaction_db.parameters,
- liveaction_db.context)
+ liveaction_db.context,
+ )
# Asserts for runner params.
# Assert that default values for runner params are resolved.
- self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo')
+ self.assertEqual(runner_params.get("runnerstr"), "defaultfoo")
# Assert that a runner param from action exec is picked up.
- self.assertEqual(runner_params.get('runnerint'), 555)
+ self.assertEqual(runner_params.get("runnerint"), 555)
# Assert that an immutable param cannot be overridden by action param or execution param.
- self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable')
+ self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable")
# Asserts for action params.
- self.assertEqual(action_params.get('actionstr'), 'foo')
- self.assertEqual(action_params.get('actionnumber'), 1.0)
+ self.assertEqual(action_params.get("actionstr"), "foo")
+ self.assertEqual(action_params.get("actionnumber"), 1.0)
def test_get_finalized_params_action_immutable(self):
params = {
- 'actionstr': 'foo',
- 'some_key_that_aint_exist_in_action_or_runner': 'bar',
- 'runnerint': 555,
- 'actionimmutable': 'failed_override'
+ "actionstr": "foo",
+ "some_key_that_aint_exist_in_action_or_runner": "bar",
+ "runnerint": 555,
+ "actionimmutable": "failed_override",
}
liveaction_db = self._get_liveaction_model(params)
- action_context = {'api_user': None}
+ action_context = {"api_user": None}
runner_params, action_params = param_utils.get_finalized_params(
ParamsUtilsTest.runnertype_db.runner_parameters,
ParamsUtilsTest.action_db.parameters,
liveaction_db.parameters,
- action_context)
+ action_context,
+ )
# Asserts for runner params.
# Assert that default values for runner params are resolved.
- self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo')
+ self.assertEqual(runner_params.get("runnerstr"), "defaultfoo")
# Assert that a runner param from action exec is picked up.
- self.assertEqual(runner_params.get('runnerint'), 555)
+ self.assertEqual(runner_params.get("runnerint"), 555)
# Assert that a runner param can be overridden by action param default.
- self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy')
+ self.assertEqual(runner_params.get("runnerdummy"), "actiondummy")
# Asserts for action params.
- self.assertEqual(action_params.get('actionstr'), 'foo')
+ self.assertEqual(action_params.get("actionstr"), "foo")
# Assert that a param that is provided in action exec that isn't in action or runner params
# isn't in resolved params.
- self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None)
+ self.assertEqual(
+ action_params.get("some_key_that_aint_exist_in_action_or_runner"), None
+ )
def test_get_finalized_params_empty(self):
params = {}
runner_param_info = {}
action_param_info = {}
- action_context = {'user': None}
+ action_context = {"user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
+ runner_param_info, action_param_info, params, action_context
+ )
self.assertEqual(r_runner_params, params)
self.assertEqual(r_action_params, params)
def test_get_finalized_params_none(self):
- params = {
- 'r1': None,
- 'a1': None
- }
- runner_param_info = {'r1': {}}
- action_param_info = {'a1': {}}
- action_context = {'api_user': None}
+ params = {"r1": None, "a1": None}
+ runner_param_info = {"r1": {}}
+ action_param_info = {"a1": {}}
+ action_context = {"api_user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': None})
- self.assertEqual(r_action_params, {'a1': None})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": None})
+ self.assertEqual(r_action_params, {"a1": None})
def test_get_finalized_params_no_cast(self):
params = {
- 'r1': '{{r2}}',
- 'r2': 1,
- 'a1': True,
- 'a2': '{{r1}} {{a1}}',
- 'a3': '{{action_context.api_user}}'
- }
- runner_param_info = {'r1': {}, 'r2': {}}
- action_param_info = {'a1': {}, 'a2': {}, 'a3': {}}
- action_context = {'api_user': 'noob'}
+ "r1": "{{r2}}",
+ "r2": 1,
+ "a1": True,
+ "a2": "{{r1}} {{a1}}",
+ "a3": "{{action_context.api_user}}",
+ }
+ runner_param_info = {"r1": {}, "r2": {}}
+ action_param_info = {"a1": {}, "a2": {}, "a3": {}}
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': u'1', 'r2': 1})
- self.assertEqual(r_action_params, {'a1': True, 'a2': u'1 True', 'a3': 'noob'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": "1", "r2": 1})
+ self.assertEqual(r_action_params, {"a1": True, "a2": "1 True", "a3": "noob"})
def test_get_finalized_params_with_cast(self):
# Note : In this test runner_params.r1 has a string value. However per runner_param_info the
# type is an integer. The expected type is considered and cast is performed accordingly.
params = {
- 'r1': '{{r2}}',
- 'r2': 1,
- 'a1': True,
- 'a2': '{{a1}}',
- 'a3': '{{action_context.api_user}}'
+ "r1": "{{r2}}",
+ "r2": 1,
+ "a1": True,
+ "a2": "{{a1}}",
+ "a3": "{{action_context.api_user}}",
}
- runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}}
- action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'boolean'}, 'a3': {}}
- action_context = {'api_user': 'noob'}
+ runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}}
+ action_param_info = {
+ "a1": {"type": "boolean"},
+ "a2": {"type": "boolean"},
+ "a3": {},
+ }
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1})
- self.assertEqual(r_action_params, {'a1': True, 'a2': True, 'a3': 'noob'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": 1, "r2": 1})
+ self.assertEqual(r_action_params, {"a1": True, "a2": True, "a3": "noob"})
def test_get_finalized_params_with_cast_overriden(self):
params = {
- 'r1': '{{r2}}',
- 'r2': 1,
- 'a1': '{{r1}}',
- 'a2': '{{r1}}',
- 'a3': '{{r1}}'
+ "r1": "{{r2}}",
+ "r2": 1,
+ "a1": "{{r1}}",
+ "a2": "{{r1}}",
+ "a3": "{{r1}}",
}
- runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}}
- action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'string'},
- 'a3': {'type': 'integer'}, 'r1': {'type': 'string'}}
- action_context = {'api_user': 'noob'}
+ runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}}
+ action_param_info = {
+ "a1": {"type": "boolean"},
+ "a2": {"type": "string"},
+ "a3": {"type": "integer"},
+ "r1": {"type": "string"},
+ }
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1})
- self.assertEqual(r_action_params, {'a1': 1, 'a2': u'1', 'a3': 1})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": 1, "r2": 1})
+ self.assertEqual(r_action_params, {"a1": 1, "a2": "1", "a3": 1})
def test_get_finalized_params_cross_talk_no_cast(self):
params = {
- 'r1': '{{a1}}',
- 'r2': 1,
- 'a1': True,
- 'a2': '{{r1}} {{a1}}',
- 'a3': '{{action_context.api_user}}'
- }
- runner_param_info = {'r1': {}, 'r2': {}}
- action_param_info = {'a1': {}, 'a2': {}, 'a3': {}}
- action_context = {'api_user': 'noob'}
+ "r1": "{{a1}}",
+ "r2": 1,
+ "a1": True,
+ "a2": "{{r1}} {{a1}}",
+ "a3": "{{action_context.api_user}}",
+ }
+ runner_param_info = {"r1": {}, "r2": {}}
+ action_param_info = {"a1": {}, "a2": {}, "a3": {}}
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': u'True', 'r2': 1})
- self.assertEqual(r_action_params, {'a1': True, 'a2': u'True True', 'a3': 'noob'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": "True", "r2": 1})
+ self.assertEqual(r_action_params, {"a1": True, "a2": "True True", "a3": "noob"})
def test_get_finalized_params_cross_talk_with_cast(self):
params = {
- 'r1': '{{a1}}',
- 'r2': 1,
- 'r3': 1,
- 'a1': True,
- 'a2': '{{r1}},{{a1}},{{a3}},{{r3}}',
- 'a3': '{{a1}}'
+ "r1": "{{a1}}",
+ "r2": 1,
+ "r3": 1,
+ "a1": True,
+ "a2": "{{r1}},{{a1}},{{a3}},{{r3}}",
+ "a3": "{{a1}}",
}
- runner_param_info = {'r1': {'type': 'boolean'}, 'r2': {'type': 'integer'}, 'r3': {}}
- action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'array'}, 'a3': {}}
- action_context = {'user': None}
+ runner_param_info = {
+ "r1": {"type": "boolean"},
+ "r2": {"type": "integer"},
+ "r3": {},
+ }
+ action_param_info = {
+ "a1": {"type": "boolean"},
+ "a2": {"type": "array"},
+ "a3": {},
+ }
+ action_context = {"user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': True, 'r2': 1, 'r3': 1})
- self.assertEqual(r_action_params, {'a1': True, 'a2': (True, True, True, 1), 'a3': u'True'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": True, "r2": 1, "r3": 1})
+ self.assertEqual(
+ r_action_params, {"a1": True, "a2": (True, True, True, 1), "a3": "True"}
+ )
def test_get_finalized_params_order(self):
- params = {
- 'r1': 'p1',
- 'r2': 'p2',
- 'r3': 'p3',
- 'a1': 'p4',
- 'a2': 'p5'
- }
- runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}}
- action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}}
- action_context = {'api_user': 'noob'}
+ params = {"r1": "p1", "r2": "p2", "r3": "p3", "a1": "p4", "a2": "p5"}
+ runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}}
+ action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}}
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': u'p1', 'r2': u'p2', 'r3': u'p3'})
- self.assertEqual(r_action_params, {'a1': u'p4', 'a2': u'p5'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": "p1", "r2": "p2", "r3": "p3"})
+ self.assertEqual(r_action_params, {"a1": "p4", "a2": "p5"})
params = {}
- runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}}
- action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}}
- action_context = {'api_user': 'noob'}
+ runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}}
+ action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}}
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'})
- self.assertEqual(r_action_params, {'a1': None, 'a2': u'a2'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"})
+ self.assertEqual(r_action_params, {"a1": None, "a2": "a2"})
params = {}
- runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {}}
- action_param_info = {'r1': {}, 'r2': {}, 'r3': {'default': 'a3'}}
- action_context = {'api_user': 'noob'}
+ runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {}}
+ action_param_info = {"r1": {}, "r2": {}, "r3": {"default": "a3"}}
+ action_context = {"api_user": "noob"}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"})
def test_get_finalized_params_non_existent_template_key_in_action_context(self):
params = {
- 'r1': 'foo',
- 'r2': 2,
- 'a1': 'i love tests',
- 'a2': '{{action_context.lorem_ipsum}}'
- }
- runner_param_info = {'r1': {'type': 'string'}, 'r2': {'type': 'integer'}}
- action_param_info = {'a1': {'type': 'string'}, 'a2': {'type': 'string'}}
- action_context = {'api_user': 'noob', 'source_channel': 'reddit'}
+ "r1": "foo",
+ "r2": 2,
+ "a1": "i love tests",
+ "a2": "{{action_context.lorem_ipsum}}",
+ }
+ runner_param_info = {"r1": {"type": "string"}, "r2": {"type": "integer"}}
+ action_param_info = {"a1": {"type": "string"}, "a2": {"type": "string"}}
+ action_context = {"api_user": "noob", "source_channel": "reddit"}
try:
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.fail('This should have thrown because we are trying to deref a key in ' +
- 'action context that ain\'t exist.')
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.fail(
+ "This should have thrown because we are trying to deref a key in "
+ + "action context that ain't exist."
+ )
except ParamException as e:
- error_msg = 'Failed to render parameter "a2": \'dict object\' ' + \
- 'has no attribute \'lorem_ipsum\''
+ error_msg = (
+ "Failed to render parameter \"a2\": 'dict object' "
+ + "has no attribute 'lorem_ipsum'"
+ )
self.assertIn(error_msg, six.text_type(e))
pass
def test_unicode_value_casting(self):
- rendered = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'}
- parameter_schemas = {'a1': {'type': 'string'}}
+ rendered = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"}
+ parameter_schemas = {"a1": {"type": "string"}}
- result = param_utils._cast_params(rendered=rendered,
- parameter_schemas=parameter_schemas)
+ result = param_utils._cast_params(
+ rendered=rendered, parameter_schemas=parameter_schemas
+ )
if six.PY3:
- expected = {
- 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2')
- }
+ expected = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")}
else:
expected = {
- 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc'
- u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2')
+ "a1": (
+ "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc"
+ "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2"
+ )
}
self.assertEqual(result, expected)
def test_get_finalized_params_with_casting_unicode_values(self):
- params = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'}
+ params = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"}
runner_param_info = {}
- action_param_info = {'a1': {'type': 'string'}}
+ action_param_info = {"a1": {"type": "string"}}
- action_context = {'user': None}
+ action_context = {"user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
+ runner_param_info, action_param_info, params, action_context
+ )
if six.PY3:
- expected_action_params = {
- 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2')
- }
+ expected_action_params = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")}
else:
expected_action_params = {
- 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc'
- u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2')
+ "a1": (
+ "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc"
+ "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2"
+ )
}
self.assertEqual(r_runner_params, {})
@@ -359,59 +391,53 @@ def test_get_finalized_params_with_dict(self):
# Note : In this test runner_params.r1 has a string value. However per runner_param_info the
# type is an integer. The expected type is considered and cast is performed accordingly.
params = {
- 'r1': '{{r2}}',
- 'r2': {'r2.1': 1},
- 'a1': True,
- 'a2': '{{a1}}',
- 'a3': {
- 'test': '{{a1}}',
- 'test1': '{{a4}}',
- 'test2': '{{a5}}',
+ "r1": "{{r2}}",
+ "r2": {"r2.1": 1},
+ "a1": True,
+ "a2": "{{a1}}",
+ "a3": {
+ "test": "{{a1}}",
+ "test1": "{{a4}}",
+ "test2": "{{a5}}",
},
- 'a4': 3,
- 'a5': ['1', '{{a1}}']
+ "a4": 3,
+ "a5": ["1", "{{a1}}"],
}
- runner_param_info = {'r1': {'type': 'object'}, 'r2': {'type': 'object'}}
+ runner_param_info = {"r1": {"type": "object"}, "r2": {"type": "object"}}
action_param_info = {
- 'a1': {
- 'type': 'boolean',
+ "a1": {
+ "type": "boolean",
},
- 'a2': {
- 'type': 'boolean',
+ "a2": {
+ "type": "boolean",
},
- 'a3': {
- 'type': 'object',
+ "a3": {
+ "type": "object",
},
- 'a4': {
- 'type': 'integer',
+ "a4": {
+ "type": "integer",
},
- 'a5': {
- 'type': 'array',
+ "a5": {
+ "type": "array",
},
}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, {'user': None})
- self.assertEqual(
- r_runner_params, {'r1': {'r2.1': 1}, 'r2': {'r2.1': 1}})
+ runner_param_info, action_param_info, params, {"user": None}
+ )
+ self.assertEqual(r_runner_params, {"r1": {"r2.1": 1}, "r2": {"r2.1": 1}})
self.assertEqual(
r_action_params,
{
- 'a1': True,
- 'a2': True,
- 'a3': {
- 'test': True,
- 'test1': 3,
- 'test2': [
- '1',
- True
- ],
+ "a1": True,
+ "a2": True,
+ "a3": {
+ "test": True,
+ "test1": 3,
+ "test2": ["1", True],
},
- 'a4': 3,
- 'a5': [
- '1',
- True
- ],
- }
+ "a4": 3,
+ "a5": ["1", True],
+ },
)
def test_get_finalized_params_with_list(self):
@@ -419,183 +445,177 @@ def test_get_finalized_params_with_list(self):
# type is an integer. The expected type is considered and cast is performed accordingly.
self.maxDiff = None
params = {
- 'r1': '{{r2}}',
- 'r2': ['1', '2'],
- 'a1': True,
- 'a2': 'Test',
- 'a3': 'Test2',
- 'a4': '{{a1}}',
- 'a5': ['{{a2}}', '{{a3}}'],
- 'a6': [
- ['{{r2}}', '{{a2}}'],
- ['{{a3}}', '{{a1}}'],
+ "r1": "{{r2}}",
+ "r2": ["1", "2"],
+ "a1": True,
+ "a2": "Test",
+ "a3": "Test2",
+ "a4": "{{a1}}",
+ "a5": ["{{a2}}", "{{a3}}"],
+ "a6": [
+ ["{{r2}}", "{{a2}}"],
+ ["{{a3}}", "{{a1}}"],
[
- '{{a7}}',
- 'This should be rendered as a string {{a1}}',
- '{{a1}} This, too, should be rendered as a string {{a1}}',
- ]
+ "{{a7}}",
+ "This should be rendered as a string {{a1}}",
+ "{{a1}} This, too, should be rendered as a string {{a1}}",
+ ],
],
- 'a7': 5,
+ "a7": 5,
}
- runner_param_info = {'r1': {'type': 'array'}, 'r2': {'type': 'array'}}
+ runner_param_info = {"r1": {"type": "array"}, "r2": {"type": "array"}}
action_param_info = {
- 'a1': {'type': 'boolean'},
- 'a2': {'type': 'string'},
- 'a3': {'type': 'string'},
- 'a4': {'type': 'boolean'},
- 'a5': {'type': 'array'},
- 'a6': {'type': 'array'},
- 'a7': {'type': 'integer'},
+ "a1": {"type": "boolean"},
+ "a2": {"type": "string"},
+ "a3": {"type": "string"},
+ "a4": {"type": "boolean"},
+ "a5": {"type": "array"},
+ "a6": {"type": "array"},
+ "a7": {"type": "integer"},
}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, {'user': None})
- self.assertEqual(r_runner_params, {'r1': ['1', '2'], 'r2': ['1', '2']})
+ runner_param_info, action_param_info, params, {"user": None}
+ )
+ self.assertEqual(r_runner_params, {"r1": ["1", "2"], "r2": ["1", "2"]})
self.assertEqual(
r_action_params,
{
- 'a1': True,
- 'a2': 'Test',
- 'a3': 'Test2',
- 'a4': True,
- 'a5': ['Test', 'Test2'],
- 'a6': [
- [['1', '2'], 'Test'],
- ['Test2', True],
+ "a1": True,
+ "a2": "Test",
+ "a3": "Test2",
+ "a4": True,
+ "a5": ["Test", "Test2"],
+ "a6": [
+ [["1", "2"], "Test"],
+ ["Test2", True],
[
5,
- u'This should be rendered as a string True',
- u'True This, too, should be rendered as a string True'
- ]
+ "This should be rendered as a string True",
+ "True This, too, should be rendered as a string True",
+ ],
],
- 'a7': 5,
- }
+ "a7": 5,
+ },
)
def test_get_finalized_params_with_cyclic_dependency(self):
- params = {'r1': '{{r2}}', 'r2': '{{r1}}'}
- runner_param_info = {'r1': {}, 'r2': {}}
+ params = {"r1": "{{r2}}", "r2": "{{r1}}"}
+ runner_param_info = {"r1": {}, "r2": {}}
action_param_info = {}
test_pass = True
try:
- param_utils.get_finalized_params(runner_param_info,
- action_param_info,
- params,
- {'user': None})
+ param_utils.get_finalized_params(
+ runner_param_info, action_param_info, params, {"user": None}
+ )
test_pass = False
except ParamException as e:
- test_pass = six.text_type(e).find('Cyclic') == 0
+ test_pass = six.text_type(e).find("Cyclic") == 0
self.assertTrue(test_pass)
def test_get_finalized_params_with_missing_dependency(self):
- params = {'r1': '{{r3}}', 'r2': '{{r3}}'}
- runner_param_info = {'r1': {}, 'r2': {}}
+ params = {"r1": "{{r3}}", "r2": "{{r3}}"}
+ runner_param_info = {"r1": {}, "r2": {}}
action_param_info = {}
test_pass = True
try:
- param_utils.get_finalized_params(runner_param_info,
- action_param_info,
- params,
- {'user': None})
+ param_utils.get_finalized_params(
+ runner_param_info, action_param_info, params, {"user": None}
+ )
test_pass = False
except ParamException as e:
- test_pass = six.text_type(e).find('Dependency') == 0
+ test_pass = six.text_type(e).find("Dependency") == 0
self.assertTrue(test_pass)
params = {}
- runner_param_info = {'r1': {'default': '{{r3}}'}, 'r2': {'default': '{{r3}}'}}
+ runner_param_info = {"r1": {"default": "{{r3}}"}, "r2": {"default": "{{r3}}"}}
action_param_info = {}
test_pass = True
try:
- param_utils.get_finalized_params(runner_param_info,
- action_param_info,
- params,
- {'user': None})
+ param_utils.get_finalized_params(
+ runner_param_info, action_param_info, params, {"user": None}
+ )
test_pass = False
except ParamException as e:
- test_pass = six.text_type(e).find('Dependency') == 0
+ test_pass = six.text_type(e).find("Dependency") == 0
self.assertTrue(test_pass)
def test_get_finalized_params_no_double_rendering(self):
- params = {
- 'r1': '{{ action_context.h1 }}{{ action_context.h2 }}'
- }
- runner_param_info = {'r1': {}}
+ params = {"r1": "{{ action_context.h1 }}{{ action_context.h2 }}"}
+ runner_param_info = {"r1": {}}
action_param_info = {}
- action_context = {
- 'h1': '{',
- 'h2': '{ missing }}',
- 'user': None
- }
+ action_context = {"h1": "{", "h2": "{ missing }}", "user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
- self.assertEqual(r_runner_params, {'r1': '{{ missing }}'})
+ runner_param_info, action_param_info, params, action_context
+ )
+ self.assertEqual(r_runner_params, {"r1": "{{ missing }}"})
self.assertEqual(r_action_params, {})
def test_get_finalized_params_jinja_filters(self):
- params = {'cmd': 'echo {{"1.6.0" | version_bump_minor}}'}
- runner_param_info = {'r1': {}}
- action_param_info = {'cmd': {}}
- action_context = {'user': None}
+ params = {"cmd": 'echo {{"1.6.0" | version_bump_minor}}'}
+ runner_param_info = {"r1": {}}
+ action_param_info = {"cmd": {}}
+ action_context = {"user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
+ runner_param_info, action_param_info, params, action_context
+ )
- self.assertEqual(r_action_params['cmd'], "echo 1.7.0")
+ self.assertEqual(r_action_params["cmd"], "echo 1.7.0")
def test_get_finalized_params_param_rendering_failure(self):
- params = {'cmd': '{{a2.foo}}', 'a2': 'test'}
- action_param_info = {'cmd': {}, 'a2': {}}
+ params = {"cmd": "{{a2.foo}}", "a2": "test"}
+ action_param_info = {"cmd": {}, "a2": {}}
expected_msg = 'Failed to render parameter "cmd": .*'
- self.assertRaisesRegexp(ParamException,
- expected_msg,
- param_utils.get_finalized_params,
- runnertype_parameter_info={},
- action_parameter_info=action_param_info,
- liveaction_parameters=params,
- action_context={'user': None})
+ self.assertRaisesRegexp(
+ ParamException,
+ expected_msg,
+ param_utils.get_finalized_params,
+ runnertype_parameter_info={},
+ action_parameter_info=action_param_info,
+ liveaction_parameters=params,
+ action_context={"user": None},
+ )
def test_get_finalized_param_object_contains_template_notation_in_the_value(self):
- runner_param_info = {'r1': {}}
+ runner_param_info = {"r1": {}}
action_param_info = {
- 'params': {
- 'type': 'object',
- 'default': {
- 'host': '{{host}}',
- 'port': '{{port}}',
- 'path': '/bar'}
+ "params": {
+ "type": "object",
+ "default": {"host": "{{host}}", "port": "{{port}}", "path": "/bar"},
}
}
- params = {
- 'host': 'lolcathost',
- 'port': 5555
- }
- action_context = {'user': None}
+ params = {"host": "lolcathost", "port": 5555}
+ action_context = {"user": None}
r_runner_params, r_action_params = param_utils.get_finalized_params(
- runner_param_info, action_param_info, params, action_context)
+ runner_param_info, action_param_info, params, action_context
+ )
- expected_params = {
- 'host': 'lolcathost',
- 'port': 5555,
- 'path': '/bar'
- }
- self.assertEqual(r_action_params['params'], expected_params)
+ expected_params = {"host": "lolcathost", "port": 5555, "path": "/bar"}
+ self.assertEqual(r_action_params["params"], expected_params)
def test_cast_param_referenced_action_doesnt_exist(self):
# Make sure the function throws if the action doesnt exist
expected_msg = 'Action with ref "foo.doesntexist" doesn\'t exist'
- self.assertRaisesRegexp(ValueError, expected_msg, action_param_utils.cast_params,
- action_ref='foo.doesntexist', params={})
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ action_param_utils.cast_params,
+ action_ref="foo.doesntexist",
+ params={},
+ )
def test_get_finalized_params_with_config(self):
- with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader:
+ with mock.patch(
+ "st2common.util.config_loader.ContentPackConfigLoader"
+ ) as config_loader:
config_loader().get_config.return_value = {
- 'generic_config_param': 'So generic'
+ "generic_config_param": "So generic"
}
params = {
- 'config_param': '{{config_context.generic_config_param}}',
+ "config_param": "{{config_context.generic_config_param}}",
}
liveaction_db = self._get_liveaction_model(params, True)
@@ -603,369 +623,327 @@ def test_get_finalized_params_with_config(self):
ParamsUtilsTest.runnertype_db.runner_parameters,
ParamsUtilsTest.action_db.parameters,
liveaction_db.parameters,
- liveaction_db.context)
- self.assertEqual(
- action_params.get('config_param'),
- 'So generic'
+ liveaction_db.context,
)
+ self.assertEqual(action_params.get("config_param"), "So generic")
def test_get_config(self):
- with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader:
- mock_config_return = {
- 'generic_config_param': 'So generic'
- }
+ with mock.patch(
+ "st2common.util.config_loader.ContentPackConfigLoader"
+ ) as config_loader:
+ mock_config_return = {"generic_config_param": "So generic"}
config_loader().get_config.return_value = mock_config_return
self.assertEqual(get_config(None, None), {})
- self.assertEqual(get_config('pack', None), {})
- self.assertEqual(get_config(None, 'user'), {})
- self.assertEqual(
- get_config('pack', 'user'), mock_config_return
- )
+ self.assertEqual(get_config("pack", None), {})
+ self.assertEqual(get_config(None, "user"), {})
+ self.assertEqual(get_config("pack", "user"), mock_config_return)
- config_loader.assert_called_with(pack_name='pack', user='user')
+ config_loader.assert_called_with(pack_name="pack", user="user")
config_loader().get_config.assert_called_once()
def _get_liveaction_model(self, params, with_config_context=False):
- status = 'initializing'
+ status = "initializing"
start_timestamp = date_utils.get_datetime_utc_now()
- action_ref = ResourceReference(name=ParamsUtilsTest.action_db.name,
- pack=ParamsUtilsTest.action_db.pack).ref
- liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp,
- action=action_ref, parameters=params)
+ action_ref = ResourceReference(
+ name=ParamsUtilsTest.action_db.name, pack=ParamsUtilsTest.action_db.pack
+ ).ref
+ liveaction_db = LiveActionDB(
+ status=status,
+ start_timestamp=start_timestamp,
+ action=action_ref,
+ parameters=params,
+ )
liveaction_db.context = {
- 'api_user': 'noob',
- 'source_channel': 'reddit',
+ "api_user": "noob",
+ "source_channel": "reddit",
}
if with_config_context:
- liveaction_db.context.update(
- {
- 'pack': 'generic',
- 'user': 'st2admin'
- }
- )
+ liveaction_db.context.update({"pack": "generic", "user": "st2admin"})
return liveaction_db
def test_get_value_from_datastore_through_render_live_params(self):
# Register datastore value to be refered by this test-case
register_kwargs = [
- {'name': 'test_key', 'value': 'foo'},
- {'name': 'user1:test_key', 'value': 'bar', 'scope': FULL_USER_SCOPE},
- {'name': '%s:test_key' % cfg.CONF.system_user.user, 'value': 'baz',
- 'scope': FULL_USER_SCOPE},
+ {"name": "test_key", "value": "foo"},
+ {"name": "user1:test_key", "value": "bar", "scope": FULL_USER_SCOPE},
+ {
+ "name": "%s:test_key" % cfg.CONF.system_user.user,
+ "value": "baz",
+ "scope": FULL_USER_SCOPE,
+ },
]
for kwargs in register_kwargs:
KeyValuePair.add_or_update(KeyValuePairDB(**kwargs))
# Assert that datastore value can be got via the Jinja expression from individual scopes.
- context = {'user': 'user1'}
+ context = {"user": "user1"}
param = {
- 'system_value': {'default': '{{ st2kv.system.test_key }}'},
- 'user_value': {'default': '{{ st2kv.user.test_key }}'},
+ "system_value": {"default": "{{ st2kv.system.test_key }}"},
+ "user_value": {"default": "{{ st2kv.user.test_key }}"},
}
- live_params = param_utils.render_live_params(runner_parameters={},
- action_parameters=param,
- params={},
- action_context=context)
+ live_params = param_utils.render_live_params(
+ runner_parameters={},
+ action_parameters=param,
+ params={},
+ action_context=context,
+ )
- self.assertEqual(live_params['system_value'], 'foo')
- self.assertEqual(live_params['user_value'], 'bar')
+ self.assertEqual(live_params["system_value"], "foo")
+ self.assertEqual(live_params["user_value"], "bar")
# Assert that datastore value in the user-scope that is registered by user1
# cannot be got by the operation of user2.
- context = {'user': 'user2'}
- param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}}
- live_params = param_utils.render_live_params(runner_parameters={},
- action_parameters=param,
- params={},
- action_context=context)
+ context = {"user": "user2"}
+ param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}}
+ live_params = param_utils.render_live_params(
+ runner_parameters={},
+ action_parameters=param,
+ params={},
+ action_context=context,
+ )
- self.assertEqual(live_params['user_value'], '')
+ self.assertEqual(live_params["user_value"], "")
# Assert that system-user's scope is selected when user and api_user parameter specified
context = {}
- param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}}
- live_params = param_utils.render_live_params(runner_parameters={},
- action_parameters=param,
- params={},
- action_context=context)
+ param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}}
+ live_params = param_utils.render_live_params(
+ runner_parameters={},
+ action_parameters=param,
+ params={},
+ action_context=context,
+ )
- self.assertEqual(live_params['user_value'], 'baz')
+ self.assertEqual(live_params["user_value"], "baz")
def test_get_live_params_with_additional_context(self):
- runner_param_info = {
- 'r1': {
- 'default': 'some'
- }
- }
- action_param_info = {
- 'r2': {
- 'default': '{{ r1 }}'
- }
- }
- params = {
- 'r3': 'lolcathost',
- 'r1': '{{ additional.stuff }}'
- }
- action_context = {'user': None}
- additional_contexts = {
- 'additional': {
- 'stuff': 'generic'
- }
- }
+ runner_param_info = {"r1": {"default": "some"}}
+ action_param_info = {"r2": {"default": "{{ r1 }}"}}
+ params = {"r3": "lolcathost", "r1": "{{ additional.stuff }}"}
+ action_context = {"user": None}
+ additional_contexts = {"additional": {"stuff": "generic"}}
live_params = param_utils.render_live_params(
- runner_param_info, action_param_info, params, action_context, additional_contexts)
+ runner_param_info,
+ action_param_info,
+ params,
+ action_context,
+ additional_contexts,
+ )
- expected_params = {
- 'r1': 'generic',
- 'r2': 'generic',
- 'r3': 'lolcathost'
- }
+ expected_params = {"r1": "generic", "r2": "generic", "r3": "lolcathost"}
self.assertEqual(live_params, expected_params)
def test_cyclic_dependency_friendly_error_message(self):
runner_param_info = {
- 'r1': {
- 'default': 'some',
- 'cyclic': 'cyclic value',
- 'morecyclic': 'cyclic value'
- }
- }
- action_param_info = {
- 'r2': {
- 'default': '{{ r1 }}'
+ "r1": {
+ "default": "some",
+ "cyclic": "cyclic value",
+ "morecyclic": "cyclic value",
}
}
+ action_param_info = {"r2": {"default": "{{ r1 }}"}}
params = {
- 'r3': 'lolcathost',
- 'cyclic': '{{ cyclic }}',
- 'morecyclic': '{{ morecyclic }}'
+ "r3": "lolcathost",
+ "cyclic": "{{ cyclic }}",
+ "morecyclic": "{{ morecyclic }}",
}
- action_context = {'user': None}
+ action_context = {"user": None}
- expected_msg = 'Cyclic dependency found in the following variables: cyclic, morecyclic'
- self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params,
- runner_param_info, action_param_info, params, action_context)
+ expected_msg = (
+ "Cyclic dependency found in the following variables: cyclic, morecyclic"
+ )
+ self.assertRaisesRegexp(
+ ParamException,
+ expected_msg,
+ param_utils.render_live_params,
+ runner_param_info,
+ action_param_info,
+ params,
+ action_context,
+ )
def test_unsatisfied_dependency_friendly_error_message(self):
runner_param_info = {
- 'r1': {
- 'default': 'some',
- }
- }
- action_param_info = {
- 'r2': {
- 'default': '{{ r1 }}'
+ "r1": {
+ "default": "some",
}
}
+ action_param_info = {"r2": {"default": "{{ r1 }}"}}
params = {
- 'r3': 'lolcathost',
- 'r4': '{{ variable_not_defined }}',
+ "r3": "lolcathost",
+ "r4": "{{ variable_not_defined }}",
}
- action_context = {'user': None}
+ action_context = {"user": None}
expected_msg = 'Dependency unsatisfied in variable "variable_not_defined"'
- self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params,
- runner_param_info, action_param_info, params, action_context)
+ self.assertRaisesRegexp(
+ ParamException,
+ expected_msg,
+ param_utils.render_live_params,
+ runner_param_info,
+ action_param_info,
+ params,
+ action_context,
+ )
def test_add_default_templates_to_live_params(self):
- """Test addition of template values in defaults to live params
- """
+ """Test addition of template values in defaults to live params"""
# Ensure parameter is skipped if the parameter has immutable set to true in schema
schemas = [
{
- 'templateparam': {
- 'default': '{{ 3 | int }}',
- 'type': 'integer',
- 'immutable': True
+ "templateparam": {
+ "default": "{{ 3 | int }}",
+ "type": "integer",
+ "immutable": True,
}
}
]
- context = {
- 'templateparam': '3'
- }
+ context = {"templateparam": "3"}
result = param_utils._cast_params_from({}, context, schemas)
self.assertEqual(result, {})
# Test with no live params, and two parameters - one should make it through because
# it was a template, and the other shouldn't because its default wasn't a template
- schemas = [
- {
- 'templateparam': {
- 'default': '{{ 3 | int }}',
- 'type': 'integer'
- }
- }
- ]
- context = {
- 'templateparam': '3'
- }
+ schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}]
+ context = {"templateparam": "3"}
result = param_utils._cast_params_from({}, context, schemas)
- self.assertEqual(result, {'templateparam': 3})
+ self.assertEqual(result, {"templateparam": 3})
# Ensure parameter is skipped if the value in context is identical to default
- schemas = [
- {
- 'nottemplateparam': {
- 'default': '4',
- 'type': 'integer'
- }
- }
- ]
+ schemas = [{"nottemplateparam": {"default": "4", "type": "integer"}}]
context = {
- 'nottemplateparam': '4',
+ "nottemplateparam": "4",
}
result = param_utils._cast_params_from({}, context, schemas)
self.assertEqual(result, {})
# Ensure parameter is skipped if the parameter doesn't have a default
- schemas = [
- {
- 'nottemplateparam': {
- 'type': 'integer'
- }
- }
- ]
+ schemas = [{"nottemplateparam": {"type": "integer"}}]
context = {
- 'nottemplateparam': '4',
+ "nottemplateparam": "4",
}
result = param_utils._cast_params_from({}, context, schemas)
self.assertEqual(result, {})
# Skip if the default value isn't a Jinja expression
- schemas = [
- {
- 'nottemplateparam': {
- 'default': '5',
- 'type': 'integer'
- }
- }
- ]
+ schemas = [{"nottemplateparam": {"default": "5", "type": "integer"}}]
context = {
- 'nottemplateparam': '4',
+ "nottemplateparam": "4",
}
result = param_utils._cast_params_from({}, context, schemas)
self.assertEqual(result, {})
# Ensure parameter is skipped if the parameter is being overridden
- schemas = [
- {
- 'templateparam': {
- 'default': '{{ 3 | int }}',
- 'type': 'integer'
- }
- }
- ]
+ schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}]
context = {
- 'templateparam': '4',
+ "templateparam": "4",
}
- result = param_utils._cast_params_from({'templateparam': '4'}, context, schemas)
- self.assertEqual(result, {'templateparam': 4})
+ result = param_utils._cast_params_from({"templateparam": "4"}, context, schemas)
+ self.assertEqual(result, {"templateparam": 4})
def test_render_final_params_and_shell_script_action_command_strings(self):
runner_parameters = {}
action_db_parameters = {
- 'project': {
- 'type': 'string',
- 'default': 'st2',
- 'position': 0,
+ "project": {
+ "type": "string",
+ "default": "st2",
+ "position": 0,
},
- 'version': {
- 'type': 'string',
- 'position': 1,
- 'required': True
+ "version": {"type": "string", "position": 1, "required": True},
+ "fork": {
+ "type": "string",
+ "position": 2,
+ "default": "StackStorm",
},
- 'fork': {
- 'type': 'string',
- 'position': 2,
- 'default': 'StackStorm',
+ "branch": {
+ "type": "string",
+ "position": 3,
+ "default": "master",
},
- 'branch': {
- 'type': 'string',
- 'position': 3,
- 'default': 'master',
+ "update_changelog": {"type": "boolean", "position": 4, "default": False},
+ "local_repo": {
+ "type": "string",
+ "position": 5,
},
- 'update_changelog': {
- 'type': 'boolean',
- 'position': 4,
- 'default': False
- },
- 'local_repo': {
- 'type': 'string',
- 'position': 5,
- }
}
context = {}
# 1. All default values used
live_action_db_parameters = {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'local_repo': '/tmp/repo'
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "local_repo": "/tmp/repo",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
- self.assertDictEqual(action_params, {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repo'
- })
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repo",
+ },
+ )
# 2. Some default values used
live_action_db_parameters = {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'update_changelog': True,
- 'local_repo': '/tmp/repob'
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "update_changelog": True,
+ "local_repo": "/tmp/repob",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
- self.assertDictEqual(action_params, {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'branch': 'master', # default value used
- 'update_changelog': True, # default value used
- 'local_repo': '/tmp/repob'
- })
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "branch": "master", # default value used
+ "update_changelog": True, # default value used
+ "local_repo": "/tmp/repob",
+ },
+ )
# 3. None is specified for a boolean parameter, should use a default
live_action_db_parameters = {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'update_changelog': None,
- 'local_repo': '/tmp/repoc'
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "update_changelog": None,
+ "local_repo": "/tmp/repoc",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
-
- self.assertDictEqual(action_params, {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repoc'
- })
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
+
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repoc",
+ },
+ )
diff --git a/st2common/tests/unit/test_paramiko_command_action_model.py b/st2common/tests/unit/test_paramiko_command_action_model.py
index 2ce7bbfed3..0d023d4f8a 100644
--- a/st2common/tests/unit/test_paramiko_command_action_model.py
+++ b/st2common/tests/unit/test_paramiko_command_action_model.py
@@ -18,76 +18,84 @@
from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction
-__all__ = [
- 'ParamikoRemoteCommandActionTestCase'
-]
+__all__ = ["ParamikoRemoteCommandActionTestCase"]
class ParamikoRemoteCommandActionTestCase(unittest2.TestCase):
-
def test_get_command_string_no_env_vars(self):
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- 'echo boo bah baz')
- ex = 'cd /tmp && echo boo bah baz'
+ "echo boo bah baz"
+ )
+ ex = "cd /tmp && echo boo bah baz"
self.assertEqual(cmd_action.get_full_command_string(), ex)
# With sudo
cmd_action.sudo = True
- ex = 'sudo -E -- bash -c \'cd /tmp && echo boo bah baz\''
+ ex = "sudo -E -- bash -c 'cd /tmp && echo boo bah baz'"
self.assertEqual(cmd_action.get_full_command_string(), ex)
# Executing a path command requires user to provide an escaped input.
# E.g. st2 run core.remote hosts=localhost cmd='"/tmp/space stuff.sh"'
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- '"/t/space stuff.sh"')
+ '"/t/space stuff.sh"'
+ )
ex = 'cd /tmp && "/t/space stuff.sh"'
self.assertEqual(cmd_action.get_full_command_string(), ex)
# sudo_password provided
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- 'echo boo bah baz')
+ "echo boo bah baz"
+ )
cmd_action.sudo = True
- cmd_action.sudo_password = 'sudo pass'
+ cmd_action.sudo_password = "sudo pass"
- ex = ('set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- '
- 'bash -c \'cd /tmp && echo boo bah baz\'')
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- "
+ "bash -c 'cd /tmp && echo boo bah baz'"
+ )
self.assertEqual(cmd_action.get_full_command_string(), ex)
def test_get_command_string_with_env_vars(self):
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- 'echo boo bah baz')
- cmd_action.env_vars = {'FOO': 'BAR', 'BAR': 'BEET CAFE'}
- ex = 'export BAR=\'BEET CAFE\' ' + \
- 'FOO=BAR' + \
- ' && cd /tmp && echo boo bah baz'
+ "echo boo bah baz"
+ )
+ cmd_action.env_vars = {"FOO": "BAR", "BAR": "BEET CAFE"}
+ ex = "export BAR='BEET CAFE' " + "FOO=BAR" + " && cd /tmp && echo boo bah baz"
self.assertEqual(cmd_action.get_full_command_string(), ex)
# With sudo
cmd_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'export FOO=BAR ' + \
- 'BAR=\'"\'"\'BEET CAFE\'"\'"\'' + \
- ' && cd /tmp && echo boo bah baz\''
- ex = 'sudo -E -- bash -c ' + \
- '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \
- 'FOO=BAR' + \
- ' && cd /tmp && echo boo bah baz\''
+ ex = (
+ "sudo -E -- bash -c "
+ + "'export FOO=BAR "
+ + "BAR='\"'\"'BEET CAFE'\"'\"'"
+ + " && cd /tmp && echo boo bah baz'"
+ )
+ ex = (
+ "sudo -E -- bash -c "
+ + "'export BAR='\"'\"'BEET CAFE'\"'\"' "
+ + "FOO=BAR"
+ + " && cd /tmp && echo boo bah baz'"
+ )
self.assertEqual(cmd_action.get_full_command_string(), ex)
# with sudo_password
cmd_action.sudo = True
- cmd_action.sudo_password = 'sudo pass'
- ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \
- '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \
- 'FOO=BAR HISTFILE=/dev/null HISTSIZE=0' + \
- ' && cd /tmp && echo boo bah baz\''
+ cmd_action.sudo_password = "sudo pass"
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c "
+ + "'export BAR='\"'\"'BEET CAFE'\"'\"' "
+ + "FOO=BAR HISTFILE=/dev/null HISTSIZE=0"
+ + " && cd /tmp && echo boo bah baz'"
+ )
self.assertEqual(cmd_action.get_full_command_string(), ex)
def test_get_command_string_no_user(self):
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- 'echo boo bah baz')
+ "echo boo bah baz"
+ )
cmd_action.user = None
- ex = 'cd /tmp && echo boo bah baz'
+ ex = "cd /tmp && echo boo bah baz"
self.assertEqual(cmd_action.get_full_command_string(), ex)
# Executing a path command requires user to provide an escaped input.
@@ -99,25 +107,28 @@ def test_get_command_string_no_user(self):
def test_get_command_string_no_user_env_vars(self):
cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action(
- 'echo boo bah baz')
+ "echo boo bah baz"
+ )
cmd_action.user = None
- cmd_action.env_vars = {'FOO': 'BAR'}
- ex = 'export FOO=BAR && cd /tmp && echo boo bah baz'
+ cmd_action.env_vars = {"FOO": "BAR"}
+ ex = "export FOO=BAR && cd /tmp && echo boo bah baz"
self.assertEqual(cmd_action.get_full_command_string(), ex)
@staticmethod
def _get_test_command_action(command):
- cmd_action = ParamikoRemoteCommandAction('fixtures.remote_command',
- '55ce39d532ed3543aecbe71d',
- command=command,
- env_vars={},
- on_behalf_user='svetlana',
- user='estee',
- password=None,
- private_key='---PRIVATE-KEY---',
- hosts='127.0.0.1',
- parallel=True,
- sudo=False,
- timeout=None,
- cwd='/tmp')
+ cmd_action = ParamikoRemoteCommandAction(
+ "fixtures.remote_command",
+ "55ce39d532ed3543aecbe71d",
+ command=command,
+ env_vars={},
+ on_behalf_user="svetlana",
+ user="estee",
+ password=None,
+ private_key="---PRIVATE-KEY---",
+ hosts="127.0.0.1",
+ parallel=True,
+ sudo=False,
+ timeout=None,
+ cwd="/tmp",
+ )
return cmd_action
diff --git a/st2common/tests/unit/test_paramiko_script_action_model.py b/st2common/tests/unit/test_paramiko_script_action_model.py
index e05350e46d..3efae1053f 100644
--- a/st2common/tests/unit/test_paramiko_script_action_model.py
+++ b/st2common/tests/unit/test_paramiko_script_action_model.py
@@ -18,75 +18,81 @@
from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction
-__all__ = [
- 'ParamikoRemoteScriptActionTestCase'
-]
+__all__ = ["ParamikoRemoteScriptActionTestCase"]
class ParamikoRemoteScriptActionTestCase(unittest2.TestCase):
-
def test_get_command_string_no_env_vars(self):
script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action()
- ex = 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\''
+ ex = "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'"
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'cd /tmp && ' + \
- '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\''
+ ex = (
+ "sudo -E -- bash -c "
+ + "'cd /tmp && "
+ + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
# with sudo password
script_action.sudo = True
- script_action.sudo_password = 'sudo pass'
- ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \
- '\'cd /tmp && ' + \
- '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\''
+ script_action.sudo_password = "sudo pass"
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c "
+ + "'cd /tmp && "
+ + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
def test_get_command_string_with_env_vars(self):
script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action()
script_action.env_vars = {
- 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d',
- 'FOO': 'BAR BAZ BOOZ'
+ "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d",
+ "FOO": "BAR BAZ BOOZ",
}
- ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \
- 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \
- 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\''
+ ex = (
+ "export FOO='BAR BAZ BOOZ' "
+ + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && "
+ + "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \
- 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \
- 'cd /tmp && ' + \
- '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\''
+ ex = (
+ "sudo -E -- bash -c "
+ + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' "
+ + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && "
+ + "cd /tmp && "
+ + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
# with sudo password
script_action.sudo = True
- script_action.sudo_password = 'sudo pass'
+ script_action.sudo_password = "sudo pass"
- ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \
- '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' HISTFILE=/dev/null HISTSIZE=0 ' + \
- 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \
- 'cd /tmp && ' + \
- '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\''
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c "
+ + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' HISTFILE=/dev/null HISTSIZE=0 "
+ + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && "
+ + "cd /tmp && "
+ + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
def test_get_command_string_no_script_args_no_env_args(self):
script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action()
script_action.named_args = {}
script_action.positional_args = []
- ex = 'cd /tmp && /tmp/remote_script.sh'
+ ex = "cd /tmp && /tmp/remote_script.sh"
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'cd /tmp && /tmp/remote_script.sh\''
+ ex = "sudo -E -- bash -c " + "'cd /tmp && /tmp/remote_script.sh'"
self.assertEqual(script_action.get_full_command_string(), ex)
def test_get_command_string_no_script_args_with_env_args(self):
@@ -94,88 +100,100 @@ def test_get_command_string_no_script_args_with_env_args(self):
script_action.named_args = {}
script_action.positional_args = []
script_action.env_vars = {
- 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d',
- 'FOO': 'BAR BAZ BOOZ'
+ "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d",
+ "FOO": "BAR BAZ BOOZ",
}
- ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \
- 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \
- 'cd /tmp && /tmp/remote_script.sh'
+ ex = (
+ "export FOO='BAR BAZ BOOZ' "
+ + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && "
+ + "cd /tmp && /tmp/remote_script.sh"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \
- 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \
- 'cd /tmp && ' + \
- '/tmp/remote_script.sh\''
+ ex = (
+ "sudo -E -- bash -c "
+ + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' "
+ + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && "
+ + "cd /tmp && "
+ + "/tmp/remote_script.sh'"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
def test_script_path_shell_injection_safe(self):
script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action()
- test_path = '/tmp/remote script.sh'
+ test_path = "/tmp/remote script.sh"
script_action.remote_script = test_path
script_action.named_args = {}
script_action.positional_args = []
- ex = 'cd /tmp && \'/tmp/remote script.sh\''
+ ex = "cd /tmp && '/tmp/remote script.sh'"
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\''
+ ex = "sudo -E -- bash -c " + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''"
self.assertEqual(script_action.get_full_command_string(), ex)
# With sudo_password
script_action.sudo = True
- script_action.sudo_password = 'sudo pass'
+ script_action.sudo_password = "sudo pass"
- ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \
- '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\''
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c "
+ + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
def test_script_path_shell_injection_safe_with_env_vars(self):
script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action()
- test_path = '/tmp/remote script.sh'
+ test_path = "/tmp/remote script.sh"
script_action.remote_script = test_path
script_action.named_args = {}
script_action.positional_args = []
- script_action.env_vars = {'FOO': 'BAR'}
- ex = 'export FOO=BAR && cd /tmp && \'/tmp/remote script.sh\''
+ script_action.env_vars = {"FOO": "BAR"}
+ ex = "export FOO=BAR && cd /tmp && '/tmp/remote script.sh'"
self.assertEqual(script_action.get_full_command_string(), ex)
# Test with sudo
script_action.sudo = True
- ex = 'sudo -E -- bash -c ' + \
- '\'export FOO=BAR && ' + \
- 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\''
+ ex = (
+ "sudo -E -- bash -c "
+ + "'export FOO=BAR && "
+ + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
# With sudo_password
script_action.sudo = True
- script_action.sudo_password = 'sudo pass'
+ script_action.sudo_password = "sudo pass"
- ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \
- '\'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && ' + \
- 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\''
+ ex = (
+ "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c "
+ + "'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && "
+ + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''"
+ )
self.assertEqual(script_action.get_full_command_string(), ex)
@staticmethod
def _get_test_script_action():
- local_script_path = '/opt/stackstorm/packs/fixtures/actions/remote_script.sh'
- script_action = ParamikoRemoteScriptAction('fixtures.remote_script',
- '55ce39d532ed3543aecbe71d',
- local_script_path,
- '/opt/stackstorm/packs/fixtures/actions/lib/',
- named_args={'song': 'b s'},
- positional_args=['taylor swift'],
- env_vars={},
- on_behalf_user='stanley',
- user='vagrant',
- private_key='/home/vagrant/.ssh/stanley_rsa',
- remote_dir='/tmp',
- hosts=['127.0.0.1'],
- parallel=True,
- sudo=False,
- timeout=60, cwd='/tmp')
+ local_script_path = "/opt/stackstorm/packs/fixtures/actions/remote_script.sh"
+ script_action = ParamikoRemoteScriptAction(
+ "fixtures.remote_script",
+ "55ce39d532ed3543aecbe71d",
+ local_script_path,
+ "/opt/stackstorm/packs/fixtures/actions/lib/",
+ named_args={"song": "b s"},
+ positional_args=["taylor swift"],
+ env_vars={},
+ on_behalf_user="stanley",
+ user="vagrant",
+ private_key="/home/vagrant/.ssh/stanley_rsa",
+ remote_dir="/tmp",
+ hosts=["127.0.0.1"],
+ parallel=True,
+ sudo=False,
+ timeout=60,
+ cwd="/tmp",
+ )
return script_action
diff --git a/st2common/tests/unit/test_persistence.py b/st2common/tests/unit/test_persistence.py
index 14f25731ff..6fce36c18d 100644
--- a/st2common/tests/unit/test_persistence.py
+++ b/st2common/tests/unit/test_persistence.py
@@ -27,7 +27,6 @@
class TestPersistence(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestPersistence, cls).setUpClass()
@@ -38,7 +37,7 @@ def tearDown(self):
super(TestPersistence, self).tearDown()
def test_crud(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'a': 1})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"a": 1})
obj1 = self.access.add_or_update(obj1)
obj2 = self.access.get(name=obj1.name)
self.assertIsNotNone(obj2)
@@ -59,16 +58,16 @@ def test_crud(self):
self.assertIsNone(obj2)
def test_count(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"})
obj1 = self.access.add_or_update(obj1)
- obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'})
+ obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"})
obj2 = self.access.add_or_update(obj2)
self.assertEqual(self.access.count(), 2)
def test_get_all(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"})
obj1 = self.access.add_or_update(obj1)
- obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'})
+ obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"})
obj2 = self.access.add_or_update(obj2)
objs = self.access.get_all()
self.assertIsNotNone(objs)
@@ -76,33 +75,35 @@ def test_get_all(self):
self.assertListEqual(list(objs), [obj1, obj2])
def test_query_by_id(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"})
obj1 = self.access.add_or_update(obj1)
obj2 = self.access.get_by_id(str(obj1.id))
self.assertIsNotNone(obj2)
self.assertEqual(obj1.id, obj2.id)
self.assertEqual(obj1.name, obj2.name)
self.assertDictEqual(obj1.context, obj2.context)
- self.assertRaises(StackStormDBObjectNotFoundError,
- self.access.get_by_id, str(bson.ObjectId()))
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, self.access.get_by_id, str(bson.ObjectId())
+ )
def test_query_by_name(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"})
obj1 = self.access.add_or_update(obj1)
obj2 = self.access.get_by_name(obj1.name)
self.assertIsNotNone(obj2)
self.assertEqual(obj1.id, obj2.id)
self.assertEqual(obj1.name, obj2.name)
self.assertDictEqual(obj1.context, obj2.context)
- self.assertRaises(StackStormDBObjectNotFoundError, self.access.get_by_name,
- uuid.uuid4().hex)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, self.access.get_by_name, uuid.uuid4().hex
+ )
def test_query_filter(self):
- obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'})
+ obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"})
obj1 = self.access.add_or_update(obj1)
- obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'})
+ obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"})
obj2 = self.access.add_or_update(obj2)
- objs = self.access.query(context__user='system')
+ objs = self.access.query(context__user="system")
self.assertIsNotNone(objs)
self.assertGreater(len(objs), 0)
self.assertEqual(obj1.id, objs[0].id)
@@ -113,17 +114,17 @@ def test_null_filter(self):
obj1 = FakeModelDB(name=uuid.uuid4().hex)
obj1 = self.access.add_or_update(obj1)
- objs = self.access.query(index='null')
+ objs = self.access.query(index="null")
self.assertEqual(len(objs), 1)
self.assertEqual(obj1.id, objs[0].id)
self.assertEqual(obj1.name, objs[0].name)
- self.assertIsNone(getattr(obj1, 'index', None))
+ self.assertIsNone(getattr(obj1, "index", None))
objs = self.access.query(index=None)
self.assertEqual(len(objs), 1)
self.assertEqual(obj1.id, objs[0].id)
self.assertEqual(obj1.name, objs[0].name)
- self.assertIsNone(getattr(obj1, 'index', None))
+ self.assertIsNone(getattr(obj1, "index", None))
def test_datetime_range(self):
base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0))
@@ -132,12 +133,12 @@ def test_datetime_range(self):
obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp)
self.access.add_or_update(obj)
- dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z'
+ dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z"
objs = self.access.query(timestamp=dt_range)
self.assertEqual(len(objs), 10)
self.assertLess(objs[0].timestamp, objs[9].timestamp)
- dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z'
+ dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z"
objs = self.access.query(timestamp=dt_range)
self.assertEqual(len(objs), 10)
self.assertLess(objs[9].timestamp, objs[0].timestamp)
@@ -146,52 +147,61 @@ def test_pagination(self):
count = 100
page_size = 25
pages = int(count / page_size)
- users = ['Peter', 'Susan', 'Edmund', 'Lucy']
+ users = ["Peter", "Susan", "Edmund", "Lucy"]
for user in users:
- context = {'user': user}
+ context = {"user": user}
for i in range(count):
- self.access.add_or_update(FakeModelDB(name=uuid.uuid4().hex,
- context=context, index=i))
+ self.access.add_or_update(
+ FakeModelDB(name=uuid.uuid4().hex, context=context, index=i)
+ )
self.assertEqual(self.access.count(), len(users) * count)
for user in users:
for i in range(pages):
offset = i * page_size
- objs = self.access.query(context__user=user, order_by=['index'],
- offset=offset, limit=page_size)
+ objs = self.access.query(
+ context__user=user,
+ order_by=["index"],
+ offset=offset,
+ limit=page_size,
+ )
self.assertEqual(len(objs), page_size)
for j in range(page_size):
- self.assertEqual(objs[j].context['user'], user)
+ self.assertEqual(objs[j].context["user"], user)
self.assertEqual(objs[j].index, (i * page_size) + j)
def test_sort_multiple(self):
count = 60
base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0))
for i in range(count):
- category = 'type1' if i % 2 else 'type2'
+ category = "type1" if i % 2 else "type2"
timestamp = base + datetime.timedelta(seconds=i)
- obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp, category=category)
+ obj = FakeModelDB(
+ name=uuid.uuid4().hex, timestamp=timestamp, category=category
+ )
self.access.add_or_update(obj)
- objs = self.access.query(order_by=['category', 'timestamp'])
+ objs = self.access.query(order_by=["category", "timestamp"])
self.assertEqual(len(objs), count)
for i in range(count):
- category = 'type1' if i < count / 2 else 'type2'
+ category = "type1" if i < count / 2 else "type2"
self.assertEqual(objs[i].category, category)
self.assertLess(objs[0].timestamp, objs[(int(count / 2)) - 1].timestamp)
- self.assertLess(objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp)
+ self.assertLess(
+ objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp
+ )
self.assertLess(objs[int(count / 2)].timestamp, objs[count - 1].timestamp)
def test_escaped_field(self):
- context = {'a.b.c': 'abc'}
+ context = {"a.b.c": "abc"}
obj1 = FakeModelDB(name=uuid.uuid4().hex, context=context)
obj2 = self.access.add_or_update(obj1)
# Check that the original dict has not been altered.
- self.assertIn('a.b.c', list(context.keys()))
- self.assertNotIn('a\uff0eb\uff0ec', list(context.keys()))
+ self.assertIn("a.b.c", list(context.keys()))
+ self.assertNotIn("a\uff0eb\uff0ec", list(context.keys()))
# Check to_python has run and context is not left escaped.
self.assertDictEqual(obj2.context, context)
@@ -206,26 +216,26 @@ def test_query_only_fields(self):
count = 5
ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0))
for i in range(count):
- category = 'type1'
- obj = FakeModelDB(name='test-%s' % (i), timestamp=ts, category=category)
+ category = "type1"
+ obj = FakeModelDB(name="test-%s" % (i), timestamp=ts, category=category)
self.access.add_or_update(obj)
model_dbs = FakeModel.query()
- self.assertEqual(model_dbs[0].name, 'test-0')
+ self.assertEqual(model_dbs[0].name, "test-0")
self.assertEqual(model_dbs[0].timestamp, ts)
- self.assertEqual(model_dbs[0].category, 'type1')
+ self.assertEqual(model_dbs[0].category, "type1")
# only id
- model_dbs = FakeModel.query(only_fields=['id'])
+ model_dbs = FakeModel.query(only_fields=["id"])
self.assertTrue(model_dbs[0].id)
self.assertEqual(model_dbs[0].name, None)
self.assertEqual(model_dbs[0].timestamp, None)
self.assertEqual(model_dbs[0].category, None)
# only name - note: id is always included
- model_dbs = FakeModel.query(only_fields=['name'])
+ model_dbs = FakeModel.query(only_fields=["name"])
self.assertTrue(model_dbs[0].id)
- self.assertEqual(model_dbs[0].name, 'test-0')
+ self.assertEqual(model_dbs[0].name, "test-0")
self.assertEqual(model_dbs[0].timestamp, None)
self.assertEqual(model_dbs[0].category, None)
@@ -233,28 +243,28 @@ def test_query_exclude_fields(self):
count = 5
ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0))
for i in range(count):
- category = 'type1'
- obj = FakeModelDB(name='test-2-%s' % (i), timestamp=ts, category=category)
+ category = "type1"
+ obj = FakeModelDB(name="test-2-%s" % (i), timestamp=ts, category=category)
self.access.add_or_update(obj)
model_dbs = FakeModel.query()
- self.assertEqual(model_dbs[0].name, 'test-2-0')
+ self.assertEqual(model_dbs[0].name, "test-2-0")
self.assertEqual(model_dbs[0].timestamp, ts)
- self.assertEqual(model_dbs[0].category, 'type1')
+ self.assertEqual(model_dbs[0].category, "type1")
- model_dbs = FakeModel.query(exclude_fields=['name'])
+ model_dbs = FakeModel.query(exclude_fields=["name"])
self.assertTrue(model_dbs[0].id)
self.assertEqual(model_dbs[0].name, None)
self.assertEqual(model_dbs[0].timestamp, ts)
- self.assertEqual(model_dbs[0].category, 'type1')
+ self.assertEqual(model_dbs[0].category, "type1")
- model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp'])
+ model_dbs = FakeModel.query(exclude_fields=["name", "timestamp"])
self.assertTrue(model_dbs[0].id)
self.assertEqual(model_dbs[0].name, None)
self.assertEqual(model_dbs[0].timestamp, None)
- self.assertEqual(model_dbs[0].category, 'type1')
+ self.assertEqual(model_dbs[0].category, "type1")
- model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp', 'category'])
+ model_dbs = FakeModel.query(exclude_fields=["name", "timestamp", "category"])
self.assertTrue(model_dbs[0].id)
self.assertEqual(model_dbs[0].name, None)
self.assertEqual(model_dbs[0].timestamp, None)
diff --git a/st2common/tests/unit/test_persistence_change_revision.py b/st2common/tests/unit/test_persistence_change_revision.py
index f9e31e1c73..c268fa86b5 100644
--- a/st2common/tests/unit/test_persistence_change_revision.py
+++ b/st2common/tests/unit/test_persistence_change_revision.py
@@ -24,7 +24,6 @@
class TestChangeRevision(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestChangeRevision, cls).setUpClass()
@@ -35,7 +34,7 @@ def tearDown(self):
super(TestChangeRevision, self).tearDown()
def test_crud(self):
- initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1})
+ initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1})
# Test create
created = self.access.add_or_update(initial)
@@ -47,14 +46,14 @@ def test_crud(self):
self.assertDictEqual(created.context, retrieved.context)
# Test update
- retrieved = self.access.update(retrieved, context={'a': 2})
+ retrieved = self.access.update(retrieved, context={"a": 2})
updated = self.access.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
self.assertEqual(retrieved.rev, updated.rev)
self.assertDictEqual(retrieved.context, updated.context)
# Test add or update
- retrieved.context = {'a': 1, 'b': 2}
+ retrieved.context = {"a": 1, "b": 2}
retrieved = self.access.add_or_update(retrieved)
updated = self.access.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
@@ -65,13 +64,11 @@ def test_crud(self):
created.delete()
self.assertRaises(
- db_exc.StackStormDBObjectNotFoundError,
- self.access.get_by_id,
- doc_id
+ db_exc.StackStormDBObjectNotFoundError, self.access.get_by_id, doc_id
)
def test_write_conflict(self):
- initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1})
+ initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1})
# Prep record
created = self.access.add_or_update(initial)
@@ -83,7 +80,7 @@ def test_write_conflict(self):
retrieved2 = self.access.get_by_id(doc_id)
# Test update on instance 1, expect success
- retrieved1 = self.access.update(retrieved1, context={'a': 2})
+ retrieved1 = self.access.update(retrieved1, context={"a": 2})
updated = self.access.get_by_id(doc_id)
self.assertNotEqual(created.rev, updated.rev)
self.assertEqual(retrieved1.rev, updated.rev)
@@ -94,5 +91,5 @@ def test_write_conflict(self):
db_exc.StackStormDBObjectWriteConflictError,
self.access.update,
retrieved2,
- context={'a': 1, 'b': 2}
+ context={"a": 1, "b": 2},
)
diff --git a/st2common/tests/unit/test_plugin_loader.py b/st2common/tests/unit/test_plugin_loader.py
index 4b78b6b4cc..4641b66e9c 100644
--- a/st2common/tests/unit/test_plugin_loader.py
+++ b/st2common/tests/unit/test_plugin_loader.py
@@ -24,8 +24,8 @@
import st2common.util.loader as plugin_loader
-PLUGIN_FOLDER = 'loadableplugin'
-SRC_RELATIVE = os.path.join('../resources', PLUGIN_FOLDER)
+PLUGIN_FOLDER = "loadableplugin"
+SRC_RELATIVE = os.path.join("../resources", PLUGIN_FOLDER)
SRC_ROOT = os.path.join(os.path.abspath(os.path.dirname(__file__)), SRC_RELATIVE)
@@ -51,64 +51,71 @@ def tearDown(self):
sys.path = LoaderTest.sys_path
def test_module_load_from_file(self):
- plugin_path = os.path.join(SRC_ROOT, 'plugin/standaloneplugin.py')
+ plugin_path = os.path.join(SRC_ROOT, "plugin/standaloneplugin.py")
plugin_classes = plugin_loader.register_plugin(
- LoaderTest.DummyPlugin, plugin_path)
+ LoaderTest.DummyPlugin, plugin_path
+ )
# Even though there are two classes in that file, only one
# matches the specs of DummyPlugin class.
self.assertEqual(1, len(plugin_classes))
# Validate sys.path now contains the plugin directory.
- self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, 'plugin')), sys.path)
+ self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, "plugin")), sys.path)
# Validate the individual plugins
for plugin_class in plugin_classes:
try:
plugin_instance = plugin_class()
ret_val = plugin_instance.do_work()
- self.assertIsNotNone(ret_val, 'Should be non-null.')
+ self.assertIsNotNone(ret_val, "Should be non-null.")
except:
pass
def test_module_load_from_file_fail(self):
try:
- plugin_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py')
+ plugin_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py")
plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_path)
- self.assertTrue(False, 'Import error is expected.')
+ self.assertTrue(False, "Import error is expected.")
except ImportError:
self.assertTrue(True)
def test_syspath_unchanged_load_multiple_plugins(self):
- plugin_1_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py')
+ plugin_1_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py")
try:
- plugin_loader.register_plugin(
- LoaderTest.DummyPlugin, plugin_1_path)
+ plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_1_path)
except ImportError:
pass
old_sys_path = copy.copy(sys.path)
- plugin_2_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin2.py')
+ plugin_2_path = os.path.join(SRC_ROOT, "plugin/sampleplugin2.py")
try:
- plugin_loader.register_plugin(
- LoaderTest.DummyPlugin, plugin_2_path)
+ plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_2_path)
except ImportError:
pass
- self.assertEqual(old_sys_path, sys.path, 'Should be equal.')
+ self.assertEqual(old_sys_path, sys.path, "Should be equal.")
def test_register_plugin_class_class_doesnt_exist(self):
- file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py')
+ file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py")
expected_msg = 'doesn\'t expose class named "SamplePluginNotExists"'
- self.assertRaisesRegexp(Exception, expected_msg,
- plugin_loader.register_plugin_class,
- base_class=LoaderTest.DummyPlugin,
- file_path=file_path,
- class_name='SamplePluginNotExists')
+ self.assertRaisesRegexp(
+ Exception,
+ expected_msg,
+ plugin_loader.register_plugin_class,
+ base_class=LoaderTest.DummyPlugin,
+ file_path=file_path,
+ class_name="SamplePluginNotExists",
+ )
def test_register_plugin_class_abstract_method_not_implemented(self):
- file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py')
-
- expected_msg = 'doesn\'t implement required "do_work" method from the base class'
- self.assertRaisesRegexp(plugin_loader.IncompatiblePluginException, expected_msg,
- plugin_loader.register_plugin_class,
- base_class=LoaderTest.DummyPlugin,
- file_path=file_path,
- class_name='SamplePlugin')
+ file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py")
+
+ expected_msg = (
+ 'doesn\'t implement required "do_work" method from the base class'
+ )
+ self.assertRaisesRegexp(
+ plugin_loader.IncompatiblePluginException,
+ expected_msg,
+ plugin_loader.register_plugin_class,
+ base_class=LoaderTest.DummyPlugin,
+ file_path=file_path,
+ class_name="SamplePlugin",
+ )
diff --git a/st2common/tests/unit/test_policies.py b/st2common/tests/unit/test_policies.py
index 5491e7482b..f6dd9a47de 100644
--- a/st2common/tests/unit/test_policies.py
+++ b/st2common/tests/unit/test_policies.py
@@ -19,55 +19,43 @@
from st2tests import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'PolicyTestCase'
-]
+__all__ = ["PolicyTestCase"]
-PACK = 'generic'
+PACK = "generic"
TEST_FIXTURES = {
- 'runners': [
- 'testrunner1.yaml'
- ],
- 'actions': [
- 'action1.yaml'
- ],
- 'policytypes': [
- 'fake_policy_type_1.yaml',
- 'fake_policy_type_2.yaml'
- ],
- 'policies': [
- 'policy_1.yaml',
- 'policy_2.yaml'
- ]
+ "runners": ["testrunner1.yaml"],
+ "actions": ["action1.yaml"],
+ "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"],
+ "policies": ["policy_1.yaml", "policy_2.yaml"],
}
class PolicyTestCase(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(PolicyTestCase, cls).setUpClass()
loader = FixturesLoader()
- loader.save_fixtures_to_db(fixtures_pack=PACK,
- fixtures_dict=TEST_FIXTURES)
+ loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES)
def test_get_by_ref(self):
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency')
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency")
self.assertIsNotNone(policy_db)
- self.assertEqual(policy_db.pack, 'wolfpack')
- self.assertEqual(policy_db.name, 'action-1.concurrency')
+ self.assertEqual(policy_db.pack, "wolfpack")
+ self.assertEqual(policy_db.name, "action-1.concurrency")
policy_type_db = PolicyType.get_by_ref(policy_db.policy_type)
self.assertIsNotNone(policy_type_db)
- self.assertEqual(policy_type_db.resource_type, 'action')
- self.assertEqual(policy_type_db.name, 'concurrency')
+ self.assertEqual(policy_type_db.resource_type, "action")
+ self.assertEqual(policy_type_db.name, "concurrency")
def test_get_driver(self):
- policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency')
- policy = get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters)
+ policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency")
+ policy = get_driver(
+ policy_db.ref, policy_db.policy_type, **policy_db.parameters
+ )
self.assertIsInstance(policy, ResourcePolicyApplicator)
self.assertEqual(policy._policy_ref, policy_db.ref)
self.assertEqual(policy._policy_type, policy_db.policy_type)
- self.assertTrue(hasattr(policy, 'threshold'))
+ self.assertTrue(hasattr(policy, "threshold"))
self.assertEqual(policy.threshold, 3)
diff --git a/st2common/tests/unit/test_policies_registrar.py b/st2common/tests/unit/test_policies_registrar.py
index b46515e08a..85c1d34490 100644
--- a/st2common/tests/unit/test_policies_registrar.py
+++ b/st2common/tests/unit/test_policies_registrar.py
@@ -29,9 +29,7 @@
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import get_fixtures_packs_base_path
-__all__ = [
- 'PoliciesRegistrarTestCase'
-]
+__all__ = ["PoliciesRegistrarTestCase"]
class PoliciesRegistrarTestCase(CleanDbTestCase):
@@ -44,13 +42,13 @@ def setUp(self):
def test_register_policy_types(self):
self.assertEqual(register_policy_types(st2tests), 2)
- type1 = PolicyType.get_by_ref('action.concurrency')
- self.assertEqual(type1.name, 'concurrency')
- self.assertEqual(type1.resource_type, 'action')
+ type1 = PolicyType.get_by_ref("action.concurrency")
+ self.assertEqual(type1.name, "concurrency")
+ self.assertEqual(type1.resource_type, "action")
- type2 = PolicyType.get_by_ref('action.mock_policy_error')
- self.assertEqual(type2.name, 'mock_policy_error')
- self.assertEqual(type2.resource_type, 'action')
+ type2 = PolicyType.get_by_ref("action.mock_policy_error")
+ self.assertEqual(type2.name, "mock_policy_error")
+ self.assertEqual(type2.resource_type, "action")
def test_register_all_policies(self):
policies_dbs = Policy.get_all()
@@ -64,38 +62,29 @@ def test_register_all_policies(self):
policies = {
policies_db.name: {
- 'pack': policies_db.pack,
- 'type': policies_db.policy_type,
- 'parameters': policies_db.parameters
+ "pack": policies_db.pack,
+ "type": policies_db.policy_type,
+ "parameters": policies_db.parameters,
}
for policies_db in policies_dbs
}
expected_policies = {
- 'test_policy_1': {
- 'pack': 'dummy_pack_1',
- 'type': 'action.concurrency',
- 'parameters': {
- 'action': 'delay',
- 'threshold': 3
- }
+ "test_policy_1": {
+ "pack": "dummy_pack_1",
+ "type": "action.concurrency",
+ "parameters": {"action": "delay", "threshold": 3},
},
- 'test_policy_3': {
- 'pack': 'dummy_pack_1',
- 'type': 'action.retry',
- 'parameters': {
- 'retry_on': 'timeout',
- 'max_retry_count': 5
- }
+ "test_policy_3": {
+ "pack": "dummy_pack_1",
+ "type": "action.retry",
+ "parameters": {"retry_on": "timeout", "max_retry_count": 5},
+ },
+ "sequential.retry_on_failure": {
+ "pack": "orquesta_tests",
+ "type": "action.retry",
+ "parameters": {"retry_on": "failure", "max_retry_count": 1},
},
- 'sequential.retry_on_failure': {
- 'pack': 'orquesta_tests',
- 'type': 'action.retry',
- 'parameters': {
- 'retry_on': 'failure',
- 'max_retry_count': 1
- }
- }
}
self.assertEqual(len(expected_policies), count)
@@ -103,39 +92,49 @@ def test_register_all_policies(self):
self.assertDictEqual(expected_policies, policies)
def test_register_policies_from_pack(self):
- pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1')
+ pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1")
self.assertEqual(register_policies(pack_dir=pack_dir), 2)
- p1 = Policy.get_by_ref('dummy_pack_1.test_policy_1')
- self.assertEqual(p1.name, 'test_policy_1')
- self.assertEqual(p1.pack, 'dummy_pack_1')
- self.assertEqual(p1.resource_ref, 'dummy_pack_1.local')
- self.assertEqual(p1.policy_type, 'action.concurrency')
+ p1 = Policy.get_by_ref("dummy_pack_1.test_policy_1")
+ self.assertEqual(p1.name, "test_policy_1")
+ self.assertEqual(p1.pack, "dummy_pack_1")
+ self.assertEqual(p1.resource_ref, "dummy_pack_1.local")
+ self.assertEqual(p1.policy_type, "action.concurrency")
# Verify that a default value for parameter "action" which isn't provided in the file is set
- self.assertEqual(p1.parameters['action'], 'delay')
- self.assertEqual(p1.metadata_file, 'policies/policy_1.yaml')
+ self.assertEqual(p1.parameters["action"], "delay")
+ self.assertEqual(p1.metadata_file, "policies/policy_1.yaml")
- p2 = Policy.get_by_ref('dummy_pack_1.test_policy_2')
+ p2 = Policy.get_by_ref("dummy_pack_1.test_policy_2")
self.assertEqual(p2, None)
def test_register_policy_invalid_policy_type_references(self):
# Policy references an invalid (inexistent) policy type
registrar = PolicyRegistrar()
- policy_path = os.path.join(get_fixtures_packs_base_path(),
- 'dummy_pack_1/policies/policy_2.yaml')
+ policy_path = os.path.join(
+ get_fixtures_packs_base_path(), "dummy_pack_1/policies/policy_2.yaml"
+ )
expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist'
- self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_policy,
- pack='dummy_pack_1', policy=policy_path)
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar._register_policy,
+ pack="dummy_pack_1",
+ policy=policy_path,
+ )
def test_make_sure_policy_parameters_are_validated_during_register(self):
# Policy where specified parameters fail schema validation
registrar = PolicyRegistrar()
- policy_path = os.path.join(get_fixtures_packs_base_path(),
- 'dummy_pack_2/policies/policy_3.yaml')
-
- expected_msg = '100 is greater than the maximum of 5'
- self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg,
- registrar._register_policy,
- pack='dummy_pack_2',
- policy=policy_path)
+ policy_path = os.path.join(
+ get_fixtures_packs_base_path(), "dummy_pack_2/policies/policy_3.yaml"
+ )
+
+ expected_msg = "100 is greater than the maximum of 5"
+ self.assertRaisesRegexp(
+ jsonschema.ValidationError,
+ expected_msg,
+ registrar._register_policy,
+ pack="dummy_pack_2",
+ policy=policy_path,
+ )
diff --git a/st2common/tests/unit/test_purge_executions.py b/st2common/tests/unit/test_purge_executions.py
index 5362cc753e..64ee4cfa67 100644
--- a/st2common/tests/unit/test_purge_executions.py
+++ b/st2common/tests/unit/test_purge_executions.py
@@ -34,18 +34,10 @@
LOG = logging.getLogger(__name__)
-TEST_FIXTURES = {
- 'executions': [
- 'execution1.yaml'
- ],
- 'liveactions': [
- 'liveaction4.yaml'
- ]
-}
+TEST_FIXTURES = {"executions": ["execution1.yaml"], "liveactions": ["liveaction4.yaml"]}
class TestPurgeExecutions(CleanDbTestCase):
-
@classmethod
def setUpClass(cls):
CleanDbTestCase.setUpClass()
@@ -54,114 +46,128 @@ def setUpClass(cls):
def setUp(self):
super(TestPurgeExecutions, self).setUp()
fixtures_loader = FixturesLoader()
- self.models = fixtures_loader.load_models(fixtures_pack='generic',
- fixtures_dict=TEST_FIXTURES)
+ self.models = fixtures_loader.load_models(
+ fixtures_pack="generic", fixtures_dict=TEST_FIXTURES
+ )
def test_no_timestamp_doesnt_delete_things(self):
now = date_utils.get_datetime_utc_now()
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = now - timedelta(days=15)
- exec_model['end_timestamp'] = now - timedelta(days=14)
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- exec_model['id'] = bson.ObjectId()
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = now - timedelta(days=15)
+ exec_model["end_timestamp"] = now - timedelta(days=14)
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3)
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=3
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 1)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 3)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 3)
- expected_msg = 'Specify a valid timestamp'
- self.assertRaisesRegexp(ValueError, expected_msg, purge_executions,
- logger=LOG, timestamp=None)
+ expected_msg = "Specify a valid timestamp"
+ self.assertRaisesRegexp(
+ ValueError, expected_msg, purge_executions, logger=LOG, timestamp=None
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 1)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 3)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 3)
def test_purge_executions_with_action_ref(self):
now = date_utils.get_datetime_utc_now()
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = now - timedelta(days=15)
- exec_model['end_timestamp'] = now - timedelta(days=14)
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- exec_model['id'] = bson.ObjectId()
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = now - timedelta(days=15)
+ exec_model["end_timestamp"] = now - timedelta(days=14)
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3)
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=3
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 1)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 3)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 3)
# Invalid action reference, nothing should be deleted
- purge_executions(logger=LOG, action_ref='core.localzzz', timestamp=now - timedelta(days=10))
+ purge_executions(
+ logger=LOG, action_ref="core.localzzz", timestamp=now - timedelta(days=10)
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 1)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 3)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 3)
- purge_executions(logger=LOG, action_ref='core.local', timestamp=now - timedelta(days=10))
+ purge_executions(
+ logger=LOG, action_ref="core.local", timestamp=now - timedelta(days=10)
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 0)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 0)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 0)
def test_purge_executions_with_timestamp(self):
now = date_utils.get_datetime_utc_now()
# Write one execution after cut-off threshold
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = now - timedelta(days=15)
- exec_model['end_timestamp'] = now - timedelta(days=14)
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- exec_model['id'] = bson.ObjectId()
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = now - timedelta(days=15)
+ exec_model["end_timestamp"] = now - timedelta(days=14)
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3)
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=3
+ )
# Write one execution before cut-off threshold
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = now - timedelta(days=22)
- exec_model['end_timestamp'] = now - timedelta(days=21)
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- exec_model['id'] = bson.ObjectId()
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = now - timedelta(days=22)
+ exec_model["end_timestamp"] = now - timedelta(days=21)
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3)
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=3
+ )
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 2)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 6)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 6)
purge_executions(logger=LOG, timestamp=now - timedelta(days=20))
execs = ActionExecution.get_all()
self.assertEqual(len(execs), 1)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 3)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 3)
def test_liveaction_gets_deleted(self):
@@ -169,19 +175,19 @@ def test_liveaction_gets_deleted(self):
start_ts = now - timedelta(days=15)
end_ts = now - timedelta(days=14)
- liveaction_model = copy.deepcopy(self.models['liveactions']['liveaction4.yaml'])
- liveaction_model['start_timestamp'] = start_ts
- liveaction_model['end_timestamp'] = end_ts
- liveaction_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ liveaction_model = copy.deepcopy(self.models["liveactions"]["liveaction4.yaml"])
+ liveaction_model["start_timestamp"] = start_ts
+ liveaction_model["end_timestamp"] = end_ts
+ liveaction_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
liveaction = LiveAction.add_or_update(liveaction_model)
# Write one execution before cut-off threshold
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['end_timestamp'] = end_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED
- exec_model['id'] = bson.ObjectId()
- exec_model['liveaction']['id'] = str(liveaction.id)
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["end_timestamp"] = end_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED
+ exec_model["id"] = bson.ObjectId()
+ exec_model["liveaction"]["id"] = str(liveaction.id)
ActionExecution.add_or_update(exec_model)
liveactions = LiveAction.get_all()
@@ -201,110 +207,143 @@ def test_purge_incomplete(self):
start_ts = now - timedelta(days=15)
# Write executions before cut-off threshold
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED
- exec_model['id'] = bson.ObjectId()
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1)
-
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_RUNNING
- exec_model['id'] = bson.ObjectId()
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=1
+ )
+
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_RUNNING
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1)
-
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_DELAYED
- exec_model['id'] = bson.ObjectId()
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=1
+ )
+
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_DELAYED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1)
-
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_CANCELING
- exec_model['id'] = bson.ObjectId()
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=1
+ )
+
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_CANCELING
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1)
-
- exec_model = copy.deepcopy(self.models['executions']['execution1.yaml'])
- exec_model['start_timestamp'] = start_ts
- exec_model['status'] = action_constants.LIVEACTION_STATUS_REQUESTED
- exec_model['id'] = bson.ObjectId()
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=1
+ )
+
+ exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"])
+ exec_model["start_timestamp"] = start_ts
+ exec_model["status"] = action_constants.LIVEACTION_STATUS_REQUESTED
+ exec_model["id"] = bson.ObjectId()
ActionExecution.add_or_update(exec_model)
# Insert corresponding stdout and stderr db mock models
- self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1)
+ self._insert_mock_stdout_and_stderr_objects_for_execution(
+ exec_model["id"], count=1
+ )
self.assertEqual(len(ActionExecution.get_all()), 5)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 5)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 5)
# Incompleted executions shouldnt be purged
- purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False)
+ purge_executions(
+ logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False
+ )
self.assertEqual(len(ActionExecution.get_all()), 5)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 5)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 5)
- purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True)
+ purge_executions(
+ logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True
+ )
self.assertEqual(len(ActionExecution.get_all()), 0)
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), 0)
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), 0)
- @mock.patch('st2common.garbage_collection.executions.LiveAction')
- @mock.patch('st2common.garbage_collection.executions.ActionExecution')
- def test_purge_executions_whole_model_is_not_loaded_in_memory(self, mock_ActionExecution,
- mock_LiveAction):
+ @mock.patch("st2common.garbage_collection.executions.LiveAction")
+ @mock.patch("st2common.garbage_collection.executions.ActionExecution")
+ def test_purge_executions_whole_model_is_not_loaded_in_memory(
+ self, mock_ActionExecution, mock_LiveAction
+ ):
# Verify that whole execution objects are not loaded in memory and we just retrieve the
# id field
self.assertEqual(mock_ActionExecution.query.call_count, 0)
self.assertEqual(mock_LiveAction.query.call_count, 0)
now = date_utils.get_datetime_utc_now()
- purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True)
+ purge_executions(
+ logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True
+ )
self.assertEqual(mock_ActionExecution.query.call_count, 2)
self.assertEqual(mock_LiveAction.query.call_count, 1)
- self.assertEqual(mock_ActionExecution.query.call_args_list[0][1]['only_fields'], ['id'])
- self.assertTrue(mock_ActionExecution.query.call_args_list[0][1]['no_dereference'])
- self.assertEqual(mock_ActionExecution.query.call_args_list[1][1]['only_fields'], ['id'])
- self.assertTrue(mock_ActionExecution.query.call_args_list[1][1]['no_dereference'])
- self.assertEqual(mock_LiveAction.query.call_args_list[0][1]['only_fields'], ['id'])
- self.assertTrue(mock_LiveAction.query.call_args_list[0][1]['no_dereference'])
-
- def _insert_mock_stdout_and_stderr_objects_for_execution(self, execution_id, count=5):
+ self.assertEqual(
+ mock_ActionExecution.query.call_args_list[0][1]["only_fields"], ["id"]
+ )
+ self.assertTrue(
+ mock_ActionExecution.query.call_args_list[0][1]["no_dereference"]
+ )
+ self.assertEqual(
+ mock_ActionExecution.query.call_args_list[1][1]["only_fields"], ["id"]
+ )
+ self.assertTrue(
+ mock_ActionExecution.query.call_args_list[1][1]["no_dereference"]
+ )
+ self.assertEqual(
+ mock_LiveAction.query.call_args_list[0][1]["only_fields"], ["id"]
+ )
+ self.assertTrue(mock_LiveAction.query.call_args_list[0][1]["no_dereference"])
+
+ def _insert_mock_stdout_and_stderr_objects_for_execution(
+ self, execution_id, count=5
+ ):
execution_id = str(execution_id)
stdout_dbs, stderr_dbs = [], []
for i in range(0, count):
- stdout_db = ActionExecutionOutputDB(execution_id=execution_id,
- action_ref='dummy.pack',
- runner_ref='dummy',
- output_type='stdout',
- data='stdout %s' % (i))
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=execution_id,
+ action_ref="dummy.pack",
+ runner_ref="dummy",
+ output_type="stdout",
+ data="stdout %s" % (i),
+ )
ActionExecutionOutput.add_or_update(stdout_db)
- stderr_db = ActionExecutionOutputDB(execution_id=execution_id,
- action_ref='dummy.pack',
- runner_ref='dummy',
- output_type='stderr',
- data='stderr%s' % (i))
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=execution_id,
+ action_ref="dummy.pack",
+ runner_ref="dummy",
+ output_type="stderr",
+ data="stderr%s" % (i),
+ )
ActionExecutionOutput.add_or_update(stderr_db)
return stdout_dbs, stderr_dbs
diff --git a/st2common/tests/unit/test_purge_trigger_instances.py b/st2common/tests/unit/test_purge_trigger_instances.py
index 2cc9f6ffed..515c4040c3 100644
--- a/st2common/tests/unit/test_purge_trigger_instances.py
+++ b/st2common/tests/unit/test_purge_trigger_instances.py
@@ -28,7 +28,6 @@
class TestPurgeTriggerInstances(CleanDbTestCase):
-
@classmethod
def setUpClass(cls):
CleanDbTestCase.setUpClass()
@@ -40,32 +39,42 @@ def setUp(self):
def test_no_timestamp_doesnt_delete(self):
now = date_utils.get_datetime_utc_now()
- instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger',
- payload={'hola': 'hi', 'kuraci': 'chicken'},
- occurrence_time=now - timedelta(days=20),
- status=TRIGGER_INSTANCE_PROCESSED)
+ instance_db = TriggerInstanceDB(
+ trigger="purge_tool.dummy.trigger",
+ payload={"hola": "hi", "kuraci": "chicken"},
+ occurrence_time=now - timedelta(days=20),
+ status=TRIGGER_INSTANCE_PROCESSED,
+ )
TriggerInstance.add_or_update(instance_db)
self.assertEqual(len(TriggerInstance.get_all()), 1)
- expected_msg = 'Specify a valid timestamp'
- self.assertRaisesRegexp(ValueError, expected_msg,
- purge_trigger_instances,
- logger=LOG, timestamp=None)
+ expected_msg = "Specify a valid timestamp"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ purge_trigger_instances,
+ logger=LOG,
+ timestamp=None,
+ )
self.assertEqual(len(TriggerInstance.get_all()), 1)
def test_purge(self):
now = date_utils.get_datetime_utc_now()
- instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger',
- payload={'hola': 'hi', 'kuraci': 'chicken'},
- occurrence_time=now - timedelta(days=20),
- status=TRIGGER_INSTANCE_PROCESSED)
+ instance_db = TriggerInstanceDB(
+ trigger="purge_tool.dummy.trigger",
+ payload={"hola": "hi", "kuraci": "chicken"},
+ occurrence_time=now - timedelta(days=20),
+ status=TRIGGER_INSTANCE_PROCESSED,
+ )
TriggerInstance.add_or_update(instance_db)
- instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger',
- payload={'hola': 'hi', 'kuraci': 'chicken'},
- occurrence_time=now - timedelta(days=5),
- status=TRIGGER_INSTANCE_PROCESSED)
+ instance_db = TriggerInstanceDB(
+ trigger="purge_tool.dummy.trigger",
+ payload={"hola": "hi", "kuraci": "chicken"},
+ occurrence_time=now - timedelta(days=5),
+ status=TRIGGER_INSTANCE_PROCESSED,
+ )
TriggerInstance.add_or_update(instance_db)
self.assertEqual(len(TriggerInstance.get_all()), 2)
diff --git a/st2common/tests/unit/test_queue_consumer.py b/st2common/tests/unit/test_queue_consumer.py
index 4f54c325aa..463eb0def4 100644
--- a/st2common/tests/unit/test_queue_consumer.py
+++ b/st2common/tests/unit/test_queue_consumer.py
@@ -23,8 +23,8 @@
from tests.unit.base import FakeModelDB
-FAKE_XCHG = Exchange('st2.tests', type='topic')
-FAKE_WORK_Q = Queue('st2.tests.unit', FAKE_XCHG)
+FAKE_XCHG = Exchange("st2.tests", type="topic")
+FAKE_WORK_Q = Queue("st2.tests.unit", FAKE_XCHG)
class FakeMessageHandler(consumers.MessageHandler):
@@ -39,15 +39,14 @@ def get_handler():
class QueueConsumerTest(DbTestCase):
-
- @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock())
+ @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock())
def test_process_message(self):
payload = FakeModelDB()
handler = get_handler()
handler._queue_consumer._process_message(payload)
FakeMessageHandler.process.assert_called_once_with(payload)
- @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock())
+ @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock())
def test_process_message_wrong_payload_type(self):
payload = 100
handler = get_handler()
@@ -72,8 +71,7 @@ def get_staged_handler():
class StagedQueueConsumerTest(DbTestCase):
-
- @mock.patch.object(FakeStagedMessageHandler, 'pre_ack_process', mock.MagicMock())
+ @mock.patch.object(FakeStagedMessageHandler, "pre_ack_process", mock.MagicMock())
def test_process_message_pre_ack(self):
payload = FakeModelDB()
handler = get_staged_handler()
@@ -82,15 +80,16 @@ def test_process_message_pre_ack(self):
FakeStagedMessageHandler.pre_ack_process.assert_called_once_with(payload)
self.assertTrue(mock_message.ack.called)
- @mock.patch.object(BufferedDispatcher, 'dispatch', mock.MagicMock())
- @mock.patch.object(FakeStagedMessageHandler, 'process', mock.MagicMock())
+ @mock.patch.object(BufferedDispatcher, "dispatch", mock.MagicMock())
+ @mock.patch.object(FakeStagedMessageHandler, "process", mock.MagicMock())
def test_process_message(self):
payload = FakeModelDB()
handler = get_staged_handler()
mock_message = mock.MagicMock()
handler._queue_consumer.process(payload, mock_message)
BufferedDispatcher.dispatch.assert_called_once_with(
- handler._queue_consumer._process_message, payload)
+ handler._queue_consumer._process_message, payload
+ )
handler._queue_consumer._process_message(payload)
FakeStagedMessageHandler.process.assert_called_once_with(payload)
self.assertTrue(mock_message.ack.called)
@@ -104,13 +103,10 @@ def test_process_message_wrong_payload_type(self):
class FakeVariableMessageHandler(consumers.VariableMessageHandler):
-
def __init__(self, connection, queues):
super(FakeVariableMessageHandler, self).__init__(connection, queues)
- self.message_types = {
- FakeModelDB: self.handle_fake_model
- }
+ self.message_types = {FakeModelDB: self.handle_fake_model}
def process(self, message):
handler_function = self.message_types.get(type(message))
@@ -125,15 +121,16 @@ def get_variable_messages_handler():
class VariableMessageQueueConsumerTest(DbTestCase):
-
- @mock.patch.object(FakeVariableMessageHandler, 'handle_fake_model', mock.MagicMock())
+ @mock.patch.object(
+ FakeVariableMessageHandler, "handle_fake_model", mock.MagicMock()
+ )
def test_process_message(self):
payload = FakeModelDB()
handler = get_variable_messages_handler()
handler._queue_consumer._process_message(payload)
FakeVariableMessageHandler.handle_fake_model.assert_called_once_with(payload)
- @mock.patch.object(FakeVariableMessageHandler, 'process', mock.MagicMock())
+ @mock.patch.object(FakeVariableMessageHandler, "process", mock.MagicMock())
def test_process_message_wrong_payload_type(self):
payload = 100
handler = get_variable_messages_handler()
diff --git a/st2common/tests/unit/test_queue_utils.py b/st2common/tests/unit/test_queue_utils.py
index 52ad7a60dc..db77fc01c2 100644
--- a/st2common/tests/unit/test_queue_utils.py
+++ b/st2common/tests/unit/test_queue_utils.py
@@ -22,31 +22,42 @@
class TestQueueUtils(TestCase):
-
def test_get_queue_name(self):
- self.assertRaises(ValueError,
- queue_utils.get_queue_name,
- queue_name_base=None, queue_name_suffix=None)
- self.assertRaises(ValueError,
- queue_utils.get_queue_name,
- queue_name_base='', queue_name_suffix=None)
- self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch',
- queue_name_suffix=None),
- 'st2.test.watch')
- self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch',
- queue_name_suffix=''),
- 'st2.test.watch')
+ self.assertRaises(
+ ValueError,
+ queue_utils.get_queue_name,
+ queue_name_base=None,
+ queue_name_suffix=None,
+ )
+ self.assertRaises(
+ ValueError,
+ queue_utils.get_queue_name,
+ queue_name_base="",
+ queue_name_suffix=None,
+ )
+ self.assertEqual(
+ queue_utils.get_queue_name(
+ queue_name_base="st2.test.watch", queue_name_suffix=None
+ ),
+ "st2.test.watch",
+ )
+ self.assertEqual(
+ queue_utils.get_queue_name(
+ queue_name_base="st2.test.watch", queue_name_suffix=""
+ ),
+ "st2.test.watch",
+ )
queue_name = queue_utils.get_queue_name(
- queue_name_base='st2.test.watch',
- queue_name_suffix='foo',
- add_random_uuid_to_suffix=True
+ queue_name_base="st2.test.watch",
+ queue_name_suffix="foo",
+ add_random_uuid_to_suffix=True,
)
- pattern = re.compile(r'st2.test.watch.foo-\w')
+ pattern = re.compile(r"st2.test.watch.foo-\w")
self.assertTrue(re.match(pattern, queue_name))
queue_name = queue_utils.get_queue_name(
- queue_name_base='st2.test.watch',
- queue_name_suffix='foo',
- add_random_uuid_to_suffix=False
+ queue_name_base="st2.test.watch",
+ queue_name_suffix="foo",
+ add_random_uuid_to_suffix=False,
)
- self.assertEqual(queue_name, 'st2.test.watch.foo')
+ self.assertEqual(queue_name, "st2.test.watch.foo")
diff --git a/st2common/tests/unit/test_rbac_types.py b/st2common/tests/unit/test_rbac_types.py
index d9d0a1dae8..03b5350cc9 100644
--- a/st2common/tests/unit/test_rbac_types.py
+++ b/st2common/tests/unit/test_rbac_types.py
@@ -22,158 +22,274 @@
class RBACPermissionTypeTestCase(TestCase):
-
def test_get_valid_permission_for_resource_type(self):
- valid_action_permissions = PermissionType.get_valid_permissions_for_resource_type(
- resource_type=ResourceType.ACTION)
+ valid_action_permissions = (
+ PermissionType.get_valid_permissions_for_resource_type(
+ resource_type=ResourceType.ACTION
+ )
+ )
for name in valid_action_permissions:
- self.assertTrue(name.startswith(ResourceType.ACTION + '_'))
+ self.assertTrue(name.startswith(ResourceType.ACTION + "_"))
valid_rule_permissions = PermissionType.get_valid_permissions_for_resource_type(
- resource_type=ResourceType.RULE)
+ resource_type=ResourceType.RULE
+ )
for name in valid_rule_permissions:
- self.assertTrue(name.startswith(ResourceType.RULE + '_'))
+ self.assertTrue(name.startswith(ResourceType.RULE + "_"))
def test_get_resource_type(self):
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_LIST),
- SystemType.PACK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_VIEW),
- SystemType.PACK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_CREATE),
- SystemType.PACK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_MODIFY),
- SystemType.PACK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_DELETE),
- SystemType.PACK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_ALL),
- SystemType.PACK)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_LIST),
- SystemType.SENSOR_TYPE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_VIEW),
- SystemType.SENSOR_TYPE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY),
- SystemType.SENSOR_TYPE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_ALL),
- SystemType.SENSOR_TYPE)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_LIST),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_VIEW),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_CREATE),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_MODIFY),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_DELETE),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE),
- SystemType.ACTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_ALL),
- SystemType.ACTION)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_LIST),
- SystemType.EXECUTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW),
- SystemType.EXECUTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN),
- SystemType.EXECUTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_STOP),
- SystemType.EXECUTION)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_ALL),
- SystemType.EXECUTION)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_LIST),
- SystemType.RULE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_VIEW),
- SystemType.RULE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_CREATE),
- SystemType.RULE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_MODIFY),
- SystemType.RULE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_DELETE),
- SystemType.RULE)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ALL),
- SystemType.RULE)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST),
- SystemType.RULE_ENFORCEMENT)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW),
- SystemType.RULE_ENFORCEMENT)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW),
- SystemType.KEY_VALUE_PAIR)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET),
- SystemType.KEY_VALUE_PAIR)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE),
- SystemType.KEY_VALUE_PAIR)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE),
- SystemType.WEBHOOK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND),
- SystemType.WEBHOOK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE),
- SystemType.WEBHOOK)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL),
- SystemType.WEBHOOK)
-
- self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_LIST),
- SystemType.API_KEY)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_VIEW),
- SystemType.API_KEY)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_CREATE),
- SystemType.API_KEY)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_DELETE),
- SystemType.API_KEY)
- self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_ALL),
- SystemType.API_KEY)
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_LIST), SystemType.PACK
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_VIEW), SystemType.PACK
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_CREATE),
+ SystemType.PACK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_MODIFY),
+ SystemType.PACK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_DELETE),
+ SystemType.PACK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.PACK_ALL), SystemType.PACK
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.SENSOR_LIST),
+ SystemType.SENSOR_TYPE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.SENSOR_VIEW),
+ SystemType.SENSOR_TYPE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY),
+ SystemType.SENSOR_TYPE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.SENSOR_ALL),
+ SystemType.SENSOR_TYPE,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_LIST),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_VIEW),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_CREATE),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_MODIFY),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_DELETE),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE),
+ SystemType.ACTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.ACTION_ALL),
+ SystemType.ACTION,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.EXECUTION_LIST),
+ SystemType.EXECUTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW),
+ SystemType.EXECUTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN),
+ SystemType.EXECUTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.EXECUTION_STOP),
+ SystemType.EXECUTION,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.EXECUTION_ALL),
+ SystemType.EXECUTION,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_LIST), SystemType.RULE
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_VIEW), SystemType.RULE
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_CREATE),
+ SystemType.RULE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_MODIFY),
+ SystemType.RULE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_DELETE),
+ SystemType.RULE,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_ALL), SystemType.RULE
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST),
+ SystemType.RULE_ENFORCEMENT,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW),
+ SystemType.RULE_ENFORCEMENT,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW),
+ SystemType.KEY_VALUE_PAIR,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET),
+ SystemType.KEY_VALUE_PAIR,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE),
+ SystemType.KEY_VALUE_PAIR,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE),
+ SystemType.WEBHOOK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND),
+ SystemType.WEBHOOK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE),
+ SystemType.WEBHOOK,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL),
+ SystemType.WEBHOOK,
+ )
+
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.API_KEY_LIST),
+ SystemType.API_KEY,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.API_KEY_VIEW),
+ SystemType.API_KEY,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.API_KEY_CREATE),
+ SystemType.API_KEY,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.API_KEY_DELETE),
+ SystemType.API_KEY,
+ )
+ self.assertEqual(
+ PermissionType.get_resource_type(PermissionType.API_KEY_ALL),
+ SystemType.API_KEY,
+ )
def test_get_permission_type(self):
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION,
- permission_name='view'),
- PermissionType.ACTION_VIEW)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION,
- permission_name='all'),
- PermissionType.ACTION_ALL)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION,
- permission_name='execute'),
- PermissionType.ACTION_EXECUTE)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE,
- permission_name='view'),
- PermissionType.RULE_VIEW)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE,
- permission_name='delete'),
- PermissionType.RULE_DELETE)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR,
- permission_name='view'),
- PermissionType.SENSOR_VIEW)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR,
- permission_name='all'),
- PermissionType.SENSOR_ALL)
- self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR,
- permission_name='modify'),
- PermissionType.SENSOR_MODIFY)
- self.assertEqual(
- PermissionType.get_permission_type(resource_type=ResourceType.RULE_ENFORCEMENT,
- permission_name='view'),
- PermissionType.RULE_ENFORCEMENT_VIEW)
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.ACTION, permission_name="view"
+ ),
+ PermissionType.ACTION_VIEW,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.ACTION, permission_name="all"
+ ),
+ PermissionType.ACTION_ALL,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.ACTION, permission_name="execute"
+ ),
+ PermissionType.ACTION_EXECUTE,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.RULE, permission_name="view"
+ ),
+ PermissionType.RULE_VIEW,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.RULE, permission_name="delete"
+ ),
+ PermissionType.RULE_DELETE,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.SENSOR, permission_name="view"
+ ),
+ PermissionType.SENSOR_VIEW,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.SENSOR, permission_name="all"
+ ),
+ PermissionType.SENSOR_ALL,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.SENSOR, permission_name="modify"
+ ),
+ PermissionType.SENSOR_MODIFY,
+ )
+ self.assertEqual(
+ PermissionType.get_permission_type(
+ resource_type=ResourceType.RULE_ENFORCEMENT, permission_name="view"
+ ),
+ PermissionType.RULE_ENFORCEMENT_VIEW,
+ )
def test_get_permission_name(self):
- self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_LIST),
- 'list')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_CREATE),
- 'create')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_DELETE),
- 'delete')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_ALL),
- 'all')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.PACK_ALL),
- 'all')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY),
- 'modify')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE),
- 'execute')
- self.assertEqual(PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST),
- 'list')
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.ACTION_LIST), "list"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.ACTION_CREATE), "create"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.ACTION_DELETE), "delete"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.ACTION_ALL), "all"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.PACK_ALL), "all"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY), "modify"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE), "execute"
+ )
+ self.assertEqual(
+ PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST),
+ "list",
+ )
diff --git a/st2common/tests/unit/test_reference.py b/st2common/tests/unit/test_reference.py
index f39800c2dd..ced486a867 100644
--- a/st2common/tests/unit/test_reference.py
+++ b/st2common/tests/unit/test_reference.py
@@ -26,35 +26,34 @@
from st2tests import DbTestCase
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ReferenceTest(DbTestCase):
__model = None
__ref = None
@classmethod
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def setUpClass(cls):
super(ReferenceTest, cls).setUpClass()
- trigger = TriggerDB(pack='dummy_pack_1', name='trigger-1')
+ trigger = TriggerDB(pack="dummy_pack_1", name="trigger-1")
cls.__model = Trigger.add_or_update(trigger)
- cls.__ref = {'id': str(cls.__model.id),
- 'name': cls.__model.name}
+ cls.__ref = {"id": str(cls.__model.id), "name": cls.__model.name}
@classmethod
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def tearDownClass(cls):
Trigger.delete(cls.__model)
super(ReferenceTest, cls).tearDownClass()
def test_to_reference(self):
ref = reference.get_ref_from_model(self.__model)
- self.assertEqual(ref, self.__ref, 'Failed to generated equivalent ref.')
+ self.assertEqual(ref, self.__ref, "Failed to generated equivalent ref.")
def test_to_reference_no_model(self):
try:
reference.get_ref_from_model(None)
- self.assertTrue(False, 'Exception expected.')
+ self.assertTrue(False, "Exception expected.")
except ValueError:
self.assertTrue(True)
@@ -63,37 +62,37 @@ def test_to_reference_no_model_id(self):
model = copy.copy(self.__model)
model.id = None
reference.get_ref_from_model(model)
- self.assertTrue(False, 'Exception expected.')
+ self.assertTrue(False, "Exception expected.")
except db.StackStormDBObjectMalformedError:
self.assertTrue(True)
def test_to_model_with_id(self):
model = reference.get_model_from_ref(Trigger, self.__ref)
- self.assertEqual(model, self.__model, 'Failed to return correct model.')
+ self.assertEqual(model, self.__model, "Failed to return correct model.")
def test_to_model_with_name(self):
ref = copy.copy(self.__ref)
- ref['id'] = None
+ ref["id"] = None
model = reference.get_model_from_ref(Trigger, ref)
- self.assertEqual(model, self.__model, 'Failed to return correct model.')
+ self.assertEqual(model, self.__model, "Failed to return correct model.")
def test_to_model_no_name_no_id(self):
try:
reference.get_model_from_ref(Trigger, {})
- self.assertTrue(False, 'Exception expected.')
+ self.assertTrue(False, "Exception expected.")
except db.StackStormDBObjectNotFoundError:
self.assertTrue(True)
def test_to_model_unknown_id(self):
try:
- reference.get_model_from_ref(Trigger, {'id': '1'})
- self.assertTrue(False, 'Exception expected.')
+ reference.get_model_from_ref(Trigger, {"id": "1"})
+ self.assertTrue(False, "Exception expected.")
except mongoengine.ValidationError:
self.assertTrue(True)
def test_to_model_unknown_name(self):
try:
- reference.get_model_from_ref(Trigger, {'name': 'unknown'})
- self.assertTrue(False, 'Exception expected.')
+ reference.get_model_from_ref(Trigger, {"name": "unknown"})
+ self.assertTrue(False, "Exception expected.")
except db.StackStormDBObjectNotFoundError:
self.assertTrue(True)
diff --git a/st2common/tests/unit/test_register_internal_trigger.py b/st2common/tests/unit/test_register_internal_trigger.py
index dd4959611f..3d33e32548 100644
--- a/st2common/tests/unit/test_register_internal_trigger.py
+++ b/st2common/tests/unit/test_register_internal_trigger.py
@@ -20,7 +20,6 @@
class TestRegisterInternalTriggers(DbTestCase):
-
def test_register_internal_trigger_types(self):
registered_trigger_types_db = register_internal_trigger_types()
for trigger_type_db in registered_trigger_types_db:
@@ -31,4 +30,6 @@ def _validate_shadow_trigger(self, trigger_type_db):
return
trigger_type_ref = trigger_type_db.get_reference().ref
triggers = Trigger.query(type=trigger_type_ref)
- self.assertTrue(len(triggers) > 0, 'Shadow trigger not created for %s.' % trigger_type_ref)
+ self.assertTrue(
+ len(triggers) > 0, "Shadow trigger not created for %s." % trigger_type_ref
+ )
diff --git a/st2common/tests/unit/test_resource_reference.py b/st2common/tests/unit/test_resource_reference.py
index 04dfdc9357..95533022ed 100644
--- a/st2common/tests/unit/test_resource_reference.py
+++ b/st2common/tests/unit/test_resource_reference.py
@@ -22,45 +22,64 @@
class ResourceReferenceTestCase(unittest2.TestCase):
def test_resource_reference_success(self):
- value = 'pack1.name1'
+ value = "pack1.name1"
ref = ResourceReference.from_string_reference(ref=value)
- self.assertEqual(ref.pack, 'pack1')
- self.assertEqual(ref.name, 'name1')
+ self.assertEqual(ref.pack, "pack1")
+ self.assertEqual(ref.name, "name1")
self.assertEqual(ref.ref, value)
- ref = ResourceReference(pack='pack1', name='name1')
- self.assertEqual(ref.ref, 'pack1.name1')
+ ref = ResourceReference(pack="pack1", name="name1")
+ self.assertEqual(ref.ref, "pack1.name1")
- ref = ResourceReference(pack='pack1', name='name1.name2')
- self.assertEqual(ref.ref, 'pack1.name1.name2')
+ ref = ResourceReference(pack="pack1", name="name1.name2")
+ self.assertEqual(ref.ref, "pack1.name1.name2")
def test_resource_reference_failure(self):
- self.assertRaises(InvalidResourceReferenceError,
- ResourceReference.from_string_reference,
- ref='blah')
+ self.assertRaises(
+ InvalidResourceReferenceError,
+ ResourceReference.from_string_reference,
+ ref="blah",
+ )
- self.assertRaises(InvalidResourceReferenceError,
- ResourceReference.from_string_reference,
- ref=None)
+ self.assertRaises(
+ InvalidResourceReferenceError,
+ ResourceReference.from_string_reference,
+ ref=None,
+ )
def test_to_string_reference(self):
- ref = ResourceReference.to_string_reference(pack='mapack', name='moname')
- self.assertEqual(ref, 'mapack.moname')
+ ref = ResourceReference.to_string_reference(pack="mapack", name="moname")
+ self.assertEqual(ref, "mapack.moname")
expected_msg = r'Pack name should not contain "\."'
- self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference,
- pack='pack.invalid', name='bar')
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ ResourceReference.to_string_reference,
+ pack="pack.invalid",
+ name="bar",
+ )
- expected_msg = 'Both pack and name needed for building'
- self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference,
- pack='pack', name=None)
+ expected_msg = "Both pack and name needed for building"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ ResourceReference.to_string_reference,
+ pack="pack",
+ name=None,
+ )
- expected_msg = 'Both pack and name needed for building'
- self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference,
- pack=None, name='name')
+ expected_msg = "Both pack and name needed for building"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ ResourceReference.to_string_reference,
+ pack=None,
+ name="name",
+ )
def test_is_resource_reference(self):
- self.assertTrue(ResourceReference.is_resource_reference('foo.bar'))
- self.assertTrue(ResourceReference.is_resource_reference('foo.bar.ponies'))
- self.assertFalse(ResourceReference.is_resource_reference('foo'))
+ self.assertTrue(ResourceReference.is_resource_reference("foo.bar"))
+ self.assertTrue(ResourceReference.is_resource_reference("foo.bar.ponies"))
+ self.assertFalse(ResourceReference.is_resource_reference("foo"))
diff --git a/st2common/tests/unit/test_resource_registrar.py b/st2common/tests/unit/test_resource_registrar.py
index 9850785f21..2a1c61ad6a 100644
--- a/st2common/tests/unit/test_resource_registrar.py
+++ b/st2common/tests/unit/test_resource_registrar.py
@@ -30,23 +30,21 @@
from st2tests.fixturesloader import get_fixtures_base_path
-__all__ = [
- 'ResourceRegistrarTestCase'
-]
-
-PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_1')
-PACK_PATH_6 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_6')
-PACK_PATH_7 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_7')
-PACK_PATH_8 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_8')
-PACK_PATH_9 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_9')
-PACK_PATH_10 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_10')
-PACK_PATH_12 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_12')
-PACK_PATH_13 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_13')
-PACK_PATH_14 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_14')
-PACK_PATH_17 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_17')
-PACK_PATH_18 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_18')
-PACK_PATH_20 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_20')
-PACK_PATH_21 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_21')
+__all__ = ["ResourceRegistrarTestCase"]
+
+PACK_PATH_1 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_1")
+PACK_PATH_6 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_6")
+PACK_PATH_7 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_7")
+PACK_PATH_8 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_8")
+PACK_PATH_9 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_9")
+PACK_PATH_10 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_10")
+PACK_PATH_12 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_12")
+PACK_PATH_13 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_13")
+PACK_PATH_14 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_14")
+PACK_PATH_17 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_17")
+PACK_PATH_18 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_18")
+PACK_PATH_20 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_20")
+PACK_PATH_21 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_21")
class ResourceRegistrarTestCase(CleanDbTestCase):
@@ -60,7 +58,7 @@ def test_register_packs(self):
registrar = ResourceRegistrar(use_pack_cache=False)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_PATH_1}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_PATH_1}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_packs(base_dirs=packs_base_paths)
@@ -74,20 +72,20 @@ def test_register_packs(self):
pack_db = pack_dbs[0]
config_schema_db = config_schema_dbs[0]
- self.assertEqual(pack_db.name, 'dummy_pack_1')
+ self.assertEqual(pack_db.name, "dummy_pack_1")
self.assertEqual(len(pack_db.contributors), 2)
- self.assertEqual(pack_db.contributors[0], 'John Doe1 ')
- self.assertEqual(pack_db.contributors[1], 'John Doe2 ')
- self.assertIn('api_key', config_schema_db.attributes)
- self.assertIn('api_secret', config_schema_db.attributes)
+ self.assertEqual(pack_db.contributors[0], "John Doe1 ")
+ self.assertEqual(pack_db.contributors[1], "John Doe2 ")
+ self.assertIn("api_key", config_schema_db.attributes)
+ self.assertIn("api_secret", config_schema_db.attributes)
# Verify pack_db.files is correct and doesn't contain excluded files (*.pyc, .git/*, etc.)
# Note: We can't test that .git/* files are excluded since git doesn't allow you to add
# .git directory to existing repo index :/
excluded_files = [
- '__init__.pyc',
- 'actions/dummy1.pyc',
- 'actions/dummy2.pyc',
+ "__init__.pyc",
+ "actions/dummy1.pyc",
+ "actions/dummy2.pyc",
]
for excluded_file in excluded_files:
@@ -100,14 +98,14 @@ def test_register_pack_arbitrary_properties_are_allowed(self):
registrar = ResourceRegistrar(use_pack_cache=False)
registrar._pack_loader.get_packs = mock.Mock()
registrar._pack_loader.get_packs.return_value = {
- 'dummy_pack_20': PACK_PATH_20,
+ "dummy_pack_20": PACK_PATH_20,
}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_packs(base_dirs=packs_base_paths)
# Ref is provided
- pack_db = Pack.get_by_name('dummy_pack_20')
- self.assertEqual(pack_db.ref, 'dummy_pack_20_ref')
+ pack_db = Pack.get_by_name("dummy_pack_20")
+ self.assertEqual(pack_db.ref, "dummy_pack_20_ref")
self.assertEqual(len(pack_db.contributors), 0)
def test_register_pack_pack_ref(self):
@@ -119,53 +117,74 @@ def test_register_pack_pack_ref(self):
registrar = ResourceRegistrar(use_pack_cache=False)
registrar._pack_loader.get_packs = mock.Mock()
registrar._pack_loader.get_packs.return_value = {
- 'dummy_pack_1': PACK_PATH_1,
- 'dummy_pack_6': PACK_PATH_6
+ "dummy_pack_1": PACK_PATH_1,
+ "dummy_pack_6": PACK_PATH_6,
}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_packs(base_dirs=packs_base_paths)
# Ref is provided
- pack_db = Pack.get_by_name('dummy_pack_6')
- self.assertEqual(pack_db.ref, 'dummy_pack_6_ref')
+ pack_db = Pack.get_by_name("dummy_pack_6")
+ self.assertEqual(pack_db.ref, "dummy_pack_6_ref")
self.assertEqual(len(pack_db.contributors), 0)
# Ref is not provided, directory name should be used
- pack_db = Pack.get_by_name('dummy_pack_1')
- self.assertEqual(pack_db.ref, 'dummy_pack_1')
+ pack_db = Pack.get_by_name("dummy_pack_1")
+ self.assertEqual(pack_db.ref, "dummy_pack_1")
# "ref" is not provided, but "name" is
registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_7)
- pack_db = Pack.get_by_name('dummy_pack_7_name')
- self.assertEqual(pack_db.ref, 'dummy_pack_7_name')
+ pack_db = Pack.get_by_name("dummy_pack_7_name")
+ self.assertEqual(pack_db.ref, "dummy_pack_7_name")
# "ref" is not provided and "name" contains invalid characters
- expected_msg = 'contains invalid characters'
- self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db,
- pack_name=None, pack_dir=PACK_PATH_8)
+ expected_msg = "contains invalid characters"
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar._register_pack_db,
+ pack_name=None,
+ pack_dir=PACK_PATH_8,
+ )
def test_register_pack_invalid_ref_name_friendly_error_message(self):
registrar = ResourceRegistrar(use_pack_cache=False)
# Invalid ref
- expected_msg = (r'Pack ref / name can only contain valid word characters .*?,'
- ' dashes are not allowed.')
- self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db,
- pack_name=None, pack_dir=PACK_PATH_13)
+ expected_msg = (
+ r"Pack ref / name can only contain valid word characters .*?,"
+ " dashes are not allowed."
+ )
+ self.assertRaisesRegexp(
+ ValidationError,
+ expected_msg,
+ registrar._register_pack_db,
+ pack_name=None,
+ pack_dir=PACK_PATH_13,
+ )
try:
registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_13)
except ValidationError as e:
- self.assertIn("'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e))
+ self.assertIn(
+ "'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e)
+ )
else:
- self.fail('Exception not thrown')
+ self.fail("Exception not thrown")
# Pack ref not provided and name doesn't contain valid characters
- expected_msg = (r'Pack name "dummy pack 14" contains invalid characters and "ref" '
- 'attribute is not available. You either need to add')
- self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db,
- pack_name=None, pack_dir=PACK_PATH_14)
+ expected_msg = (
+ r'Pack name "dummy pack 14" contains invalid characters and "ref" '
+ "attribute is not available. You either need to add"
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar._register_pack_db,
+ pack_name=None,
+ pack_dir=PACK_PATH_14,
+ )
def test_register_pack_pack_stackstorm_version_and_future_parameters(self):
# Verify DB is empty
@@ -174,53 +193,74 @@ def test_register_pack_pack_stackstorm_version_and_future_parameters(self):
registrar = ResourceRegistrar(use_pack_cache=False)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_9': PACK_PATH_9}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_9": PACK_PATH_9}
packs_base_paths = content_utils.get_packs_base_paths()
registrar.register_packs(base_dirs=packs_base_paths)
# Dependencies, stackstorm_version and future values
- pack_db = Pack.get_by_name('dummy_pack_9_deps')
- self.assertEqual(pack_db.dependencies, ['core=0.2.0'])
- self.assertEqual(pack_db.stackstorm_version, '>=1.6.0, <2.2.0')
- self.assertEqual(pack_db.system, {'centos': {'foo': '>= 1.0'}})
- self.assertEqual(pack_db.python_versions, ['2', '3'])
+ pack_db = Pack.get_by_name("dummy_pack_9_deps")
+ self.assertEqual(pack_db.dependencies, ["core=0.2.0"])
+ self.assertEqual(pack_db.stackstorm_version, ">=1.6.0, <2.2.0")
+ self.assertEqual(pack_db.system, {"centos": {"foo": ">= 1.0"}})
+ self.assertEqual(pack_db.python_versions, ["2", "3"])
# Note: We only store parameters which are defined in the schema, all other custom user
# defined attributes are ignored
- self.assertTrue(not hasattr(pack_db, 'future'))
- self.assertTrue(not hasattr(pack_db, 'this'))
+ self.assertTrue(not hasattr(pack_db, "future"))
+ self.assertTrue(not hasattr(pack_db, "this"))
# Wrong characters in the required st2 version
expected_msg = "'wrongstackstormversion' does not match"
- self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db,
- pack_name=None, pack_dir=PACK_PATH_10)
+ self.assertRaisesRegexp(
+ ValidationError,
+ expected_msg,
+ registrar._register_pack_db,
+ pack_name=None,
+ pack_dir=PACK_PATH_10,
+ )
def test_register_pack_empty_and_invalid_config_schema(self):
registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_17': PACK_PATH_17}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_17": PACK_PATH_17}
packs_base_paths = content_utils.get_packs_base_paths()
- expected_msg = 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.'
- self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs,
- base_dirs=packs_base_paths)
+ expected_msg = (
+ 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.'
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_packs,
+ base_dirs=packs_base_paths,
+ )
def test_register_pack_invalid_config_schema_invalid_attribute(self):
registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_18': PACK_PATH_18}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_18": PACK_PATH_18}
packs_base_paths = content_utils.get_packs_base_paths()
- expected_msg = r'Additional properties are not allowed \(\'invalid\' was unexpected\)'
- self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs,
- base_dirs=packs_base_paths)
+ expected_msg = (
+ r"Additional properties are not allowed \(\'invalid\' was unexpected\)"
+ )
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_packs,
+ base_dirs=packs_base_paths,
+ )
def test_register_pack_invalid_python_versions_attribute(self):
registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True)
registrar._pack_loader.get_packs = mock.Mock()
- registrar._pack_loader.get_packs.return_value = {'dummy_pack_21': PACK_PATH_21}
+ registrar._pack_loader.get_packs.return_value = {"dummy_pack_21": PACK_PATH_21}
packs_base_paths = content_utils.get_packs_base_paths()
expected_msg = r"'4' is not one of \['2', '3'\]"
- self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs,
- base_dirs=packs_base_paths)
+ self.assertRaisesRegexp(
+ ValueError,
+ expected_msg,
+ registrar.register_packs,
+ base_dirs=packs_base_paths,
+ )
diff --git a/st2common/tests/unit/test_runners_base.py b/st2common/tests/unit/test_runners_base.py
index 34ede41adf..7490b40cd6 100644
--- a/st2common/tests/unit/test_runners_base.py
+++ b/st2common/tests/unit/test_runners_base.py
@@ -23,11 +23,12 @@
class RunnersLoaderUtilsTestCase(DbTestCase):
def test_get_runner_success(self):
- runner = get_runner('local-shell-cmd')
+ runner = get_runner("local-shell-cmd")
self.assertTrue(runner)
- self.assertEqual(runner.__class__.__name__, 'LocalShellCommandRunner')
+ self.assertEqual(runner.__class__.__name__, "LocalShellCommandRunner")
def test_get_runner_failure_not_found(self):
- expected_msg = 'Failed to find runner invalid-name-not-found.*'
- self.assertRaisesRegexp(ActionRunnerCreateError, expected_msg,
- get_runner, 'invalid-name-not-found')
+ expected_msg = "Failed to find runner invalid-name-not-found.*"
+ self.assertRaisesRegexp(
+ ActionRunnerCreateError, expected_msg, get_runner, "invalid-name-not-found"
+ )
diff --git a/st2common/tests/unit/test_runners_utils.py b/st2common/tests/unit/test_runners_utils.py
index dc98848223..bc6acfcf7e 100644
--- a/st2common/tests/unit/test_runners_utils.py
+++ b/st2common/tests/unit/test_runners_utils.py
@@ -24,16 +24,17 @@
from st2tests import config as tests_config
+
tests_config.parse_args()
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_FIXTURES = {
- 'liveactions': ['liveaction1.yaml'],
- 'actions': ['local.yaml'],
- 'executions': ['execution1.yaml'],
- 'runners': ['run-local.yaml']
+ "liveactions": ["liveaction1.yaml"],
+ "actions": ["local.yaml"],
+ "executions": ["execution1.yaml"],
+ "runners": ["run-local.yaml"],
}
@@ -48,15 +49,16 @@ def setUp(self):
loader = fixturesloader.FixturesLoader()
self.models = loader.save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_FIXTURES
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
)
- self.liveaction_db = self.models['liveactions']['liveaction1.yaml']
+ self.liveaction_db = self.models["liveactions"]["liveaction1.yaml"]
exe_svc.create_execution_object(self.liveaction_db)
self.action_db = action_db_utils.get_action_by_ref(self.liveaction_db.action)
- @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None))
+ @mock.patch.object(
+ action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None)
+ )
def test_invoke_post_run_action_provided(self):
utils.invoke_post_run(self.liveaction_db, action_db=self.action_db)
action_db_utils.get_action_by_ref.assert_not_called()
@@ -64,8 +66,12 @@ def test_invoke_post_run_action_provided(self):
def test_invoke_post_run_action_exists(self):
utils.invoke_post_run(self.liveaction_db)
- @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None))
- @mock.patch.object(action_db_utils, 'get_runnertype_by_name', mock.MagicMock(return_value=None))
+ @mock.patch.object(
+ action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None)
+ )
+ @mock.patch.object(
+ action_db_utils, "get_runnertype_by_name", mock.MagicMock(return_value=None)
+ )
def test_invoke_post_run_action_does_not_exist(self):
utils.invoke_post_run(self.liveaction_db)
action_db_utils.get_action_by_ref.assert_called_once()
diff --git a/st2common/tests/unit/test_sensor_type_utils.py b/st2common/tests/unit/test_sensor_type_utils.py
index 08269ebcf2..657054c453 100644
--- a/st2common/tests/unit/test_sensor_type_utils.py
+++ b/st2common/tests/unit/test_sensor_type_utils.py
@@ -22,59 +22,67 @@
class SensorTypeUtilsTestCase(unittest2.TestCase):
-
def test_to_sensor_db_model_no_trigger_types(self):
sensor_meta = {
- 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py',
- 'class_name': 'JIRASensor',
- 'pack': 'jira'
+ "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py",
+ "class_name": "JIRASensor",
+ "pack": "jira",
}
sensor_api = SensorTypeAPI(**sensor_meta)
sensor_model = SensorTypeAPI.to_model(sensor_api)
- self.assertEqual(sensor_model.name, sensor_meta['class_name'])
- self.assertEqual(sensor_model.pack, sensor_meta['pack'])
- self.assertEqual(sensor_model.artifact_uri, sensor_meta['artifact_uri'])
+ self.assertEqual(sensor_model.name, sensor_meta["class_name"])
+ self.assertEqual(sensor_model.pack, sensor_meta["pack"])
+ self.assertEqual(sensor_model.artifact_uri, sensor_meta["artifact_uri"])
self.assertListEqual(sensor_model.trigger_types, [])
- @mock.patch.object(sensor_type_utils, 'create_trigger_types', mock.MagicMock(
- return_value=['mock.trigger_ref']))
+ @mock.patch.object(
+ sensor_type_utils,
+ "create_trigger_types",
+ mock.MagicMock(return_value=["mock.trigger_ref"]),
+ )
def test_to_sensor_db_model_with_trigger_types(self):
sensor_meta = {
- 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py',
- 'class_name': 'JIRASensor',
- 'pack': 'jira',
- 'trigger_types': [{'pack': 'jira', 'name': 'issue_created', 'parameters': {}}]
+ "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py",
+ "class_name": "JIRASensor",
+ "pack": "jira",
+ "trigger_types": [
+ {"pack": "jira", "name": "issue_created", "parameters": {}}
+ ],
}
sensor_api = SensorTypeAPI(**sensor_meta)
sensor_model = SensorTypeAPI.to_model(sensor_api)
- self.assertListEqual(sensor_model.trigger_types, ['mock.trigger_ref'])
+ self.assertListEqual(sensor_model.trigger_types, ["mock.trigger_ref"])
def test_get_sensor_entry_point(self):
# System packs
- file_path = 'file:///data/st/st2reactor/st2reactor/' + \
- 'contrib/sensors/st2_generic_webhook_sensor.py'
- class_name = 'St2GenericWebhooksSensor'
+ file_path = (
+ "file:///data/st/st2reactor/st2reactor/"
+ + "contrib/sensors/st2_generic_webhook_sensor.py"
+ )
+ class_name = "St2GenericWebhooksSensor"
- sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'core'}
+ sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "core"}
sensor_api = SensorTypeAPI(**sensor)
entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api)
self.assertEqual(entry_point, class_name)
# Non system packs
- file_path = 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py'
- class_name = 'JIRASensor'
- sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'jira'}
+ file_path = "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py"
+ class_name = "JIRASensor"
+ sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "jira"}
sensor_api = SensorTypeAPI(**sensor)
entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api)
- self.assertEqual(entry_point, 'sensors.jira_sensor.JIRASensor')
+ self.assertEqual(entry_point, "sensors.jira_sensor.JIRASensor")
- file_path = 'file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py'
- class_name = 'DockerSensor'
- sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'docker'}
+ file_path = (
+ "file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py"
+ )
+ class_name = "DockerSensor"
+ sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "docker"}
sensor_api = SensorTypeAPI(**sensor)
entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api)
- self.assertEqual(entry_point, 'sensors.docker_container_sensor.DockerSensor')
+ self.assertEqual(entry_point, "sensors.docker_container_sensor.DockerSensor")
diff --git a/st2common/tests/unit/test_sensor_watcher.py b/st2common/tests/unit/test_sensor_watcher.py
index 65f61965df..2379f81562 100644
--- a/st2common/tests/unit/test_sensor_watcher.py
+++ b/st2common/tests/unit/test_sensor_watcher.py
@@ -22,39 +22,44 @@
from st2common.models.db.sensor import SensorTypeDB
from st2common.transport.publishers import PoolPublisher
-MOCK_SENSOR_DB = SensorTypeDB(name='foo', pack='test')
+MOCK_SENSOR_DB = SensorTypeDB(name="foo", pack="test")
class SensorWatcherTests(unittest2.TestCase):
-
- @mock.patch.object(Message, 'ack', mock.MagicMock())
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(Message, "ack", mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def test_assert_handlers_called(self):
handler_vars = {
- 'create_handler_called': False,
- 'update_handler_called': False,
- 'delete_handler_called': False
+ "create_handler_called": False,
+ "update_handler_called": False,
+ "delete_handler_called": False,
}
def create_handler(sensor_db):
- handler_vars['create_handler_called'] = True
+ handler_vars["create_handler_called"] = True
def update_handler(sensor_db):
- handler_vars['update_handler_called'] = True
+ handler_vars["update_handler_called"] = True
def delete_handler(sensor_db):
- handler_vars['delete_handler_called'] = True
+ handler_vars["delete_handler_called"] = True
sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler)
- message = Message(None, delivery_info={'routing_key': 'create'})
+ message = Message(None, delivery_info={"routing_key": "create"})
sensor_watcher.process_task(MOCK_SENSOR_DB, message)
- self.assertTrue(handler_vars['create_handler_called'], 'create handler should be called.')
+ self.assertTrue(
+ handler_vars["create_handler_called"], "create handler should be called."
+ )
- message = Message(None, delivery_info={'routing_key': 'update'})
+ message = Message(None, delivery_info={"routing_key": "update"})
sensor_watcher.process_task(MOCK_SENSOR_DB, message)
- self.assertTrue(handler_vars['update_handler_called'], 'update handler should be called.')
+ self.assertTrue(
+ handler_vars["update_handler_called"], "update handler should be called."
+ )
- message = Message(None, delivery_info={'routing_key': 'delete'})
+ message = Message(None, delivery_info={"routing_key": "delete"})
sensor_watcher.process_task(MOCK_SENSOR_DB, message)
- self.assertTrue(handler_vars['delete_handler_called'], 'delete handler should be called.')
+ self.assertTrue(
+ handler_vars["delete_handler_called"], "delete handler should be called."
+ )
diff --git a/st2common/tests/unit/test_service_setup.py b/st2common/tests/unit/test_service_setup.py
index 4000f6ce81..b1358f295d 100644
--- a/st2common/tests/unit/test_service_setup.py
+++ b/st2common/tests/unit/test_service_setup.py
@@ -31,9 +31,7 @@
from st2tests.base import CleanFilesTestCase
from st2tests import config
-__all__ = [
- 'ServiceSetupTestCase'
-]
+__all__ = ["ServiceSetupTestCase"]
MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL = """
[loggers]
@@ -61,11 +59,11 @@
datefmt=
""".strip()
-MOCK_DEFAULT_CONFIG_FILE_PATH = '/etc/st2/st2.conf-test-patched'
+MOCK_DEFAULT_CONFIG_FILE_PATH = "/etc/st2/st2.conf-test-patched"
def mock_get_logging_config_path():
- return ''
+ return ""
class ServiceSetupTestCase(CleanFilesTestCase):
@@ -78,19 +76,24 @@ def test_no_logging_config_found(self):
else:
expected_msg = "No section: .*"
- self.assertRaisesRegexp(Exception, expected_msg,
- service_setup.setup, service='api',
- config=config,
- setup_db=False, register_mq_exchanges=False,
- register_signal_handlers=False,
- register_internal_trigger_types=False,
- run_migrations=False)
+ self.assertRaisesRegexp(
+ Exception,
+ expected_msg,
+ service_setup.setup,
+ service="api",
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_signal_handlers=False,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ )
def test_invalid_log_level_friendly_error_message(self):
_, mock_logging_config_path = tempfile.mkstemp()
self.to_delete_files.append(mock_logging_config_path)
- with open(mock_logging_config_path, 'w') as fp:
+ with open(mock_logging_config_path, "w") as fp:
fp.write(MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL)
def mock_get_logging_config_path():
@@ -99,21 +102,28 @@ def mock_get_logging_config_path():
config.get_logging_config_path = mock_get_logging_config_path
if six.PY3:
- expected_msg = 'ValueError: Unknown level: \'invalid_log_level\''
+ expected_msg = "ValueError: Unknown level: 'invalid_log_level'"
exc_type = ValueError
else:
- expected_msg = 'Invalid log level selected. Log level names need to be all uppercase'
+ expected_msg = (
+ "Invalid log level selected. Log level names need to be all uppercase"
+ )
exc_type = KeyError
- self.assertRaisesRegexp(exc_type, expected_msg,
- service_setup.setup, service='api',
- config=config,
- setup_db=False, register_mq_exchanges=False,
- register_signal_handlers=False,
- register_internal_trigger_types=False,
- run_migrations=False)
-
- @mock.patch('kombu.Queue.declare')
+ self.assertRaisesRegexp(
+ exc_type,
+ expected_msg,
+ service_setup.setup,
+ service="api",
+ config=config,
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_signal_handlers=False,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ )
+
+ @mock.patch("kombu.Queue.declare")
def test_register_exchanges_predeclare_queues(self, mock_declare):
# Verify that queues are correctly pre-declared
self.assertEqual(mock_declare.call_count, 0)
@@ -121,34 +131,50 @@ def test_register_exchanges_predeclare_queues(self, mock_declare):
register_exchanges()
self.assertEqual(mock_declare.call_count, len(QUEUES))
- @mock.patch('st2common.constants.system.DEFAULT_CONFIG_FILE_PATH',
- MOCK_DEFAULT_CONFIG_FILE_PATH)
- @mock.patch('st2common.config.DEFAULT_CONFIG_FILE_PATH', MOCK_DEFAULT_CONFIG_FILE_PATH)
+ @mock.patch(
+ "st2common.constants.system.DEFAULT_CONFIG_FILE_PATH",
+ MOCK_DEFAULT_CONFIG_FILE_PATH,
+ )
+ @mock.patch(
+ "st2common.config.DEFAULT_CONFIG_FILE_PATH", MOCK_DEFAULT_CONFIG_FILE_PATH
+ )
def test_service_setup_default_st2_conf_config_is_used(self):
st2common_config.get_logging_config_path = mock_get_logging_config_path
cfg.CONF.reset()
# 1. DEFAULT_CONFIG_FILE_PATH config path should be used by default (/etc/st2/st2.conf)
- expected_msg = 'Failed to find some config files: %s' % (MOCK_DEFAULT_CONFIG_FILE_PATH)
- self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup,
- service='api',
- config=st2common_config,
- config_args=['--debug'],
- setup_db=False, register_mq_exchanges=False,
- register_signal_handlers=False,
- register_internal_trigger_types=False,
- run_migrations=False)
+ expected_msg = "Failed to find some config files: %s" % (
+ MOCK_DEFAULT_CONFIG_FILE_PATH
+ )
+ self.assertRaisesRegexp(
+ ConfigFilesNotFoundError,
+ expected_msg,
+ service_setup.setup,
+ service="api",
+ config=st2common_config,
+ config_args=["--debug"],
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_signal_handlers=False,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ )
cfg.CONF.reset()
# 2. --config-file should still override default config file path option
- config_file_path = '/etc/st2/config.override.test'
- expected_msg = 'Failed to find some config files: %s' % (config_file_path)
- self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup,
- service='api',
- config=st2common_config,
- config_args=['--config-file', config_file_path],
- setup_db=False, register_mq_exchanges=False,
- register_signal_handlers=False,
- register_internal_trigger_types=False,
- run_migrations=False)
+ config_file_path = "/etc/st2/config.override.test"
+ expected_msg = "Failed to find some config files: %s" % (config_file_path)
+ self.assertRaisesRegexp(
+ ConfigFilesNotFoundError,
+ expected_msg,
+ service_setup.setup,
+ service="api",
+ config=st2common_config,
+ config_args=["--config-file", config_file_path],
+ setup_db=False,
+ register_mq_exchanges=False,
+ register_signal_handlers=False,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ )
diff --git a/st2common/tests/unit/test_shell_action_system_model.py b/st2common/tests/unit/test_shell_action_system_model.py
index 6fdc7d1716..76609ab953 100644
--- a/st2common/tests/unit/test_shell_action_system_model.py
+++ b/st2common/tests/unit/test_shell_action_system_model.py
@@ -32,90 +32,87 @@
from local_runner.local_shell_script_runner import LocalShellScriptRunner
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
-FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures'))
+FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures"))
LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0]
-__all__ = [
- 'ShellCommandActionTestCase',
- 'ShellScriptActionTestCase'
-]
+__all__ = ["ShellCommandActionTestCase", "ShellScriptActionTestCase"]
class ShellCommandActionTestCase(unittest2.TestCase):
def setUp(self):
self._base_kwargs = {
- 'name': 'test action',
- 'action_exec_id': '1',
- 'command': 'ls -la',
- 'env_vars': {},
- 'timeout': None
+ "name": "test action",
+ "action_exec_id": "1",
+ "command": "ls -la",
+ "env_vars": {},
+ "timeout": None,
}
def test_user_argument(self):
# User is the same as logged user, no sudo should be used
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
action = ShellCommandAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, 'ls -la')
+ self.assertEqual(command, "ls -la")
# User is different, sudo should be used
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
action = ShellCommandAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c \'ls -la\'')
+ self.assertEqual(command, "sudo -E -H -u mauser -- bash -c 'ls -la'")
# sudo with password
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['sudo_password'] = 'sudopass'
- kwargs['user'] = 'mauser'
+ kwargs["sudo"] = False
+ kwargs["sudo_password"] = "sudopass"
+ kwargs["user"] = "mauser"
action = ShellCommandAction(**kwargs)
command = action.get_full_command_string()
- expected_command = 'sudo -S -E -H -u mauser -- bash -c \'ls -la\''
+ expected_command = "sudo -S -E -H -u mauser -- bash -c 'ls -la'"
self.assertEqual(command, expected_command)
# sudo is used, it doesn't matter what user is specified since the
# command should run as root
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = True
- kwargs['user'] = 'mauser'
+ kwargs["sudo"] = True
+ kwargs["user"] = "mauser"
action = ShellCommandAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, 'sudo -E -- bash -c \'ls -la\'')
+ self.assertEqual(command, "sudo -E -- bash -c 'ls -la'")
# sudo with passwd
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = True
- kwargs['user'] = 'mauser'
- kwargs['sudo_password'] = 'sudopass'
+ kwargs["sudo"] = True
+ kwargs["user"] = "mauser"
+ kwargs["sudo_password"] = "sudopass"
action = ShellCommandAction(**kwargs)
command = action.get_full_command_string()
- expected_command = 'sudo -S -E -- bash -c \'ls -la\''
+ expected_command = "sudo -S -E -- bash -c 'ls -la'"
self.assertEqual(command, expected_command)
class ShellScriptActionTestCase(unittest2.TestCase):
def setUp(self):
self._base_kwargs = {
- 'name': 'test action',
- 'action_exec_id': '1',
- 'script_local_path_abs': '/tmp/foo.sh',
- 'named_args': {},
- 'positional_args': [],
- 'env_vars': {},
- 'timeout': None
+ "name": "test action",
+ "action_exec_id": "1",
+ "script_local_path_abs": "/tmp/foo.sh",
+ "named_args": {},
+ "positional_args": [],
+ "env_vars": {},
+ "timeout": None,
}
def _get_fixture(self, name):
- path = os.path.join(FIXTURES_DIR, 'local_runner', name)
+ path = os.path.join(FIXTURES_DIR, "local_runner", name)
- with open(path, 'r') as fp:
+ with open(path, "r") as fp:
content = fp.read().strip()
return content
@@ -123,371 +120,374 @@ def _get_fixture(self, name):
def test_user_argument(self):
# User is the same as logged user, no sudo should be used
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, '/tmp/foo.sh')
+ self.assertEqual(command, "/tmp/foo.sh")
# User is different, sudo should be used
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c /tmp/foo.sh')
+ self.assertEqual(command, "sudo -E -H -u mauser -- bash -c /tmp/foo.sh")
# sudo with password
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
- kwargs['sudo_password'] = 'sudopass'
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
+ kwargs["sudo_password"] = "sudopass"
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected_command = 'sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh'
+ expected_command = "sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh"
self.assertEqual(command, expected_command)
# complex sudo password which needs escaping
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
- kwargs['sudo_password'] = '$udo p\'as"sss'
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
+ kwargs["sudo_password"] = "$udo p'as\"sss"
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected_command = ('sudo -S -E -H '
- '-u mauser -- bash -c /tmp/foo.sh')
+ expected_command = "sudo -S -E -H " "-u mauser -- bash -c /tmp/foo.sh"
self.assertEqual(command, expected_command)
command = action.get_sanitized_full_command_string()
- expected_command = ('echo -e \'%s\n\' | sudo -S -E -H '
- '-u mauser -- bash -c /tmp/foo.sh' % (MASKED_ATTRIBUTE_VALUE))
+ expected_command = (
+ "echo -e '%s\n' | sudo -S -E -H "
+ "-u mauser -- bash -c /tmp/foo.sh" % (MASKED_ATTRIBUTE_VALUE)
+ )
self.assertEqual(command, expected_command)
# sudo is used, it doesn't matter what user is specified since the
# command should run as root
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = True
- kwargs['user'] = 'mauser'
- kwargs['sudo_password'] = 'sudopass'
+ kwargs["sudo"] = True
+ kwargs["user"] = "mauser"
+ kwargs["sudo_password"] = "sudopass"
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected_command = 'sudo -S -E -- bash -c /tmp/foo.sh'
+ expected_command = "sudo -S -E -- bash -c /tmp/foo.sh"
self.assertEqual(command, expected_command)
def test_command_construction_with_parameters(self):
# same user, named args, no positional args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2')
- ])
- kwargs['positional_args'] = []
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")])
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, '/tmp/foo.sh key1=value1 key2=value2')
+ self.assertEqual(command, "/tmp/foo.sh key1=value1 key2=value2")
# same user, named args, no positional args, sudo password
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = True
- kwargs['sudo_password'] = 'sudopass'
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2')
- ])
- kwargs['positional_args'] = []
+ kwargs["sudo"] = True
+ kwargs["sudo_password"] = "sudopass"
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")])
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = ('sudo -S -E -- bash -c '
- '\'/tmp/foo.sh key1=value1 key2=value2\'')
+ expected = "sudo -S -E -- bash -c " "'/tmp/foo.sh key1=value1 key2=value2'"
self.assertEqual(command, expected)
# different user, named args, no positional args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2')
- ])
- kwargs['positional_args'] = []
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
+ kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")])
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = 'sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2\''
+ expected = (
+ "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2'"
+ )
self.assertEqual(command, expected)
# different user, named args, no positional args, sudo password
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['sudo_password'] = 'sudopass'
- kwargs['user'] = 'mauser'
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2')
- ])
- kwargs['positional_args'] = []
+ kwargs["sudo"] = False
+ kwargs["sudo_password"] = "sudopass"
+ kwargs["user"] = "mauser"
+ kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")])
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = ('sudo -S -E -H -u mauser -- bash -c '
- '\'/tmp/foo.sh key1=value1 key2=value2\'')
+ expected = (
+ "sudo -S -E -H -u mauser -- bash -c "
+ "'/tmp/foo.sh key1=value1 key2=value2'"
+ )
self.assertEqual(command, expected)
# same user, positional args, no named args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = {}
- kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia', 'foo\nbar']
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = {}
+ kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia", "foo\nbar"]
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, '/tmp/foo.sh ein zwei drei \'mamma mia\' \'foo\nbar\'')
+ self.assertEqual(command, "/tmp/foo.sh ein zwei drei 'mamma mia' 'foo\nbar'")
# different user, named args, positional args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
- kwargs['named_args'] = {}
- kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia']
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
+ kwargs["named_args"] = {}
+ kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia"]
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- ex = ('sudo -E -H -u mauser -- '
- 'bash -c \'/tmp/foo.sh ein zwei drei \'"\'"\'mamma mia\'"\'"\'\'')
+ ex = (
+ "sudo -E -H -u mauser -- "
+ "bash -c '/tmp/foo.sh ein zwei drei '\"'\"'mamma mia'\"'\"''"
+ )
self.assertEqual(command, ex)
# same user, positional and named args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2'),
- ('key3', 'value 3')
- ])
-
- kwargs['positional_args'] = ['ein', 'zwei', 'drei']
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = OrderedDict(
+ [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")]
+ )
+
+ kwargs["positional_args"] = ["ein", "zwei", "drei"]
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- exp = '/tmp/foo.sh key1=value1 key2=value2 key3=\'value 3\' ein zwei drei'
+ exp = "/tmp/foo.sh key1=value1 key2=value2 key3='value 3' ein zwei drei"
self.assertEqual(command, exp)
# different user, positional and named args
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = 'mauser'
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value1'),
- ('key2', 'value2'),
- ('key3', 'value 3')
- ])
- kwargs['positional_args'] = ['ein', 'zwei', 'drei']
+ kwargs["sudo"] = False
+ kwargs["user"] = "mauser"
+ kwargs["named_args"] = OrderedDict(
+ [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")]
+ )
+ kwargs["positional_args"] = ["ein", "zwei", "drei"]
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = ('sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2 '
- 'key3=\'"\'"\'value 3\'"\'"\' ein zwei drei\'')
+ expected = (
+ "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2 "
+ "key3='\"'\"'value 3'\"'\"' ein zwei drei'"
+ )
self.assertEqual(command, expected)
def test_named_parameter_escaping(self):
# no sudo
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value foo bar'),
- ('key2', 'value "bar" foo'),
- ('key3', 'date ; whoami'),
- ('key4', '"date ; whoami"'),
- ])
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = OrderedDict(
+ [
+ ("key1", "value foo bar"),
+ ("key2", 'value "bar" foo'),
+ ("key3", "date ; whoami"),
+ ("key4", '"date ; whoami"'),
+ ]
+ )
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = self._get_fixture('escaping_test_command_1.txt')
+ expected = self._get_fixture("escaping_test_command_1.txt")
self.assertEqual(command, expected)
# sudo
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = True
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = OrderedDict([
- ('key1', 'value foo bar'),
- ('key2', 'value "bar" foo'),
- ('key3', 'date ; whoami'),
- ('key4', '"date ; whoami"'),
- ])
+ kwargs["sudo"] = True
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = OrderedDict(
+ [
+ ("key1", "value foo bar"),
+ ("key2", 'value "bar" foo'),
+ ("key3", "date ; whoami"),
+ ("key4", '"date ; whoami"'),
+ ]
+ )
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- expected = self._get_fixture('escaping_test_command_2.txt')
+ expected = self._get_fixture("escaping_test_command_2.txt")
self.assertEqual(command, expected)
def test_various_ascii_parameters(self):
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = {'foo1': 'bar1', 'foo2': 'bar2'}
- kwargs['positional_args'] = []
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = {"foo1": "bar1", "foo2": "bar2"}
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, u"/tmp/foo.sh foo1=bar1 foo2=bar2")
+ self.assertEqual(command, "/tmp/foo.sh foo1=bar1 foo2=bar2")
def test_unicode_parameter_specifing(self):
kwargs = copy.deepcopy(self._base_kwargs)
- kwargs['sudo'] = False
- kwargs['user'] = LOGGED_USER_USERNAME
- kwargs['named_args'] = {u'foo': u'bar'}
- kwargs['positional_args'] = []
+ kwargs["sudo"] = False
+ kwargs["user"] = LOGGED_USER_USERNAME
+ kwargs["named_args"] = {"foo": "bar"}
+ kwargs["positional_args"] = []
action = ShellScriptAction(**kwargs)
command = action.get_full_command_string()
- self.assertEqual(command, u"/tmp/foo.sh 'foo'='bar'")
+ self.assertEqual(command, "/tmp/foo.sh 'foo'='bar'")
def test_command_construction_correct_default_parameter_values_are_used(self):
runner_parameters = {}
action_db_parameters = {
- 'project': {
- 'type': 'string',
- 'default': 'st2',
- 'position': 0,
- },
- 'version': {
- 'type': 'string',
- 'position': 1,
- 'required': True
+ "project": {
+ "type": "string",
+ "default": "st2",
+ "position": 0,
},
- 'fork': {
- 'type': 'string',
- 'position': 2,
- 'default': 'StackStorm',
+ "version": {"type": "string", "position": 1, "required": True},
+ "fork": {
+ "type": "string",
+ "position": 2,
+ "default": "StackStorm",
},
- 'branch': {
- 'type': 'string',
- 'position': 3,
- 'default': 'master',
+ "branch": {
+ "type": "string",
+ "position": 3,
+ "default": "master",
},
- 'update_changelog': {
- 'type': 'boolean',
- 'position': 4,
- 'default': False
+ "update_changelog": {"type": "boolean", "position": 4, "default": False},
+ "local_repo": {
+ "type": "string",
+ "position": 5,
},
- 'local_repo': {
- 'type': 'string',
- 'position': 5,
- }
}
context = {}
- action_db = ActionDB(pack='dummy', name='action')
+ action_db = ActionDB(pack="dummy", name="action")
- runner = LocalShellScriptRunner('id')
+ runner = LocalShellScriptRunner("id")
runner.runner_parameters = {}
runner.action = action_db
# 1. All default values used
live_action_db_parameters = {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'local_repo': '/tmp/repo'
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "local_repo": "/tmp/repo",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
-
- self.assertDictEqual(action_params, {
- 'project': 'st2flow',
- 'version': '3.0.0',
- 'fork': 'StackStorm',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repo'
- })
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
+
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2flow",
+ "version": "3.0.0",
+ "fork": "StackStorm",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repo",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
- shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy',
- script_local_path_abs='/tmp/local.sh',
- named_args=named_args,
- positional_args=positional_args)
+ shell_script_action = ShellScriptAction(
+ name="dummy",
+ action_exec_id="dummy",
+ script_local_path_abs="/tmp/local.sh",
+ named_args=named_args,
+ positional_args=positional_args,
+ )
command_string = shell_script_action.get_full_command_string()
- expected = '/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo'
+ expected = "/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo"
self.assertEqual(command_string, expected)
# 2. Some default values used
live_action_db_parameters = {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'update_changelog': True,
- 'local_repo': '/tmp/repob'
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "update_changelog": True,
+ "local_repo": "/tmp/repob",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
-
- self.assertDictEqual(action_params, {
- 'project': 'st2web',
- 'version': '3.1.0',
- 'fork': 'StackStorm1',
- 'branch': 'master', # default value used
- 'update_changelog': True, # default value used
- 'local_repo': '/tmp/repob'
- })
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
+
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2web",
+ "version": "3.1.0",
+ "fork": "StackStorm1",
+ "branch": "master", # default value used
+ "update_changelog": True, # default value used
+ "local_repo": "/tmp/repob",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
- shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy',
- script_local_path_abs='/tmp/local.sh',
- named_args=named_args,
- positional_args=positional_args)
+ shell_script_action = ShellScriptAction(
+ name="dummy",
+ action_exec_id="dummy",
+ script_local_path_abs="/tmp/local.sh",
+ named_args=named_args,
+ positional_args=positional_args,
+ )
command_string = shell_script_action.get_full_command_string()
- expected = '/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob'
+ expected = "/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob"
self.assertEqual(command_string, expected)
# 3. None is specified for a boolean parameter, should use a default
live_action_db_parameters = {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'update_changelog': None,
- 'local_repo': '/tmp/repoc'
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "update_changelog": None,
+ "local_repo": "/tmp/repoc",
}
- runner_params, action_params = param_utils.render_final_params(runner_parameters,
- action_db_parameters,
- live_action_db_parameters,
- context)
-
- self.assertDictEqual(action_params, {
- 'project': 'st2rbac',
- 'version': '3.2.0',
- 'fork': 'StackStorm2',
- 'branch': 'master', # default value used
- 'update_changelog': False, # default value used
- 'local_repo': '/tmp/repoc'
- })
+ runner_params, action_params = param_utils.render_final_params(
+ runner_parameters, action_db_parameters, live_action_db_parameters, context
+ )
+
+ self.assertDictEqual(
+ action_params,
+ {
+ "project": "st2rbac",
+ "version": "3.2.0",
+ "fork": "StackStorm2",
+ "branch": "master", # default value used
+ "update_changelog": False, # default value used
+ "local_repo": "/tmp/repoc",
+ },
+ )
action_db.parameters = action_db_parameters
positional_args, named_args = runner._get_script_args(action_params)
named_args = runner._transform_named_args(named_args)
- shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy',
- script_local_path_abs='/tmp/local.sh',
- named_args=named_args,
- positional_args=positional_args)
+ shell_script_action = ShellScriptAction(
+ name="dummy",
+ action_exec_id="dummy",
+ script_local_path_abs="/tmp/local.sh",
+ named_args=named_args,
+ positional_args=positional_args,
+ )
command_string = shell_script_action.get_full_command_string()
- expected = '/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc'
+ expected = "/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc"
self.assertEqual(command_string, expected)
diff --git a/st2common/tests/unit/test_state_publisher.py b/st2common/tests/unit/test_state_publisher.py
index 1fa87b8487..99dbabda7f 100644
--- a/st2common/tests/unit/test_state_publisher.py
+++ b/st2common/tests/unit/test_state_publisher.py
@@ -27,7 +27,7 @@
from st2tests import DbTestCase
-FAKE_STATE_MGMT_XCHG = kombu.Exchange('st2.fake.state', type='topic')
+FAKE_STATE_MGMT_XCHG = kombu.Exchange("st2.fake.state", type="topic")
class FakeModelPublisher(publishers.StatePublisherMixin):
@@ -57,7 +57,7 @@ def _get_publisher(cls):
def publish_state(cls, model_object):
publisher = cls._get_publisher()
if publisher:
- publisher.publish_state(model_object, getattr(model_object, 'state', None))
+ publisher.publish_state(model_object, getattr(model_object, "state", None))
@classmethod
def _get_by_object(cls, object):
@@ -65,7 +65,6 @@ def _get_by_object(cls, object):
class StatePublisherTest(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(StatePublisherTest, cls).setUpClass()
@@ -75,13 +74,13 @@ def tearDown(self):
FakeModelDB.drop_collection()
super(StatePublisherTest, self).tearDown()
- @mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock())
def test_publish(self):
- instance = FakeModelDB(state='faked')
+ instance = FakeModelDB(state="faked")
self.access.publish_state(instance)
- publishers.PoolPublisher.publish.assert_called_with(instance,
- FAKE_STATE_MGMT_XCHG,
- instance.state)
+ publishers.PoolPublisher.publish.assert_called_with(
+ instance, FAKE_STATE_MGMT_XCHG, instance.state
+ )
def test_publish_unset(self):
instance = FakeModelDB()
@@ -92,5 +91,5 @@ def test_publish_none(self):
self.assertRaises(Exception, self.access.publish_state, instance)
def test_publish_empty_str(self):
- instance = FakeModelDB(state='')
+ instance = FakeModelDB(state="")
self.assertRaises(Exception, self.access.publish_state, instance)
diff --git a/st2common/tests/unit/test_stream_generator.py b/st2common/tests/unit/test_stream_generator.py
index 9c44db4657..a184220b80 100644
--- a/st2common/tests/unit/test_stream_generator.py
+++ b/st2common/tests/unit/test_stream_generator.py
@@ -20,7 +20,6 @@
class MockBody(object):
-
def __init__(self, id):
self.id = id
self.status = "succeeded"
@@ -32,8 +31,7 @@ def __init__(self, id):
EVENTS = [(INCLUDE, MockBody("notend")), (END_EVENT, MockBody(END_ID))]
-class MockQueue():
-
+class MockQueue:
def __init__(self):
self.items = EVENTS
@@ -47,7 +45,6 @@ def put(self, event):
class MockListener(listener.BaseListener):
-
def __init__(self, *args, **kwargs):
super(MockListener, self).__init__(*args, **kwargs)
@@ -56,19 +53,19 @@ def get_consumers(self, consumer, channel):
class TestStream(unittest2.TestCase):
-
- @mock.patch('st2common.stream.listener.BaseListener._get_action_ref_for_body')
- @mock.patch('eventlet.Queue')
- def test_generator(self, mock_queue,
- get_action_ref_for_body):
+ @mock.patch("st2common.stream.listener.BaseListener._get_action_ref_for_body")
+ @mock.patch("eventlet.Queue")
+ def test_generator(self, mock_queue, get_action_ref_for_body):
get_action_ref_for_body.return_value = None
mock_queue.return_value = MockQueue()
mock_consumer = MockListener(connection=None)
mock_consumer._stopped = False
- app_iter = mock_consumer.generator(events=INCLUDE,
+ app_iter = mock_consumer.generator(
+ events=INCLUDE,
end_event=END_EVENT,
end_statuses=["succeeded"],
- end_execution_id=END_ID)
- events = EVENTS.append('')
+ end_execution_id=END_ID,
+ )
+ events = EVENTS.append("")
for index, val in enumerate(app_iter):
self.assertEquals(val, events[index])
diff --git a/st2common/tests/unit/test_system_info.py b/st2common/tests/unit/test_system_info.py
index e7ddb20bef..c840a7aa8b 100644
--- a/st2common/tests/unit/test_system_info.py
+++ b/st2common/tests/unit/test_system_info.py
@@ -23,8 +23,7 @@
class TestLogger(unittest.TestCase):
-
def test_process_info(self):
process_info = system_info.get_process_info()
- self.assertEqual(process_info['hostname'], socket.gethostname())
- self.assertEqual(process_info['pid'], os.getpid())
+ self.assertEqual(process_info["hostname"], socket.gethostname())
+ self.assertEqual(process_info["pid"], os.getpid())
diff --git a/st2common/tests/unit/test_tags.py b/st2common/tests/unit/test_tags.py
index 3ffc59b50a..6230cedea6 100644
--- a/st2common/tests/unit/test_tags.py
+++ b/st2common/tests/unit/test_tags.py
@@ -28,53 +28,69 @@ class TaggedModel(stormbase.StormFoundationDB, stormbase.TagsMixin):
class TestTags(DbTestCase):
-
def test_simple_count(self):
instance = TaggedModel()
- instance.tags = [stormbase.TagField(name='tag1', value='v1'),
- stormbase.TagField(name='tag2', value='v2')]
+ instance.tags = [
+ stormbase.TagField(name="tag1", value="v1"),
+ stormbase.TagField(name="tag2", value="v2"),
+ ]
saved = instance.save()
retrieved = TaggedModel.objects(id=instance.id).first()
- self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.')
+ self.assertEqual(
+ len(saved.tags), len(retrieved.tags), "Failed to retrieve tags."
+ )
def test_simple_value(self):
instance = TaggedModel()
- instance.tags = [stormbase.TagField(name='tag1', value='v1')]
+ instance.tags = [stormbase.TagField(name="tag1", value="v1")]
saved = instance.save()
retrieved = TaggedModel.objects(id=instance.id).first()
- self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.')
+ self.assertEqual(
+ len(saved.tags), len(retrieved.tags), "Failed to retrieve tags."
+ )
saved_tag = saved.tags[0]
retrieved_tag = retrieved.tags[0]
- self.assertEqual(saved_tag.name, retrieved_tag.name, 'Failed to retrieve tag.')
- self.assertEqual(saved_tag.value, retrieved_tag.value, 'Failed to retrieve tag.')
+ self.assertEqual(saved_tag.name, retrieved_tag.name, "Failed to retrieve tag.")
+ self.assertEqual(
+ saved_tag.value, retrieved_tag.value, "Failed to retrieve tag."
+ )
def test_tag_max_size_restriction(self):
instance = TaggedModel()
- instance.tags = [stormbase.TagField(name=self._gen_random_string(),
- value=self._gen_random_string())]
+ instance.tags = [
+ stormbase.TagField(
+ name=self._gen_random_string(), value=self._gen_random_string()
+ )
+ ]
saved = instance.save()
retrieved = TaggedModel.objects(id=instance.id).first()
- self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.')
+ self.assertEqual(
+ len(saved.tags), len(retrieved.tags), "Failed to retrieve tags."
+ )
def test_name_exceeds_max_size(self):
instance = TaggedModel()
- instance.tags = [stormbase.TagField(name=self._gen_random_string(1025),
- value='v1')]
+ instance.tags = [
+ stormbase.TagField(name=self._gen_random_string(1025), value="v1")
+ ]
try:
instance.save()
- self.assertTrue(False, 'Expected save to fail')
+ self.assertTrue(False, "Expected save to fail")
except ValidationError:
pass
def test_value_exceeds_max_size(self):
instance = TaggedModel()
- instance.tags = [stormbase.TagField(name='n1',
- value=self._gen_random_string(1025))]
+ instance.tags = [
+ stormbase.TagField(name="n1", value=self._gen_random_string(1025))
+ ]
try:
instance.save()
- self.assertTrue(False, 'Expected save to fail')
+ self.assertTrue(False, "Expected save to fail")
except ValidationError:
pass
- def _gen_random_string(self, size=1024, chars=string.ascii_lowercase + string.digits):
- return ''.join([random.choice(chars) for _ in range(size)])
+ def _gen_random_string(
+ self, size=1024, chars=string.ascii_lowercase + string.digits
+ ):
+ return "".join([random.choice(chars) for _ in range(size)])
diff --git a/st2common/tests/unit/test_time_jinja_filters.py b/st2common/tests/unit/test_time_jinja_filters.py
index c61473cdfd..5a343a5c29 100644
--- a/st2common/tests/unit/test_time_jinja_filters.py
+++ b/st2common/tests/unit/test_time_jinja_filters.py
@@ -20,14 +20,16 @@
class TestTimeJinjaFilters(TestCase):
-
def test_to_human_time_from_seconds(self):
- self.assertEqual('0s', time.to_human_time_from_seconds(seconds=0))
- self.assertEqual('0.1\u03BCs', time.to_human_time_from_seconds(seconds=0.1))
- self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56))
- self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56.2))
- self.assertEqual('7m36s', time.to_human_time_from_seconds(seconds=456))
- self.assertEqual('1h16m0s', time.to_human_time_from_seconds(seconds=4560))
- self.assertEqual('1y12d16h36m37s', time.to_human_time_from_seconds(seconds=45678997))
- self.assertRaises(AssertionError, time.to_human_time_from_seconds,
- seconds='stuff')
+ self.assertEqual("0s", time.to_human_time_from_seconds(seconds=0))
+ self.assertEqual("0.1\u03BCs", time.to_human_time_from_seconds(seconds=0.1))
+ self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56))
+ self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56.2))
+ self.assertEqual("7m36s", time.to_human_time_from_seconds(seconds=456))
+ self.assertEqual("1h16m0s", time.to_human_time_from_seconds(seconds=4560))
+ self.assertEqual(
+ "1y12d16h36m37s", time.to_human_time_from_seconds(seconds=45678997)
+ )
+ self.assertRaises(
+ AssertionError, time.to_human_time_from_seconds, seconds="stuff"
+ )
diff --git a/st2common/tests/unit/test_transport.py b/st2common/tests/unit/test_transport.py
index 9e4d4789b2..75e35ae2c9 100644
--- a/st2common/tests/unit/test_transport.py
+++ b/st2common/tests/unit/test_transport.py
@@ -19,9 +19,7 @@
from st2common.transport.utils import _get_ssl_kwargs
-__all__ = [
- 'TransportUtilsTestCase'
-]
+__all__ = ["TransportUtilsTestCase"]
class TransportUtilsTestCase(unittest2.TestCase):
@@ -32,49 +30,39 @@ def test_get_ssl_kwargs(self):
# 2. ssl kwarg provided
ssl_kwargs = _get_ssl_kwargs(ssl=True)
- self.assertEqual(ssl_kwargs, {
- 'ssl': True
- })
+ self.assertEqual(ssl_kwargs, {"ssl": True})
# 3. ssl_keyfile provided
- ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'keyfile': '/tmp/keyfile'
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile")
+ self.assertEqual(ssl_kwargs, {"ssl": True, "keyfile": "/tmp/keyfile"})
# 4. ssl_certfile provided
- ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'certfile': '/tmp/certfile'
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile")
+ self.assertEqual(ssl_kwargs, {"ssl": True, "certfile": "/tmp/certfile"})
# 5. ssl_ca_certs provided
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ca_certs': '/tmp/ca_certs'
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs")
+ self.assertEqual(ssl_kwargs, {"ssl": True, "ca_certs": "/tmp/ca_certs"})
# 6. ssl_ca_certs and ssl_cert_reqs combinations
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ca_certs': '/tmp/ca_certs',
- 'cert_reqs': ssl.CERT_NONE
- })
+ ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none")
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_NONE},
+ )
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ca_certs': '/tmp/ca_certs',
- 'cert_reqs': ssl.CERT_OPTIONAL
- })
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional"
+ )
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_OPTIONAL},
+ )
- ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required')
- self.assertEqual(ssl_kwargs, {
- 'ssl': True,
- 'ca_certs': '/tmp/ca_certs',
- 'cert_reqs': ssl.CERT_REQUIRED
- })
+ ssl_kwargs = _get_ssl_kwargs(
+ ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required"
+ )
+ self.assertEqual(
+ ssl_kwargs,
+ {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_REQUIRED},
+ )
diff --git a/st2common/tests/unit/test_trigger_services.py b/st2common/tests/unit/test_trigger_services.py
index 6f66a5f55b..b843526bc9 100644
--- a/st2common/tests/unit/test_trigger_services.py
+++ b/st2common/tests/unit/test_trigger_services.py
@@ -18,124 +18,147 @@
from st2common.models.api.rule import RuleAPI
from st2common.models.system.common import ResourceReference
from st2common.models.db.trigger import TriggerDB
-from st2common.persistence.trigger import (Trigger, TriggerType)
+from st2common.persistence.trigger import Trigger, TriggerType
import st2common.services.triggers as trigger_service
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import FixturesLoader
-MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', parameters={},
- type='dummy_pack_1.trigger-type-test.name')
+MOCK_TRIGGER = TriggerDB(
+ pack="dummy_pack_1",
+ name="trigger-test.name",
+ parameters={},
+ type="dummy_pack_1.trigger-type-test.name",
+)
class TriggerServiceTests(CleanDbTestCase):
-
def test_create_trigger_db_from_rule(self):
- test_fixtures = {
- 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_3.yaml']
- }
+ test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_3.yaml"]}
loader = FixturesLoader()
- fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures)
- rules = fixtures['rules']
+ fixtures = loader.load_fixtures(
+ fixtures_pack="generic", fixtures_dict=test_fixtures
+ )
+ rules = fixtures["rules"]
trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule(
- RuleAPI(**rules['cron_timer_rule_1.yaml']))
+ RuleAPI(**rules["cron_timer_rule_1.yaml"])
+ )
self.assertIsNotNone(trigger_db_ret_1)
trigger_db = Trigger.get_by_id(trigger_db_ret_1.id)
- self.assertDictEqual(trigger_db.parameters,
- rules['cron_timer_rule_1.yaml']['trigger']['parameters'])
+ self.assertDictEqual(
+ trigger_db.parameters,
+ rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"],
+ )
trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule(
- RuleAPI(**rules['cron_timer_rule_3.yaml']))
+ RuleAPI(**rules["cron_timer_rule_3.yaml"])
+ )
self.assertIsNotNone(trigger_db_ret_2)
self.assertTrue(trigger_db_ret_2.id != trigger_db_ret_1.id)
def test_create_trigger_db_from_rule_duplicate(self):
- test_fixtures = {
- 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_2.yaml']
- }
+ test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_2.yaml"]}
loader = FixturesLoader()
- fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures)
- rules = fixtures['rules']
+ fixtures = loader.load_fixtures(
+ fixtures_pack="generic", fixtures_dict=test_fixtures
+ )
+ rules = fixtures["rules"]
trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule(
- RuleAPI(**rules['cron_timer_rule_1.yaml']))
+ RuleAPI(**rules["cron_timer_rule_1.yaml"])
+ )
self.assertIsNotNone(trigger_db_ret_1)
trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule(
- RuleAPI(**rules['cron_timer_rule_2.yaml']))
+ RuleAPI(**rules["cron_timer_rule_2.yaml"])
+ )
self.assertIsNotNone(trigger_db_ret_2)
- self.assertEqual(trigger_db_ret_1, trigger_db_ret_2, 'Should reuse same trigger.')
+ self.assertEqual(
+ trigger_db_ret_1, trigger_db_ret_2, "Should reuse same trigger."
+ )
trigger_db = Trigger.get_by_id(trigger_db_ret_1.id)
- self.assertDictEqual(trigger_db.parameters,
- rules['cron_timer_rule_1.yaml']['trigger']['parameters'])
+ self.assertDictEqual(
+ trigger_db.parameters,
+ rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"],
+ )
def test_create_or_update_trigger_db_simple_triggers(self):
- test_fixtures = {
- 'triggertypes': ['triggertype1.yaml']
- }
+ test_fixtures = {"triggertypes": ["triggertype1.yaml"]}
loader = FixturesLoader()
- fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures)
- triggertypes = fixtures['triggertypes']
+ fixtures = loader.save_fixtures_to_db(
+ fixtures_pack="generic", fixtures_dict=test_fixtures
+ )
+ triggertypes = fixtures["triggertypes"]
trigger_type_ref = ResourceReference.to_string_reference(
- name=triggertypes['triggertype1.yaml']['name'],
- pack=triggertypes['triggertype1.yaml']['pack'])
+ name=triggertypes["triggertype1.yaml"]["name"],
+ pack=triggertypes["triggertype1.yaml"]["pack"],
+ )
trigger = {
- 'name': triggertypes['triggertype1.yaml']['name'],
- 'pack': triggertypes['triggertype1.yaml']['pack'],
- 'type': trigger_type_ref
+ "name": triggertypes["triggertype1.yaml"]["name"],
+ "pack": triggertypes["triggertype1.yaml"]["pack"],
+ "type": trigger_type_ref,
}
trigger_service.create_or_update_trigger_db(trigger)
triggers = Trigger.get_all()
- self.assertTrue(len(triggers) == 1, 'Only one trigger should be created.')
- self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name'])
+ self.assertTrue(len(triggers) == 1, "Only one trigger should be created.")
+ self.assertTrue(
+ triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"]
+ )
# Try adding duplicate
trigger_service.create_or_update_trigger_db(trigger)
triggers = Trigger.get_all()
- self.assertTrue(len(triggers) == 1, 'Only one trigger should be present.')
- self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name'])
+ self.assertTrue(len(triggers) == 1, "Only one trigger should be present.")
+ self.assertTrue(
+ triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"]
+ )
def test_exception_thrown_when_rule_creation_no_trigger_yes_triggertype(self):
- test_fixtures = {
- 'triggertypes': ['triggertype1.yaml']
- }
+ test_fixtures = {"triggertypes": ["triggertype1.yaml"]}
loader = FixturesLoader()
- fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures)
- triggertypes = fixtures['triggertypes']
+ fixtures = loader.save_fixtures_to_db(
+ fixtures_pack="generic", fixtures_dict=test_fixtures
+ )
+ triggertypes = fixtures["triggertypes"]
trigger_type_ref = ResourceReference.to_string_reference(
- name=triggertypes['triggertype1.yaml']['name'],
- pack=triggertypes['triggertype1.yaml']['pack'])
+ name=triggertypes["triggertype1.yaml"]["name"],
+ pack=triggertypes["triggertype1.yaml"]["pack"],
+ )
rule = {
- 'name': 'fancyrule',
- 'trigger': {
- 'type': trigger_type_ref
- },
- 'criteria': {
-
- },
- 'action': {
- 'ref': 'core.local',
- 'parameters': {
- 'cmd': 'date'
- }
- }
+ "name": "fancyrule",
+ "trigger": {"type": trigger_type_ref},
+ "criteria": {},
+ "action": {"ref": "core.local", "parameters": {"cmd": "date"}},
}
rule_api = RuleAPI(**rule)
- self.assertRaises(TriggerDoesNotExistException,
- trigger_service.create_trigger_db_from_rule, rule_api)
+ self.assertRaises(
+ TriggerDoesNotExistException,
+ trigger_service.create_trigger_db_from_rule,
+ rule_api,
+ )
def test_get_trigger_db_given_type_and_params(self):
# Add dummy triggers
- trigger_1 = TriggerDB(pack='testpack', name='testtrigger1', type='testpack.testtrigger1')
+ trigger_1 = TriggerDB(
+ pack="testpack", name="testtrigger1", type="testpack.testtrigger1"
+ )
- trigger_2 = TriggerDB(pack='testpack', name='testtrigger2', type='testpack.testtrigger2')
+ trigger_2 = TriggerDB(
+ pack="testpack", name="testtrigger2", type="testpack.testtrigger2"
+ )
- trigger_3 = TriggerDB(pack='testpack', name='testtrigger3', type='testpack.testtrigger3')
+ trigger_3 = TriggerDB(
+ pack="testpack", name="testtrigger3", type="testpack.testtrigger3"
+ )
- trigger_4 = TriggerDB(pack='testpack', name='testtrigger4', type='testpack.testtrigger4',
- parameters={'ponies': 'unicorn'})
+ trigger_4 = TriggerDB(
+ pack="testpack",
+ name="testtrigger4",
+ type="testpack.testtrigger4",
+ parameters={"ponies": "unicorn"},
+ )
Trigger.add_or_update(trigger_1)
Trigger.add_or_update(trigger_2)
@@ -143,64 +166,73 @@ def test_get_trigger_db_given_type_and_params(self):
Trigger.add_or_update(trigger_4)
# Trigger with no parameters, parameters={} in db
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type,
- parameters={})
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_1.type, parameters={}
+ )
self.assertEqual(trigger_db, trigger_1)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type,
- parameters=None)
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_1.type, parameters=None
+ )
self.assertEqual(trigger_db, trigger_1)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type,
- parameters={'fo': 'bar'})
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_1.type, parameters={"fo": "bar"}
+ )
self.assertEqual(trigger_db, None)
# Trigger with no parameters, no parameters attribute in the db
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type,
- parameters={})
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_2.type, parameters={}
+ )
self.assertEqual(trigger_db, trigger_2)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type,
- parameters=None)
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_2.type, parameters=None
+ )
self.assertEqual(trigger_db, trigger_2)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type,
- parameters={'fo': 'bar'})
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_2.type, parameters={"fo": "bar"}
+ )
self.assertEqual(trigger_db, None)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type,
- parameters={})
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_3.type, parameters={}
+ )
self.assertEqual(trigger_db, trigger_3)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type,
- parameters=None)
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_3.type, parameters=None
+ )
self.assertEqual(trigger_db, trigger_3)
# Trigger with parameters
trigger_db = trigger_service.get_trigger_db_given_type_and_params(
- type=trigger_4.type,
- parameters=trigger_4.parameters)
+ type=trigger_4.type, parameters=trigger_4.parameters
+ )
self.assertEqual(trigger_db, trigger_4)
- trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_4.type,
- parameters=None)
+ trigger_db = trigger_service.get_trigger_db_given_type_and_params(
+ type=trigger_4.type, parameters=None
+ )
self.assertEqual(trigger_db, None)
def test_add_trigger_type_no_params(self):
# Trigger type with no params should create a trigger with same name as trigger type.
trig_type = {
- 'name': 'myawesometriggertype',
- 'pack': 'dummy_pack_1',
- 'description': 'Words cannot describe how awesome I am.',
- 'parameters_schema': {},
- 'payload_schema': {}
+ "name": "myawesometriggertype",
+ "pack": "dummy_pack_1",
+ "description": "Words cannot describe how awesome I am.",
+ "parameters_schema": {},
+ "payload_schema": {},
}
trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type])
trigger_type, trigger = trigtype_dbs[0]
trigtype_db = TriggerType.get_by_id(trigger_type.id)
- self.assertEqual(trigtype_db.pack, 'dummy_pack_1')
- self.assertEqual(trigtype_db.name, trig_type.get('name'))
+ self.assertEqual(trigtype_db.pack, "dummy_pack_1")
+ self.assertEqual(trigtype_db.name, trig_type.get("name"))
self.assertIsNotNone(trigger)
self.assertEqual(trigger.name, trigtype_db.name)
@@ -210,35 +242,34 @@ def test_add_trigger_type_no_params(self):
self.assertTrue(len(triggers) == 1)
def test_add_trigger_type_with_params(self):
- MOCK_TRIGGER.type = 'system.test'
+ MOCK_TRIGGER.type = "system.test"
# Trigger type with params should not create a trigger.
PARAMETERS_SCHEMA = {
"type": "object",
- "properties": {
- "url": {"type": "string"}
- },
- "required": ['url'],
- "additionalProperties": False
+ "properties": {"url": {"type": "string"}},
+ "required": ["url"],
+ "additionalProperties": False,
}
trig_type = {
- 'name': 'myawesometriggertype2',
- 'pack': 'my_pack_1',
- 'description': 'Words cannot describe how awesome I am.',
- 'parameters_schema': PARAMETERS_SCHEMA,
- 'payload_schema': {}
+ "name": "myawesometriggertype2",
+ "pack": "my_pack_1",
+ "description": "Words cannot describe how awesome I am.",
+ "parameters_schema": PARAMETERS_SCHEMA,
+ "payload_schema": {},
}
trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type])
trigger_type, trigger = trigtype_dbs[0]
trigtype_db = TriggerType.get_by_id(trigger_type.id)
- self.assertEqual(trigtype_db.pack, 'my_pack_1')
- self.assertEqual(trigtype_db.name, trig_type.get('name'))
+ self.assertEqual(trigtype_db.pack, "my_pack_1")
+ self.assertEqual(trigtype_db.name, trig_type.get("name"))
self.assertEqual(trigger, None)
def test_add_trigger_type(self):
"""
This sensor has misconfigured trigger type. We shouldn't explode.
"""
+
class FailTestSensor(object):
started = False
@@ -252,12 +283,12 @@ def stop(self):
pass
def get_trigger_types(self):
- return [
- {'description': 'Ain\'t got no name'}
- ]
+ return [{"description": "Ain't got no name"}]
try:
trigger_service.add_trigger_models(FailTestSensor().get_trigger_types())
- self.assertTrue(False, 'Trigger type doesn\'t have \'name\' field. Should have thrown.')
+ self.assertTrue(
+ False, "Trigger type doesn't have 'name' field. Should have thrown."
+ )
except Exception:
self.assertTrue(True)
diff --git a/st2common/tests/unit/test_triggers_registrar.py b/st2common/tests/unit/test_triggers_registrar.py
index 53595d2867..5ceda4f851 100644
--- a/st2common/tests/unit/test_triggers_registrar.py
+++ b/st2common/tests/unit/test_triggers_registrar.py
@@ -22,9 +22,7 @@
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import get_fixtures_packs_base_path
-__all__ = [
- 'TriggersRegistrarTestCase'
-]
+__all__ = ["TriggersRegistrarTestCase"]
class TriggersRegistrarTestCase(CleanDbTestCase):
@@ -44,7 +42,7 @@ def test_register_all_triggers(self):
def test_register_triggers_from_pack(self):
base_path = get_fixtures_packs_base_path()
- pack_dir = os.path.join(base_path, 'dummy_pack_1')
+ pack_dir = os.path.join(base_path, "dummy_pack_1")
trigger_type_dbs = TriggerType.get_all()
self.assertEqual(len(trigger_type_dbs), 0)
@@ -58,12 +56,12 @@ def test_register_triggers_from_pack(self):
self.assertEqual(len(trigger_type_dbs), 2)
self.assertEqual(len(trigger_dbs), 2)
- self.assertEqual(trigger_type_dbs[0].name, 'event_handler')
- self.assertEqual(trigger_type_dbs[0].pack, 'dummy_pack_1')
- self.assertEqual(trigger_dbs[0].name, 'event_handler')
- self.assertEqual(trigger_dbs[0].pack, 'dummy_pack_1')
- self.assertEqual(trigger_dbs[0].type, 'dummy_pack_1.event_handler')
+ self.assertEqual(trigger_type_dbs[0].name, "event_handler")
+ self.assertEqual(trigger_type_dbs[0].pack, "dummy_pack_1")
+ self.assertEqual(trigger_dbs[0].name, "event_handler")
+ self.assertEqual(trigger_dbs[0].pack, "dummy_pack_1")
+ self.assertEqual(trigger_dbs[0].type, "dummy_pack_1.event_handler")
- self.assertEqual(trigger_type_dbs[1].name, 'head_sha_monitor')
- self.assertEqual(trigger_type_dbs[1].pack, 'dummy_pack_1')
- self.assertEqual(trigger_type_dbs[1].payload_schema['type'], 'object')
+ self.assertEqual(trigger_type_dbs[1].name, "head_sha_monitor")
+ self.assertEqual(trigger_type_dbs[1].pack, "dummy_pack_1")
+ self.assertEqual(trigger_type_dbs[1].payload_schema["type"], "object")
diff --git a/st2common/tests/unit/test_unit_testing_mocks.py b/st2common/tests/unit/test_unit_testing_mocks.py
index ce63dd7834..742ca85da1 100644
--- a/st2common/tests/unit/test_unit_testing_mocks.py
+++ b/st2common/tests/unit/test_unit_testing_mocks.py
@@ -23,9 +23,9 @@
from st2tests.mocks.action import MockActionService
__all__ = [
- 'BaseSensorTestCaseTestCase',
- 'MockSensorServiceTestCase',
- 'MockActionServiceTestCase'
+ "BaseSensorTestCaseTestCase",
+ "MockSensorServiceTestCase",
+ "MockActionServiceTestCase",
]
@@ -37,36 +37,38 @@ class BaseMockResourceServiceTestCase(object):
class TestCase(unittest2.TestCase):
def test_get_user_info(self):
result = self.mock_service.get_user_info()
- self.assertEqual(result['username'], 'admin')
- self.assertEqual(result['rbac']['roles'], ['admin'])
+ self.assertEqual(result["username"], "admin")
+ self.assertEqual(result["rbac"]["roles"], ["admin"])
def test_list_set_get_delete_values(self):
# list_values, set_value
result = self.mock_service.list_values()
self.assertSequenceEqual(result, [])
- self.mock_service.set_value(name='t1.local', value='test1', local=True)
- self.mock_service.set_value(name='t1.global', value='test1', local=False)
+ self.mock_service.set_value(name="t1.local", value="test1", local=True)
+ self.mock_service.set_value(name="t1.global", value="test1", local=False)
result = self.mock_service.list_values(local=True)
self.assertEqual(len(result), 1)
- self.assertEqual(result[0].name, 'dummy.test:t1.local')
+ self.assertEqual(result[0].name, "dummy.test:t1.local")
result = self.mock_service.list_values(local=False)
- self.assertEqual(result[0].name, 'dummy.test:t1.local')
- self.assertEqual(result[1].name, 't1.global')
+ self.assertEqual(result[0].name, "dummy.test:t1.local")
+ self.assertEqual(result[1].name, "t1.global")
self.assertEqual(len(result), 2)
# get_value
- self.assertEqual(self.mock_service.get_value('inexistent'), None)
- self.assertEqual(self.mock_service.get_value(name='t1.local', local=True), 'test1')
+ self.assertEqual(self.mock_service.get_value("inexistent"), None)
+ self.assertEqual(
+ self.mock_service.get_value(name="t1.local", local=True), "test1"
+ )
# delete_value
self.assertEqual(len(self.mock_service.list_values(local=True)), 1)
- self.assertEqual(self.mock_service.delete_value('inexistent'), False)
+ self.assertEqual(self.mock_service.delete_value("inexistent"), False)
self.assertEqual(len(self.mock_service.list_values(local=True)), 1)
- self.assertEqual(self.mock_service.delete_value('t1.local'), True)
+ self.assertEqual(self.mock_service.delete_value("t1.local"), True)
self.assertEqual(len(self.mock_service.list_values(local=True)), 0)
@@ -77,47 +79,50 @@ def test_dispatch_and_assertTriggerDispatched(self):
sensor_service = self.sensor_service
expected_msg = 'Trigger "nope" hasn\'t been dispatched'
- self.assertRaisesRegexp(AssertionError, expected_msg,
- self.assertTriggerDispatched, trigger='nope')
+ self.assertRaisesRegexp(
+ AssertionError, expected_msg, self.assertTriggerDispatched, trigger="nope"
+ )
- sensor_service.dispatch(trigger='test1', payload={'a': 'b'})
- result = self.assertTriggerDispatched(trigger='test1')
+ sensor_service.dispatch(trigger="test1", payload={"a": "b"})
+ result = self.assertTriggerDispatched(trigger="test1")
self.assertTrue(result)
- result = self.assertTriggerDispatched(trigger='test1', payload={'a': 'b'})
+ result = self.assertTriggerDispatched(trigger="test1", payload={"a": "b"})
self.assertTrue(result)
expected_msg = 'Trigger "test1" hasn\'t been dispatched'
- self.assertRaisesRegexp(AssertionError, expected_msg,
- self.assertTriggerDispatched,
- trigger='test1',
- payload={'a': 'c'})
+ self.assertRaisesRegexp(
+ AssertionError,
+ expected_msg,
+ self.assertTriggerDispatched,
+ trigger="test1",
+ payload={"a": "c"},
+ )
class MockSensorServiceTestCase(BaseMockResourceServiceTestCase.TestCase):
-
def setUp(self):
- mock_sensor_wrapper = MockSensorWrapper(pack='dummy', class_name='test')
+ mock_sensor_wrapper = MockSensorWrapper(pack="dummy", class_name="test")
self.mock_service = MockSensorService(sensor_wrapper=mock_sensor_wrapper)
def test_get_logger(self):
sensor_service = self.mock_service
- logger = sensor_service.get_logger('test')
- logger.info('test info')
- logger.debug('test debug')
+ logger = sensor_service.get_logger("test")
+ logger.info("test info")
+ logger.debug("test debug")
self.assertEqual(len(logger.method_calls), 2)
method_name, method_args, method_kwargs = tuple(logger.method_calls[0])
- self.assertEqual(method_name, 'info')
- self.assertEqual(method_args, ('test info',))
+ self.assertEqual(method_name, "info")
+ self.assertEqual(method_args, ("test info",))
self.assertEqual(method_kwargs, {})
method_name, method_args, method_kwargs = tuple(logger.method_calls[1])
- self.assertEqual(method_name, 'debug')
- self.assertEqual(method_args, ('test debug',))
+ self.assertEqual(method_name, "debug")
+ self.assertEqual(method_args, ("test debug",))
self.assertEqual(method_kwargs, {})
class MockActionServiceTestCase(BaseMockResourceServiceTestCase.TestCase):
def setUp(self):
- mock_action_wrapper = MockActionWrapper(pack='dummy', class_name='test')
+ mock_action_wrapper = MockActionWrapper(pack="dummy", class_name="test")
self.mock_service = MockActionService(action_wrapper=mock_action_wrapper)
diff --git a/st2common/tests/unit/test_util_actionalias_helpstrings.py b/st2common/tests/unit/test_util_actionalias_helpstrings.py
index a7726dd177..e543bd471a 100644
--- a/st2common/tests/unit/test_util_actionalias_helpstrings.py
+++ b/st2common/tests/unit/test_util_actionalias_helpstrings.py
@@ -25,62 +25,101 @@
ALIASES = [
- MemoryActionAliasDB(name="kyle_reese", ref="terminator.1",
- pack="the80s", enabled=True,
- formats=["Come with me if you want to live"]
+ MemoryActionAliasDB(
+ name="kyle_reese",
+ ref="terminator.1",
+ pack="the80s",
+ enabled=True,
+ formats=["Come with me if you want to live"],
),
- MemoryActionAliasDB(name="terminator", ref="terminator.2",
- pack="the80s", enabled=True,
- formats=["I need your {{item}}, your {{item2}}"
- " and your {{vehicle}}"]
+ MemoryActionAliasDB(
+ name="terminator",
+ ref="terminator.2",
+ pack="the80s",
+ enabled=True,
+ formats=["I need your {{item}}, your {{item2}}" " and your {{vehicle}}"],
),
- MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.3",
- pack="the80s", enabled=True,
- formats=[{'display': 'Number 5 is {{status}}',
- 'representation': ['Number 5 is {{status=alive}}']},
- 'Hey, laser lips, your mama was a snow blower.']
+ MemoryActionAliasDB(
+ name="johnny_five_alive",
+ ref="short_circuit.3",
+ pack="the80s",
+ enabled=True,
+ formats=[
+ {
+ "display": "Number 5 is {{status}}",
+ "representation": ["Number 5 is {{status=alive}}"],
+ },
+ "Hey, laser lips, your mama was a snow blower.",
+ ],
),
- MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.4",
- pack="the80s", enabled=True,
- formats=["How do I feel? I feel... {{status}}!"]
+ MemoryActionAliasDB(
+ name="i_feel_alive",
+ ref="short_circuit.4",
+ pack="the80s",
+ enabled=True,
+ formats=["How do I feel? I feel... {{status}}!"],
),
- MemoryActionAliasDB(name='andy', ref='the_goonies.1',
- pack="the80s", enabled=True,
- formats=[{'display': 'Watch this.'}]
+ MemoryActionAliasDB(
+ name="andy",
+ ref="the_goonies.1",
+ pack="the80s",
+ enabled=True,
+ formats=[{"display": "Watch this."}],
),
- MemoryActionAliasDB(name='andy', ref='the_goonies.5',
- pack="the80s", enabled=True,
- formats=[{'display': "He's just like his {{relation}}."}]
+ MemoryActionAliasDB(
+ name="andy",
+ ref="the_goonies.5",
+ pack="the80s",
+ enabled=True,
+ formats=[{"display": "He's just like his {{relation}}."}],
),
- MemoryActionAliasDB(name='data', ref='the_goonies.6',
- pack="the80s", enabled=True,
- formats=[{'representation': "That's okay daddy. You can't hug a {{object}}."}]
+ MemoryActionAliasDB(
+ name="data",
+ ref="the_goonies.6",
+ pack="the80s",
+ enabled=True,
+ formats=[{"representation": "That's okay daddy. You can't hug a {{object}}."}],
),
- MemoryActionAliasDB(name='mr_wang', ref='the_goonies.7',
- pack="the80s", enabled=True,
- formats=[{'representation': 'You are my greatest invention.'}]
+ MemoryActionAliasDB(
+ name="mr_wang",
+ ref="the_goonies.7",
+ pack="the80s",
+ enabled=True,
+ formats=[{"representation": "You are my greatest invention."}],
),
- MemoryActionAliasDB(name="Ferris", ref="ferris_buellers_day_off.8",
- pack="the80s", enabled=True,
- formats=["Life moves pretty fast.",
- "If you don't stop and look around once in a while, you could miss it."]
+ MemoryActionAliasDB(
+ name="Ferris",
+ ref="ferris_buellers_day_off.8",
+ pack="the80s",
+ enabled=True,
+ formats=[
+ "Life moves pretty fast.",
+ "If you don't stop and look around once in a while, you could miss it.",
+ ],
),
- MemoryActionAliasDB(name="economics.teacher", ref="ferris_buellers_day_off.10",
- pack="the80s", enabled=False,
- formats=["Bueller?... Bueller?... Bueller? "]
+ MemoryActionAliasDB(
+ name="economics.teacher",
+ ref="ferris_buellers_day_off.10",
+ pack="the80s",
+ enabled=False,
+ formats=["Bueller?... Bueller?... Bueller? "],
+ ),
+ MemoryActionAliasDB(
+ name="spengler",
+ ref="ghostbusters.10",
+ pack="the80s",
+ enabled=True,
+ formats=["{{choice}} cross the {{target}}"],
),
- MemoryActionAliasDB(name="spengler", ref="ghostbusters.10",
- pack="the80s", enabled=True,
- formats=["{{choice}} cross the {{target}}"]
- )
]
-@mock.patch.object(MemoryActionAliasDB, 'get_uid')
+@mock.patch.object(MemoryActionAliasDB, "get_uid")
class ActionAliasTestCase(unittest2.TestCase):
- '''
+ """
Test scenarios must consist of 80s movie quotes.
- '''
+ """
+
def check_data_structure(self, result):
tmp = list(result.keys())
tmp.sort()
@@ -93,7 +132,9 @@ def test_filtering_no_arg(self, mock):
result = generate_helpstring_result(ALIASES)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -115,7 +156,9 @@ def test_filtering_match(self, mock):
result = generate_helpstring_result(ALIASES, "you")
self.check_data_structure(result)
self.check_available_count(result, 4)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 4)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -123,12 +166,16 @@ def test_pack_empty_string(self, mock):
result = generate_helpstring_result(ALIASES, "", "")
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
def test_pack_no_match(self, mock):
- result = generate_helpstring_result(ALIASES, "", "you_will_not_find_this_string")
+ result = generate_helpstring_result(
+ ALIASES, "", "you_will_not_find_this_string"
+ )
self.check_data_structure(result)
self.check_available_count(result, 0)
self.assertEqual(result.get("helpstrings"), [])
@@ -137,7 +184,9 @@ def test_pack_match(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s")
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -153,7 +202,9 @@ def test_limit_neg_out_of_bounds(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s", -3)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -161,7 +212,9 @@ def test_limit_pos_out_of_bounds(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s", 30)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -169,7 +222,9 @@ def test_limit_in_bounds(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s", 3)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 3)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -185,7 +240,9 @@ def test_offset_negative_out_of_bounds(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s", 0, -1)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 10)
self.assertEqual(the80s[0].get("display"), "Come with me if you want to live")
@@ -199,6 +256,8 @@ def test_offset_in_bounds(self, mock):
result = generate_helpstring_result(ALIASES, "", "the80s", 0, 6)
self.check_data_structure(result)
self.check_available_count(result, 10)
- the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"]
+ the80s = [
+ line for line in result.get("helpstrings") if line["pack"] == "the80s"
+ ]
self.assertEqual(len(the80s), 4)
self.assertEqual(the80s[0].get("display"), "He's just like his {{relation}}.")
diff --git a/st2common/tests/unit/test_util_actionalias_matching.py b/st2common/tests/unit/test_util_actionalias_matching.py
index c22ccab3e6..082fa40b98 100644
--- a/st2common/tests/unit/test_util_actionalias_matching.py
+++ b/st2common/tests/unit/test_util_actionalias_matching.py
@@ -24,89 +24,130 @@
MemoryActionAliasDB = ActionAliasDB
-@mock.patch.object(MemoryActionAliasDB, 'get_uid')
+@mock.patch.object(MemoryActionAliasDB, "get_uid")
class ActionAliasTestCase(unittest2.TestCase):
- '''
+ """
Test scenarios must consist of 80s movie quotes.
- '''
+ """
+
def test_list_format_strings_from_aliases(self, mock):
ALIASES = [
- MemoryActionAliasDB(name="kyle_reese", ref="terminator.1",
- formats=["Come with me if you want to live"]),
- MemoryActionAliasDB(name="terminator", ref="terminator.2",
- formats=["I need your {{item}}, your {{item2}}"
- " and your {{vehicle}}"])
+ MemoryActionAliasDB(
+ name="kyle_reese",
+ ref="terminator.1",
+ formats=["Come with me if you want to live"],
+ ),
+ MemoryActionAliasDB(
+ name="terminator",
+ ref="terminator.2",
+ formats=[
+ "I need your {{item}}, your {{item2}}" " and your {{vehicle}}"
+ ],
+ ),
]
result = matching.list_format_strings_from_aliases(ALIASES)
self.assertEqual(len(result), 2)
- self.assertEqual(result[0]['display'], "Come with me if you want to live")
- self.assertEqual(result[1]['display'],
- "I need your {{item}}, your {{item2}} and"
- " your {{vehicle}}")
+ self.assertEqual(result[0]["display"], "Come with me if you want to live")
+ self.assertEqual(
+ result[1]["display"],
+ "I need your {{item}}, your {{item2}} and" " your {{vehicle}}",
+ )
def test_list_format_strings_from_aliases_with_display(self, mock):
ALIASES = [
- MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.1", formats=[
- {'display': 'Number 5 is {{status}}',
- 'representation': ['Number 5 is {{status=alive}}']},
- 'Hey, laser lips, your mama was a snow blower.']),
- MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.2",
- formats=["How do I feel? I feel... {{status}}!"])
+ MemoryActionAliasDB(
+ name="johnny_five_alive",
+ ref="short_circuit.1",
+ formats=[
+ {
+ "display": "Number 5 is {{status}}",
+ "representation": ["Number 5 is {{status=alive}}"],
+ },
+ "Hey, laser lips, your mama was a snow blower.",
+ ],
+ ),
+ MemoryActionAliasDB(
+ name="i_feel_alive",
+ ref="short_circuit.2",
+ formats=["How do I feel? I feel... {{status}}!"],
+ ),
]
result = matching.list_format_strings_from_aliases(ALIASES)
self.assertEqual(len(result), 3)
- self.assertEqual(result[0]['display'], "Number 5 is {{status}}")
- self.assertEqual(result[0]['representation'], "Number 5 is {{status=alive}}")
- self.assertEqual(result[1]['display'], "Hey, laser lips, your mama was a snow blower.")
- self.assertEqual(result[1]['representation'],
- "Hey, laser lips, your mama was a snow blower.")
- self.assertEqual(result[2]['display'], "How do I feel? I feel... {{status}}!")
- self.assertEqual(result[2]['representation'], "How do I feel? I feel... {{status}}!")
+ self.assertEqual(result[0]["display"], "Number 5 is {{status}}")
+ self.assertEqual(result[0]["representation"], "Number 5 is {{status=alive}}")
+ self.assertEqual(
+ result[1]["display"], "Hey, laser lips, your mama was a snow blower."
+ )
+ self.assertEqual(
+ result[1]["representation"], "Hey, laser lips, your mama was a snow blower."
+ )
+ self.assertEqual(result[2]["display"], "How do I feel? I feel... {{status}}!")
+ self.assertEqual(
+ result[2]["representation"], "How do I feel? I feel... {{status}}!"
+ )
def test_list_format_strings_from_aliases_with_display_only(self, mock):
ALIASES = [
- MemoryActionAliasDB(name='andy',
- ref='the_goonies.1', formats=[{'display': 'Watch this.'}]),
- MemoryActionAliasDB(name='andy', ref='the_goonies.2',
- formats=[{'display': "He's just like his {{relation}}."}])
+ MemoryActionAliasDB(
+ name="andy", ref="the_goonies.1", formats=[{"display": "Watch this."}]
+ ),
+ MemoryActionAliasDB(
+ name="andy",
+ ref="the_goonies.2",
+ formats=[{"display": "He's just like his {{relation}}."}],
+ ),
]
result = matching.list_format_strings_from_aliases(ALIASES)
self.assertEqual(len(result), 2)
- self.assertEqual(result[0]['display'], 'Watch this.')
- self.assertEqual(result[0]['representation'], '')
- self.assertEqual(result[1]['display'], "He's just like his {{relation}}.")
- self.assertEqual(result[1]['representation'], '')
+ self.assertEqual(result[0]["display"], "Watch this.")
+ self.assertEqual(result[0]["representation"], "")
+ self.assertEqual(result[1]["display"], "He's just like his {{relation}}.")
+ self.assertEqual(result[1]["representation"], "")
def test_list_format_strings_from_aliases_with_representation_only(self, mock):
ALIASES = [
- MemoryActionAliasDB(name='data', ref='the_goonies.1', formats=[
- {'representation': "That's okay daddy. You can't hug a {{object}}."}]),
- MemoryActionAliasDB(name='mr_wang', ref='the_goonies.2', formats=[
- {'representation': 'You are my greatest invention.'}])
+ MemoryActionAliasDB(
+ name="data",
+ ref="the_goonies.1",
+ formats=[
+ {"representation": "That's okay daddy. You can't hug a {{object}}."}
+ ],
+ ),
+ MemoryActionAliasDB(
+ name="mr_wang",
+ ref="the_goonies.2",
+ formats=[{"representation": "You are my greatest invention."}],
+ ),
]
result = matching.list_format_strings_from_aliases(ALIASES)
self.assertEqual(len(result), 2)
- self.assertEqual(result[0]['display'], None)
- self.assertEqual(result[0]['representation'],
- "That's okay daddy. You can't hug a {{object}}.")
- self.assertEqual(result[1]['display'], None)
- self.assertEqual(result[1]['representation'], 'You are my greatest invention.')
+ self.assertEqual(result[0]["display"], None)
+ self.assertEqual(
+ result[0]["representation"],
+ "That's okay daddy. You can't hug a {{object}}.",
+ )
+ self.assertEqual(result[1]["display"], None)
+ self.assertEqual(result[1]["representation"], "You are my greatest invention.")
def test_normalise_alias_format_string(self, mock):
result = matching.normalise_alias_format_string(
- 'Quite an experience to live in fear, isn\'t it?')
+ "Quite an experience to live in fear, isn't it?"
+ )
self.assertEqual([result[0]], result[1])
self.assertEqual(result[0], "Quite an experience to live in fear, isn't it?")
def test_normalise_alias_format_string_error(self, mock):
alias_list = ["Quite an experience to live in fear, isn't it?"]
- expected_msg = ("alias_format '%s' is neither a dictionary or string type."
- % repr(alias_list))
+ expected_msg = (
+ "alias_format '%s' is neither a dictionary or string type."
+ % repr(alias_list)
+ )
with self.assertRaises(TypeError) as cm:
matching.normalise_alias_format_string(alias_list)
@@ -115,13 +156,16 @@ def test_normalise_alias_format_string_error(self, mock):
def test_matching(self, mock):
ALIASES = [
- MemoryActionAliasDB(name="spengler", ref="ghostbusters.1",
- formats=["{{choice}} cross the {{target}}"]),
+ MemoryActionAliasDB(
+ name="spengler",
+ ref="ghostbusters.1",
+ formats=["{{choice}} cross the {{target}}"],
+ ),
]
COMMAND = "Don't cross the streams"
match = matching.match_command_to_alias(COMMAND, ALIASES)
self.assertEqual(len(match), 1)
- self.assertEqual(match[0]['alias'].ref, "ghostbusters.1")
- self.assertEqual(match[0]['representation'], "{{choice}} cross the {{target}}")
+ self.assertEqual(match[0]["alias"].ref, "ghostbusters.1")
+ self.assertEqual(match[0]["representation"], "{{choice}} cross the {{target}}")
# we need some more complex scenarios in here.
diff --git a/st2common/tests/unit/test_util_api.py b/st2common/tests/unit/test_util_api.py
index bc0e385df1..2333939b13 100644
--- a/st2common/tests/unit/test_util_api.py
+++ b/st2common/tests/unit/test_util_api.py
@@ -23,24 +23,25 @@
from st2common.util.api import get_full_public_api_url
from st2tests.config import parse_args
from six.moves import zip
+
parse_args()
class APIUtilsTestCase(unittest2.TestCase):
def test_get_base_public_api_url(self):
values = [
- 'http://foo.bar.com',
- 'http://foo.bar.com/',
- 'http://foo.bar.com:8080',
- 'http://foo.bar.com:8080/',
- 'http://localhost:8080/',
+ "http://foo.bar.com",
+ "http://foo.bar.com/",
+ "http://foo.bar.com:8080",
+ "http://foo.bar.com:8080/",
+ "http://localhost:8080/",
]
expected = [
- 'http://foo.bar.com',
- 'http://foo.bar.com',
- 'http://foo.bar.com:8080',
- 'http://foo.bar.com:8080',
- 'http://localhost:8080',
+ "http://foo.bar.com",
+ "http://foo.bar.com",
+ "http://foo.bar.com:8080",
+ "http://foo.bar.com:8080",
+ "http://localhost:8080",
]
for mock_value, expected_result in zip(values, expected):
@@ -50,18 +51,18 @@ def test_get_base_public_api_url(self):
def test_get_full_public_api_url(self):
values = [
- 'http://foo.bar.com',
- 'http://foo.bar.com/',
- 'http://foo.bar.com:8080',
- 'http://foo.bar.com:8080/',
- 'http://localhost:8080/',
+ "http://foo.bar.com",
+ "http://foo.bar.com/",
+ "http://foo.bar.com:8080",
+ "http://foo.bar.com:8080/",
+ "http://localhost:8080/",
]
expected = [
- 'http://foo.bar.com/' + DEFAULT_API_VERSION,
- 'http://foo.bar.com/' + DEFAULT_API_VERSION,
- 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION,
- 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION,
- 'http://localhost:8080/' + DEFAULT_API_VERSION,
+ "http://foo.bar.com/" + DEFAULT_API_VERSION,
+ "http://foo.bar.com/" + DEFAULT_API_VERSION,
+ "http://foo.bar.com:8080/" + DEFAULT_API_VERSION,
+ "http://foo.bar.com:8080/" + DEFAULT_API_VERSION,
+ "http://localhost:8080/" + DEFAULT_API_VERSION,
]
for mock_value, expected_result in zip(values, expected):
diff --git a/st2common/tests/unit/test_util_compat.py b/st2common/tests/unit/test_util_compat.py
index 0e1ac9efe7..74face7ea6 100644
--- a/st2common/tests/unit/test_util_compat.py
+++ b/st2common/tests/unit/test_util_compat.py
@@ -19,18 +19,16 @@
from st2common.util.compat import to_ascii
-__all__ = [
- 'CompatUtilsTestCase'
-]
+__all__ = ["CompatUtilsTestCase"]
class CompatUtilsTestCase(unittest2.TestCase):
def test_to_ascii(self):
expected_values = [
- ('already ascii', 'already ascii'),
- (u'foo', 'foo'),
- ('٩(̾●̮̮̃̾•̃̾)۶', '()'),
- ('\xd9\xa9', '')
+ ("already ascii", "already ascii"),
+ ("foo", "foo"),
+ ("٩(̾●̮̮̃̾•̃̾)۶", "()"),
+ ("\xd9\xa9", ""),
]
for input_value, expected_value in expected_values:
diff --git a/st2common/tests/unit/test_util_db.py b/st2common/tests/unit/test_util_db.py
index dd230e6ae1..f94a2fe39a 100644
--- a/st2common/tests/unit/test_util_db.py
+++ b/st2common/tests/unit/test_util_db.py
@@ -22,88 +22,73 @@
class DatabaseUtilTestCase(unittest2.TestCase):
-
def test_noop_mongodb_to_python_types(self):
- data = [
- 123,
- 999.99,
- True,
- [10, 20, 30],
- {'a': 1, 'b': 2},
- None
- ]
+ data = [123, 999.99, True, [10, 20, 30], {"a": 1, "b": 2}, None]
for item in data:
self.assertEqual(db_util.mongodb_to_python_types(item), item)
def test_mongodb_basedict_to_dict(self):
- data = {'a': 1, 'b': 2}
+ data = {"a": 1, "b": 2}
- obj = mongoengine.base.datastructures.BaseDict(data, None, 'foobar')
+ obj = mongoengine.base.datastructures.BaseDict(data, None, "foobar")
self.assertDictEqual(db_util.mongodb_to_python_types(obj), data)
def test_mongodb_baselist_to_list(self):
data = [2, 4, 6]
- obj = mongoengine.base.datastructures.BaseList(data, None, 'foobar')
+ obj = mongoengine.base.datastructures.BaseList(data, None, "foobar")
self.assertListEqual(db_util.mongodb_to_python_types(obj), data)
def test_nested_mongdb_to_python_types(self):
data = {
- 'a': mongoengine.base.datastructures.BaseList([1, 2, 3], None, 'a'),
- 'b': mongoengine.base.datastructures.BaseDict({'a': 1, 'b': 2}, None, 'b'),
- 'c': {
- 'd': mongoengine.base.datastructures.BaseList([4, 5, 6], None, 'd'),
- 'e': mongoengine.base.datastructures.BaseDict({'c': 3, 'd': 4}, None, 'e')
+ "a": mongoengine.base.datastructures.BaseList([1, 2, 3], None, "a"),
+ "b": mongoengine.base.datastructures.BaseDict({"a": 1, "b": 2}, None, "b"),
+ "c": {
+ "d": mongoengine.base.datastructures.BaseList([4, 5, 6], None, "d"),
+ "e": mongoengine.base.datastructures.BaseDict(
+ {"c": 3, "d": 4}, None, "e"
+ ),
},
- 'f': mongoengine.base.datastructures.BaseList(
+ "f": mongoengine.base.datastructures.BaseList(
[
- mongoengine.base.datastructures.BaseDict({'e': 5}, None, 'f1'),
- mongoengine.base.datastructures.BaseDict({'f': 6}, None, 'f2')
+ mongoengine.base.datastructures.BaseDict({"e": 5}, None, "f1"),
+ mongoengine.base.datastructures.BaseDict({"f": 6}, None, "f2"),
],
None,
- 'f'
+ "f",
),
- 'g': mongoengine.base.datastructures.BaseDict(
+ "g": mongoengine.base.datastructures.BaseDict(
{
- 'h': mongoengine.base.datastructures.BaseList(
+ "h": mongoengine.base.datastructures.BaseList(
[
- mongoengine.base.datastructures.BaseDict({'g': 7}, None, 'h1'),
- mongoengine.base.datastructures.BaseDict({'h': 8}, None, 'h2')
+ mongoengine.base.datastructures.BaseDict(
+ {"g": 7}, None, "h1"
+ ),
+ mongoengine.base.datastructures.BaseDict(
+ {"h": 8}, None, "h2"
+ ),
],
None,
- 'h'
+ "h",
+ ),
+ "i": mongoengine.base.datastructures.BaseDict(
+ {"j": 9, "k": 10}, None, "i"
),
- 'i': mongoengine.base.datastructures.BaseDict({'j': 9, 'k': 10}, None, 'i')
},
None,
- 'g'
+ "g",
),
}
expected = {
- 'a': [1, 2, 3],
- 'b': {'a': 1, 'b': 2},
- 'c': {
- 'd': [4, 5, 6],
- 'e': {'c': 3, 'd': 4}
- },
- 'f': [
- {'e': 5},
- {'f': 6}
- ],
- 'g': {
- 'h': [
- {'g': 7},
- {'h': 8}
- ],
- 'i': {
- 'j': 9,
- 'k': 10
- }
- }
+ "a": [1, 2, 3],
+ "b": {"a": 1, "b": 2},
+ "c": {"d": [4, 5, 6], "e": {"c": 3, "d": 4}},
+ "f": [{"e": 5}, {"f": 6}],
+ "g": {"h": [{"g": 7}, {"h": 8}], "i": {"j": 9, "k": 10}},
}
self.assertDictEqual(db_util.mongodb_to_python_types(data), expected)
diff --git a/st2common/tests/unit/test_util_file_system.py b/st2common/tests/unit/test_util_file_system.py
index ea46a0b943..a1af0c957a 100644
--- a/st2common/tests/unit/test_util_file_system.py
+++ b/st2common/tests/unit/test_util_file_system.py
@@ -22,30 +22,32 @@
from st2common.util.file_system import get_file_list
CURRENT_DIR = os.path.dirname(__file__)
-ST2TESTS_DIR = os.path.join(CURRENT_DIR, '../../../st2tests/st2tests')
+ST2TESTS_DIR = os.path.join(CURRENT_DIR, "../../../st2tests/st2tests")
class FileSystemUtilsTestCase(unittest2.TestCase):
def test_get_file_list(self):
# Standard exclude pattern
- directory = os.path.join(ST2TESTS_DIR, 'policies')
+ directory = os.path.join(ST2TESTS_DIR, "policies")
expected = [
- 'mock_exception.py',
- 'concurrency.py',
- '__init__.py',
- 'meta/mock_exception.yaml',
- 'meta/concurrency.yaml',
- 'meta/__init__.py'
+ "mock_exception.py",
+ "concurrency.py",
+ "__init__.py",
+ "meta/mock_exception.yaml",
+ "meta/concurrency.yaml",
+ "meta/__init__.py",
]
- result = get_file_list(directory=directory, exclude_patterns=['*.pyc'])
+ result = get_file_list(directory=directory, exclude_patterns=["*.pyc"])
self.assertItemsEqual(expected, result)
# Custom exclude pattern
expected = [
- 'mock_exception.py',
- 'concurrency.py',
- '__init__.py',
- 'meta/__init__.py'
+ "mock_exception.py",
+ "concurrency.py",
+ "__init__.py",
+ "meta/__init__.py",
]
- result = get_file_list(directory=directory, exclude_patterns=['*.pyc', '*.yaml'])
+ result = get_file_list(
+ directory=directory, exclude_patterns=["*.pyc", "*.yaml"]
+ )
self.assertItemsEqual(expected, result)
diff --git a/st2common/tests/unit/test_util_http.py b/st2common/tests/unit/test_util_http.py
index 2bfbc22f04..a97aa8c7f1 100644
--- a/st2common/tests/unit/test_util_http.py
+++ b/st2common/tests/unit/test_util_http.py
@@ -19,24 +19,22 @@
from st2common.util.http import parse_content_type_header
from six.moves import zip
-__all__ = [
- 'HTTPUtilTestCase'
-]
+__all__ = ["HTTPUtilTestCase"]
class HTTPUtilTestCase(unittest2.TestCase):
def test_parse_content_type_header(self):
values = [
- 'application/json',
- 'foo/bar',
- 'application/json; charset=utf-8',
- 'application/json; charset=utf-8; foo=bar',
+ "application/json",
+ "foo/bar",
+ "application/json; charset=utf-8",
+ "application/json; charset=utf-8; foo=bar",
]
expected_results = [
- ('application/json', {}),
- ('foo/bar', {}),
- ('application/json', {'charset': 'utf-8'}),
- ('application/json', {'charset': 'utf-8', 'foo': 'bar'})
+ ("application/json", {}),
+ ("foo/bar", {}),
+ ("application/json", {"charset": "utf-8"}),
+ ("application/json", {"charset": "utf-8", "foo": "bar"}),
]
for value, expected_result in zip(values, expected_results):
diff --git a/st2common/tests/unit/test_util_jinja.py b/st2common/tests/unit/test_util_jinja.py
index 1b56adc0e9..127570f54b 100644
--- a/st2common/tests/unit/test_util_jinja.py
+++ b/st2common/tests/unit/test_util_jinja.py
@@ -21,97 +21,95 @@
class JinjaUtilsRenderTestCase(unittest2.TestCase):
-
def test_render_values(self):
actual = jinja_utils.render_values(
- mapping={'k1': '{{a}}', 'k2': '{{b}}'},
- context={'a': 'v1', 'b': 'v2'})
- expected = {'k2': 'v2', 'k1': 'v1'}
+ mapping={"k1": "{{a}}", "k2": "{{b}}"}, context={"a": "v1", "b": "v2"}
+ )
+ expected = {"k2": "v2", "k1": "v1"}
self.assertEqual(actual, expected)
def test_render_values_skip_missing(self):
actual = jinja_utils.render_values(
- mapping={'k1': '{{a}}', 'k2': '{{b}}', 'k3': '{{c}}'},
- context={'a': 'v1', 'b': 'v2'},
- allow_undefined=True)
- expected = {'k2': 'v2', 'k1': 'v1', 'k3': ''}
+ mapping={"k1": "{{a}}", "k2": "{{b}}", "k3": "{{c}}"},
+ context={"a": "v1", "b": "v2"},
+ allow_undefined=True,
+ )
+ expected = {"k2": "v2", "k1": "v1", "k3": ""}
self.assertEqual(actual, expected)
def test_render_values_ascii_and_unicode_values(self):
- mapping = {
- u'k_ascii': '{{a}}',
- u'k_unicode': '{{b}}',
- u'k_ascii_unicode': '{{c}}'}
+ mapping = {"k_ascii": "{{a}}", "k_unicode": "{{b}}", "k_ascii_unicode": "{{c}}"}
context = {
- 'a': u'some ascii value',
- 'b': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž',
- 'c': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ '
+ "a": "some ascii value",
+ "b": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž",
+ "c": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ",
}
expected = {
- 'k_ascii': u'some ascii value',
- 'k_unicode': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž',
- 'k_ascii_unicode': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ '
+ "k_ascii": "some ascii value",
+ "k_unicode": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž",
+ "k_ascii_unicode": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ",
}
actual = jinja_utils.render_values(
- mapping=mapping,
- context=context,
- allow_undefined=True)
+ mapping=mapping, context=context, allow_undefined=True
+ )
self.assertEqual(actual, expected)
def test_convert_str_to_raw(self):
- jinja_expr = '{{foobar}}'
- expected_raw_block = '{% raw %}{{foobar}}{% endraw %}'
- self.assertEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr))
+ jinja_expr = "{{foobar}}"
+ expected_raw_block = "{% raw %}{{foobar}}{% endraw %}"
+ self.assertEqual(
+ expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)
+ )
- jinja_block_expr = '{% for item in items %}foobar{% end for %}'
- expected_raw_block = '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}'
+ jinja_block_expr = "{% for item in items %}foobar{% end for %}"
+ expected_raw_block = (
+ "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}"
+ )
self.assertEqual(
- expected_raw_block,
- jinja_utils.convert_jinja_to_raw_block(jinja_block_expr)
+ expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_block_expr)
)
def test_convert_list_to_raw(self):
jinja_expr = [
- 'foobar',
- '{{foo}}',
- '{{bar}}',
- '{% for item in items %}foobar{% end for %}',
- {'foobar': '{{foobar}}'}
+ "foobar",
+ "{{foo}}",
+ "{{bar}}",
+ "{% for item in items %}foobar{% end for %}",
+ {"foobar": "{{foobar}}"},
]
expected_raw_block = [
- 'foobar',
- '{% raw %}{{foo}}{% endraw %}',
- '{% raw %}{{bar}}{% endraw %}',
- '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}',
- {'foobar': '{% raw %}{{foobar}}{% endraw %}'}
+ "foobar",
+ "{% raw %}{{foo}}{% endraw %}",
+ "{% raw %}{{bar}}{% endraw %}",
+ "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}",
+ {"foobar": "{% raw %}{{foobar}}{% endraw %}"},
]
- self.assertListEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr))
+ self.assertListEqual(
+ expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)
+ )
def test_convert_dict_to_raw(self):
jinja_expr = {
- 'var1': 'foobar',
- 'var2': ['{{foo}}', '{{bar}}'],
- 'var3': {'foobar': '{{foobar}}'},
- 'var4': {'foobar': '{% for item in items %}foobar{% end for %}'}
+ "var1": "foobar",
+ "var2": ["{{foo}}", "{{bar}}"],
+ "var3": {"foobar": "{{foobar}}"},
+ "var4": {"foobar": "{% for item in items %}foobar{% end for %}"},
}
expected_raw_block = {
- 'var1': 'foobar',
- 'var2': [
- '{% raw %}{{foo}}{% endraw %}',
- '{% raw %}{{bar}}{% endraw %}'
- ],
- 'var3': {
- 'foobar': '{% raw %}{{foobar}}{% endraw %}'
+ "var1": "foobar",
+ "var2": ["{% raw %}{{foo}}{% endraw %}", "{% raw %}{{bar}}{% endraw %}"],
+ "var3": {"foobar": "{% raw %}{{foobar}}{% endraw %}"},
+ "var4": {
+ "foobar": "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}"
},
- 'var4': {
- 'foobar': '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}'
- }
}
- self.assertDictEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr))
+ self.assertDictEqual(
+ expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)
+ )
diff --git a/st2common/tests/unit/test_util_keyvalue.py b/st2common/tests/unit/test_util_keyvalue.py
index 07f061e60a..5a8c15a3f7 100644
--- a/st2common/tests/unit/test_util_keyvalue.py
+++ b/st2common/tests/unit/test_util_keyvalue.py
@@ -18,14 +18,19 @@
import unittest2
from st2common.util import keyvalue as kv_utl
-from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE,
- ALL_SCOPE, DATASTORE_PARENT_SCOPE,
- DATASTORE_SCOPE_SEPARATOR)
+from st2common.constants.keyvalue import (
+ FULL_SYSTEM_SCOPE,
+ FULL_USER_SCOPE,
+ USER_SCOPE,
+ ALL_SCOPE,
+ DATASTORE_PARENT_SCOPE,
+ DATASTORE_SCOPE_SEPARATOR,
+)
from st2common.exceptions.rbac import AccessDeniedError
from st2common.models.db import auth as auth_db
-USER = 'stanley'
+USER = "stanley"
class TestKeyValueUtil(unittest2.TestCase):
@@ -38,48 +43,26 @@ def test_validate_scope(self):
kv_utl._validate_scope(scope)
def test_validate_scope_with_invalid_scope(self):
- scope = 'INVALID_SCOPE'
+ scope = "INVALID_SCOPE"
self.assertRaises(ValueError, kv_utl._validate_scope, scope)
def test_validate_decrypt_query_parameter(self):
test_params = [
- [
- False,
- USER_SCOPE,
- False,
- {}
- ],
- [
- True,
- USER_SCOPE,
- False,
- {}
- ],
- [
- True,
- FULL_SYSTEM_SCOPE,
- True,
- {}
- ],
+ [False, USER_SCOPE, False, {}],
+ [True, USER_SCOPE, False, {}],
+ [True, FULL_SYSTEM_SCOPE, True, {}],
]
for params in test_params:
kv_utl._validate_decrypt_query_parameter(*params)
def test_validate_decrypt_query_parameter_access_denied(self):
- test_params = [
- [
- True,
- FULL_SYSTEM_SCOPE,
- False,
- {}
- ]
- ]
+ test_params = [[True, FULL_SYSTEM_SCOPE, False, {}]]
for params in test_params:
assert_params = [
AccessDeniedError,
- kv_utl._validate_decrypt_query_parameter
+ kv_utl._validate_decrypt_query_parameter,
]
assert_params.extend(params)
@@ -88,81 +71,58 @@ def test_validate_decrypt_query_parameter_access_denied(self):
def test_get_datastore_full_scope(self):
self.assertEqual(
kv_utl.get_datastore_full_scope(USER_SCOPE),
- DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE])
+ DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE]),
)
def test_get_datastore_full_scope_all_scope(self):
- self.assertEqual(
- kv_utl.get_datastore_full_scope(ALL_SCOPE),
- ALL_SCOPE
- )
+ self.assertEqual(kv_utl.get_datastore_full_scope(ALL_SCOPE), ALL_SCOPE)
def test_get_datastore_full_scope_datastore_parent_scope(self):
self.assertEqual(
kv_utl.get_datastore_full_scope(DATASTORE_PARENT_SCOPE),
- DATASTORE_PARENT_SCOPE
+ DATASTORE_PARENT_SCOPE,
)
def test_derive_scope_and_key(self):
- key = 'test'
+ key = "test"
scope = USER_SCOPE
result = kv_utl._derive_scope_and_key(key, scope)
- self.assertEqual(
- (FULL_USER_SCOPE, 'user:%s' % key),
- result
- )
+ self.assertEqual((FULL_USER_SCOPE, "user:%s" % key), result)
def test_derive_scope_and_key_without_scope(self):
- key = 'test'
+ key = "test"
scope = None
result = kv_utl._derive_scope_and_key(key, scope)
- self.assertEqual(
- (FULL_USER_SCOPE, 'None:%s' % key),
- result
- )
+ self.assertEqual((FULL_USER_SCOPE, "None:%s" % key), result)
def test_derive_scope_and_key_system_key(self):
- key = 'system.test'
+ key = "system.test"
scope = None
result = kv_utl._derive_scope_and_key(key, scope)
- self.assertEqual(
- (FULL_SYSTEM_SCOPE, key.split('.')[1]),
- result
- )
+ self.assertEqual((FULL_SYSTEM_SCOPE, key.split(".")[1]), result)
- @mock.patch('st2common.util.keyvalue.KeyValuePair')
- @mock.patch('st2common.util.keyvalue.deserialize_key_value')
+ @mock.patch("st2common.util.keyvalue.KeyValuePair")
+ @mock.patch("st2common.util.keyvalue.deserialize_key_value")
def test_get_key(self, deseralize_key_value, KeyValuePair):
- key, value = ('Lindsay', 'Lohan')
+ key, value = ("Lindsay", "Lohan")
decrypt = False
KeyValuePair.get_by_scope_and_name().value = value
deseralize_key_value.return_value = value
- result = kv_utl.get_key(key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt)
+ result = kv_utl.get_key(
+ key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt
+ )
self.assertEqual(result, value)
KeyValuePair.get_by_scope_and_name.assert_called_with(
- FULL_USER_SCOPE,
- 'stanley:%s' % key
- )
- deseralize_key_value.assert_called_once_with(
- value,
- decrypt
+ FULL_USER_SCOPE, "stanley:%s" % key
)
+ deseralize_key_value.assert_called_once_with(value, decrypt)
def test_get_key_invalid_input(self):
- self.assertRaises(
- TypeError,
- kv_utl.get_key,
- key=1
- )
- self.assertRaises(
- TypeError,
- kv_utl.get_key,
- key='test',
- decrypt='yep'
- )
+ self.assertRaises(TypeError, kv_utl.get_key, key=1)
+ self.assertRaises(TypeError, kv_utl.get_key, key="test", decrypt="yep")
diff --git a/st2common/tests/unit/test_util_output_schema.py b/st2common/tests/unit/test_util_output_schema.py
index d3ef387a26..af9570d4fa 100644
--- a/st2common/tests/unit/test_util_output_schema.py
+++ b/st2common/tests/unit/test_util_output_schema.py
@@ -19,58 +19,46 @@
from st2common.constants.action import (
LIVEACTION_STATUS_SUCCEEDED,
- LIVEACTION_STATUS_FAILED
+ LIVEACTION_STATUS_FAILED,
)
ACTION_RESULT = {
- 'output': {
- 'output_1': 'Bobby',
- 'output_2': 5,
- 'deep_output': {
- 'deep_item_1': 'Jindal',
+ "output": {
+ "output_1": "Bobby",
+ "output_2": 5,
+ "deep_output": {
+ "deep_item_1": "Jindal",
},
}
}
RUNNER_SCHEMA = {
- 'output': {
- 'type': 'object'
- },
- 'error': {
- 'type': 'array'
- },
+ "output": {"type": "object"},
+ "error": {"type": "array"},
}
ACTION_SCHEMA = {
- 'output_1': {
- 'type': 'string'
- },
- 'output_2': {
- 'type': 'integer'
- },
- 'deep_output': {
- 'type': 'object',
- 'parameters': {
- 'deep_item_1': {
- 'type': 'string',
+ "output_1": {"type": "string"},
+ "output_2": {"type": "integer"},
+ "deep_output": {
+ "type": "object",
+ "parameters": {
+ "deep_item_1": {
+ "type": "string",
},
},
},
}
RUNNER_SCHEMA_FAIL = {
- 'not_a_key_you_have': {
- 'type': 'string'
- },
+ "not_a_key_you_have": {"type": "string"},
}
ACTION_SCHEMA_FAIL = {
- 'not_a_key_you_have': {
- 'type': 'string'
- },
+ "not_a_key_you_have": {"type": "string"},
}
-OUTPUT_KEY = 'output'
+OUTPUT_KEY = "output"
class OutputSchemaTestCase(unittest2.TestCase):
@@ -96,7 +84,7 @@ def test_invalid_runner_schema(self):
)
expected_result = {
- 'error': (
+ "error": (
"Additional properties are not allowed ('output' was unexpected)"
"\n\nFailed validating 'additionalProperties' in schema:\n {'addi"
"tionalProperties': False,\n 'properties': {'not_a_key_you_have': "
@@ -104,7 +92,7 @@ def test_invalid_runner_schema(self):
"output': {'deep_output': {'deep_item_1': 'Jindal'},\n "
"'output_1': 'Bobby',\n 'output_2': 5}}"
),
- 'message': 'Error validating output. See error output for more details.'
+ "message": "Error validating output. See error output for more details.",
}
self.assertEqual(result, expected_result)
@@ -120,12 +108,12 @@ def test_invalid_action_schema(self):
)
expected_result = {
- 'error': "Additional properties are not allowed",
- 'message': u'Error validating output. See error output for more details.'
+ "error": "Additional properties are not allowed",
+ "message": "Error validating output. See error output for more details.",
}
# To avoid random failures (especially in python3) this assert cant be
# exact since the parameters can be ordered differently per execution.
- self.assertIn(expected_result['error'], result['error'])
- self.assertEqual(result['message'], expected_result['message'])
+ self.assertIn(expected_result["error"], result["error"])
+ self.assertEqual(result["message"], expected_result["message"])
self.assertEqual(status, LIVEACTION_STATUS_FAILED)
diff --git a/st2common/tests/unit/test_util_pack.py b/st2common/tests/unit/test_util_pack.py
index 20522b8b18..0b476b7336 100644
--- a/st2common/tests/unit/test_util_pack.py
+++ b/st2common/tests/unit/test_util_pack.py
@@ -22,59 +22,47 @@
class PackUtilsTestCase(unittest2.TestCase):
-
def test_get_pack_common_libs_path_for_pack_db(self):
pack_model_args = {
- 'name': 'Yolo CI',
- 'ref': 'yolo_ci',
- 'description': 'YOLO CI pack',
- 'version': '0.1.0',
- 'author': 'Volkswagen',
- 'path': '/opt/stackstorm/packs/yolo_ci/'
+ "name": "Yolo CI",
+ "ref": "yolo_ci",
+ "description": "YOLO CI pack",
+ "version": "0.1.0",
+ "author": "Volkswagen",
+ "path": "/opt/stackstorm/packs/yolo_ci/",
}
pack_db = PackDB(**pack_model_args)
lib_path = get_pack_common_libs_path_for_pack_db(pack_db)
- self.assertEqual('/opt/stackstorm/packs/yolo_ci/lib', lib_path)
+ self.assertEqual("/opt/stackstorm/packs/yolo_ci/lib", lib_path)
def test_get_pack_common_libs_path_for_pack_db_no_path_in_pack_db(self):
pack_model_args = {
- 'name': 'Yolo CI',
- 'ref': 'yolo_ci',
- 'description': 'YOLO CI pack',
- 'version': '0.1.0',
- 'author': 'Volkswagen'
+ "name": "Yolo CI",
+ "ref": "yolo_ci",
+ "description": "YOLO CI pack",
+ "version": "0.1.0",
+ "author": "Volkswagen",
}
pack_db = PackDB(**pack_model_args)
lib_path = get_pack_common_libs_path_for_pack_db(pack_db)
self.assertEqual(None, lib_path)
def test_get_pack_warnings_python2_only(self):
- pack_metadata = {
- 'python_versions': ['2'],
- 'name': 'Pack2'
- }
+ pack_metadata = {"python_versions": ["2"], "name": "Pack2"}
warning = get_pack_warnings(pack_metadata)
self.assertTrue("DEPRECATION WARNING" in warning)
def test_get_pack_warnings_python3_only(self):
- pack_metadata = {
- 'python_versions': ['3'],
- 'name': 'Pack3'
- }
+ pack_metadata = {"python_versions": ["3"], "name": "Pack3"}
warning = get_pack_warnings(pack_metadata)
self.assertEqual(None, warning)
def test_get_pack_warnings_python2_and_3(self):
- pack_metadata = {
- 'python_versions': ['2', '3'],
- 'name': 'Pack23'
- }
+ pack_metadata = {"python_versions": ["2", "3"], "name": "Pack23"}
warning = get_pack_warnings(pack_metadata)
self.assertEqual(None, warning)
def test_get_pack_warnings_no_python(self):
- pack_metadata = {
- 'name': 'PackNone'
- }
+ pack_metadata = {"name": "PackNone"}
warning = get_pack_warnings(pack_metadata)
self.assertEqual(None, warning)
diff --git a/st2common/tests/unit/test_util_payload.py b/st2common/tests/unit/test_util_payload.py
index 207d4c1766..2621e3de91 100644
--- a/st2common/tests/unit/test_util_payload.py
+++ b/st2common/tests/unit/test_util_payload.py
@@ -19,27 +19,31 @@
from st2common.util.payload import PayloadLookup
-__all__ = [
- 'PayloadLookupTestCase'
-]
+__all__ = ["PayloadLookupTestCase"]
class PayloadLookupTestCase(unittest2.TestCase):
@classmethod
def setUpClass(cls):
- cls.payload = PayloadLookup({
- 'pikachu': "Has no ears",
- 'charmander': "Plays with fire",
- })
+ cls.payload = PayloadLookup(
+ {
+ "pikachu": "Has no ears",
+ "charmander": "Plays with fire",
+ }
+ )
super(PayloadLookupTestCase, cls).setUpClass()
def test_get_key(self):
- self.assertEqual(self.payload.get_value('trigger.pikachu'), ["Has no ears"])
- self.assertEqual(self.payload.get_value('trigger.charmander'), ["Plays with fire"])
+ self.assertEqual(self.payload.get_value("trigger.pikachu"), ["Has no ears"])
+ self.assertEqual(
+ self.payload.get_value("trigger.charmander"), ["Plays with fire"]
+ )
def test_explicitly_get_multiple_keys(self):
- self.assertEqual(self.payload.get_value('trigger.pikachu[*]'), ["Has no ears"])
- self.assertEqual(self.payload.get_value('trigger.charmander[*]'), ["Plays with fire"])
+ self.assertEqual(self.payload.get_value("trigger.pikachu[*]"), ["Has no ears"])
+ self.assertEqual(
+ self.payload.get_value("trigger.charmander[*]"), ["Plays with fire"]
+ )
def test_get_nonexistent_key(self):
- self.assertIsNone(self.payload.get_value('trigger.squirtle'))
+ self.assertIsNone(self.payload.get_value("trigger.squirtle"))
diff --git a/st2common/tests/unit/test_util_sandboxing.py b/st2common/tests/unit/test_util_sandboxing.py
index 5f387e0067..3926c9f74c 100644
--- a/st2common/tests/unit/test_util_sandboxing.py
+++ b/st2common/tests/unit/test_util_sandboxing.py
@@ -32,9 +32,7 @@
import st2tests.config as tests_config
-__all__ = [
- 'SandboxingUtilsTestCase'
-]
+__all__ = ["SandboxingUtilsTestCase"]
class SandboxingUtilsTestCase(unittest.TestCase):
@@ -69,8 +67,10 @@ def assertEndsWith(self, string, ending_substr, msg=None):
def test_get_sandbox_python_binary_path(self):
# Non-system content pack, should use pack specific virtualenv binary
- result = get_sandbox_python_binary_path(pack='mapack')
- expected = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/mapack/bin/python')
+ result = get_sandbox_python_binary_path(pack="mapack")
+ expected = os.path.join(
+ cfg.CONF.system.base_path, "virtualenvs/mapack/bin/python"
+ )
self.assertEqual(result, expected)
# System content pack, should use current process (system) python binary
@@ -78,159 +78,190 @@ def test_get_sandbox_python_binary_path(self):
self.assertEqual(result, sys.executable)
def test_get_sandbox_path(self):
- virtualenv_path = '/home/venv/test'
+ virtualenv_path = "/home/venv/test"
# Mock the current PATH value
- with mock.patch.dict(os.environ, {'PATH': '/home/path1:/home/path2:/home/path3:'}):
+ with mock.patch.dict(
+ os.environ, {"PATH": "/home/path1:/home/path2:/home/path3:"}
+ ):
result = get_sandbox_path(virtualenv_path=virtualenv_path)
- self.assertEqual(result, f'{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3')
+ self.assertEqual(
+ result, f"{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3"
+ )
- @mock.patch('st2common.util.sandboxing.get_python_lib')
+ @mock.patch("st2common.util.sandboxing.get_python_lib")
def test_get_sandbox_python_path(self, mock_get_python_lib):
# No inheritance
- python_path = get_sandbox_python_path(inherit_from_parent=False,
- inherit_parent_virtualenv=False)
- self.assertEqual(python_path, ':')
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=False, inherit_parent_virtualenv=False
+ )
+ self.assertEqual(python_path, ":")
# Inherit python path from current process
# Mock the current process python path
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=False)
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=False
+ )
- self.assertEqual(python_path, ':/data/test1:/data/test2')
+ self.assertEqual(python_path, ":/data/test1:/data/test2")
# Inherit from current process and from virtualenv (not running inside virtualenv)
clear_virtualenv_prefix()
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=False)
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=False
+ )
- self.assertEqual(python_path, ':/data/test1:/data/test2')
+ self.assertEqual(python_path, ":/data/test1:/data/test2")
# Inherit from current process and from virtualenv (running inside virtualenv)
- sys.real_prefix = '/usr'
- mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest'
+ sys.real_prefix = "/usr"
+ mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest"
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=True)
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=True
+ )
- self.assertEqual(python_path, f':/data/test1:/data/test2:{sys.prefix}/virtualenvtest')
+ self.assertEqual(
+ python_path, f":/data/test1:/data/test2:{sys.prefix}/virtualenvtest"
+ )
- @mock.patch('os.path.isdir', mock.Mock(return_value=True))
- @mock.patch('os.listdir', mock.Mock(return_value=['python3.6']))
- @mock.patch('st2common.util.sandboxing.get_python_lib')
- def test_get_sandbox_python_path_for_python_action_no_inheritance(self,
- mock_get_python_lib):
+ @mock.patch("os.path.isdir", mock.Mock(return_value=True))
+ @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"]))
+ @mock.patch("st2common.util.sandboxing.get_python_lib")
+ def test_get_sandbox_python_path_for_python_action_no_inheritance(
+ self, mock_get_python_lib
+ ):
# No inheritance
- python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack',
- inherit_from_parent=False,
- inherit_parent_virtualenv=False)
+ python_path = get_sandbox_python_path_for_python_action(
+ pack="dummy_pack",
+ inherit_from_parent=False,
+ inherit_parent_virtualenv=False,
+ )
- actual_path = python_path.strip(':').split(':')
+ actual_path = python_path.strip(":").split(":")
self.assertEqual(len(actual_path), 3)
# First entry should be lib/python3 dir from venv
- self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6')
+ self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6")
# Second entry should be python3 site-packages dir from venv
- self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages')
+ self.assertEndsWith(
+ actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages"
+ )
# Third entry should be actions/lib dir from pack root directory
- self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib')
+ self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib")
- @mock.patch('os.path.isdir', mock.Mock(return_value=True))
- @mock.patch('os.listdir', mock.Mock(return_value=['python3.6']))
- @mock.patch('st2common.util.sandboxing.get_python_lib')
- def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only(self,
- mock_get_python_lib):
+ @mock.patch("os.path.isdir", mock.Mock(return_value=True))
+ @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"]))
+ @mock.patch("st2common.util.sandboxing.get_python_lib")
+ def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only(
+ self, mock_get_python_lib
+ ):
# Inherit python path from current process
# Mock the current process python path
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=False)
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=False
+ )
- self.assertEqual(python_path, ':/data/test1:/data/test2')
+ self.assertEqual(python_path, ":/data/test1:/data/test2")
- python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack',
- inherit_from_parent=True,
- inherit_parent_virtualenv=False)
+ python_path = get_sandbox_python_path_for_python_action(
+ pack="dummy_pack",
+ inherit_from_parent=True,
+ inherit_parent_virtualenv=False,
+ )
- actual_path = python_path.strip(':').split(':')
+ actual_path = python_path.strip(":").split(":")
self.assertEqual(len(actual_path), 6)
# First entry should be lib/python3 dir from venv
- self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6')
+ self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6")
# Second entry should be python3 site-packages dir from venv
- self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages')
+ self.assertEndsWith(
+ actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages"
+ )
# Third entry should be actions/lib dir from pack root directory
- self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib')
+ self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib")
# And the rest of the paths from get_sandbox_python_path
- self.assertEqual(actual_path[3], '')
- self.assertEqual(actual_path[4], '/data/test1')
- self.assertEqual(actual_path[5], '/data/test2')
+ self.assertEqual(actual_path[3], "")
+ self.assertEqual(actual_path[4], "/data/test1")
+ self.assertEqual(actual_path[5], "/data/test2")
- @mock.patch('os.path.isdir', mock.Mock(return_value=True))
- @mock.patch('os.listdir', mock.Mock(return_value=['python3.6']))
- @mock.patch('st2common.util.sandboxing.get_python_lib')
- def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv(self,
- mock_get_python_lib):
+ @mock.patch("os.path.isdir", mock.Mock(return_value=True))
+ @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"]))
+ @mock.patch("st2common.util.sandboxing.get_python_lib")
+ def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv(
+ self, mock_get_python_lib
+ ):
# Inherit from current process and from virtualenv (not running inside virtualenv)
clear_virtualenv_prefix()
# Inherit python path from current process
# Mock the current process python path
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=False)
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=False
+ )
- self.assertEqual(python_path, ':/data/test1:/data/test2')
+ self.assertEqual(python_path, ":/data/test1:/data/test2")
- python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack',
- inherit_from_parent=True,
- inherit_parent_virtualenv=True)
+ python_path = get_sandbox_python_path_for_python_action(
+ pack="dummy_pack",
+ inherit_from_parent=True,
+ inherit_parent_virtualenv=True,
+ )
- actual_path = python_path.strip(':').split(':')
+ actual_path = python_path.strip(":").split(":")
self.assertEqual(len(actual_path), 6)
# First entry should be lib/python3 dir from venv
- self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6')
+ self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6")
# Second entry should be python3 site-packages dir from venv
- self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages')
+ self.assertEndsWith(
+ actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages"
+ )
# Third entry should be actions/lib dir from pack root directory
- self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib')
+ self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib")
# And the rest of the paths from get_sandbox_python_path
- self.assertEqual(actual_path[3], '')
- self.assertEqual(actual_path[4], '/data/test1')
- self.assertEqual(actual_path[5], '/data/test2')
+ self.assertEqual(actual_path[3], "")
+ self.assertEqual(actual_path[4], "/data/test1")
+ self.assertEqual(actual_path[5], "/data/test2")
# Inherit from current process and from virtualenv (running inside virtualenv)
- sys.real_prefix = '/usr'
- mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest'
+ sys.real_prefix = "/usr"
+ mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest"
# Inherit python path from current process
# Mock the current process python path
- with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}):
- python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack',
- inherit_from_parent=True,
- inherit_parent_virtualenv=True)
-
- actual_path = python_path.strip(':').split(':')
+ with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}):
+ python_path = get_sandbox_python_path_for_python_action(
+ pack="dummy_pack",
+ inherit_from_parent=True,
+ inherit_parent_virtualenv=True,
+ )
+
+ actual_path = python_path.strip(":").split(":")
self.assertEqual(len(actual_path), 7)
# First entry should be lib/python3 dir from venv
- self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6')
+ self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6")
# Second entry should be python3 site-packages dir from venv
- self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages')
+ self.assertEndsWith(
+ actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages"
+ )
# Third entry should be actions/lib dir from pack root directory
- self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib')
+ self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib")
# The paths from get_sandbox_python_path
- self.assertEqual(actual_path[3], '')
- self.assertEqual(actual_path[4], '/data/test1')
- self.assertEqual(actual_path[5], '/data/test2')
+ self.assertEqual(actual_path[3], "")
+ self.assertEqual(actual_path[4], "/data/test1")
+ self.assertEqual(actual_path[5], "/data/test2")
# And the parent virtualenv
- self.assertEqual(actual_path[6], f'{sys.prefix}/virtualenvtest')
+ self.assertEqual(actual_path[6], f"{sys.prefix}/virtualenvtest")
diff --git a/st2common/tests/unit/test_util_secrets.py b/st2common/tests/unit/test_util_secrets.py
index f49f8f76a9..8c77c34f49 100644
--- a/st2common/tests/unit/test_util_secrets.py
+++ b/st2common/tests/unit/test_util_secrets.py
@@ -22,38 +22,30 @@
################################################################################
TEST_FLAT_SCHEMA = {
- 'arg_required_no_default': {
- 'description': 'Foo',
- 'required': True,
- 'type': 'string',
- 'secret': False
+ "arg_required_no_default": {
+ "description": "Foo",
+ "required": True,
+ "type": "string",
+ "secret": False,
},
- 'arg_optional_no_type_secret': {
- 'description': 'Bar',
- 'secret': True
- },
- 'arg_optional_type_array': {
- 'description': 'Who''s the fairest?',
- 'type': 'array'
- },
- 'arg_optional_type_object': {
- 'description': 'Who''s the fairest of them?',
- 'type': 'object'
+ "arg_optional_no_type_secret": {"description": "Bar", "secret": True},
+ "arg_optional_type_array": {"description": "Who" "s the fairest?", "type": "array"},
+ "arg_optional_type_object": {
+ "description": "Who" "s the fairest of them?",
+ "type": "object",
},
}
-TEST_FLAT_SECRET_PARAMS = {
- 'arg_optional_no_type_secret': None
-}
+TEST_FLAT_SECRET_PARAMS = {"arg_optional_no_type_secret": None}
################################################################################
TEST_NO_SECRETS_SCHEMA = {
- 'arg_required_no_default': {
- 'description': 'Foo',
- 'required': True,
- 'type': 'string',
- 'secret': False
+ "arg_required_no_default": {
+ "description": "Foo",
+ "required": True,
+ "type": "string",
+ "secret": False,
}
}
@@ -62,497 +54,397 @@
################################################################################
TEST_NESTED_OBJECTS_SCHEMA = {
- 'arg_string': {
- 'description': 'Junk',
- 'type': 'string',
+ "arg_string": {
+ "description": "Junk",
+ "type": "string",
},
- 'arg_optional_object': {
- 'description': 'Mirror',
- 'type': 'object',
- 'properties': {
- 'arg_nested_object': {
- 'description': 'Mirror mirror',
- 'type': 'object',
- 'properties': {
- 'arg_double_nested_secret': {
- 'description': 'Deep, deep down',
- 'type': 'string',
- 'secret': True
+ "arg_optional_object": {
+ "description": "Mirror",
+ "type": "object",
+ "properties": {
+ "arg_nested_object": {
+ "description": "Mirror mirror",
+ "type": "object",
+ "properties": {
+ "arg_double_nested_secret": {
+ "description": "Deep, deep down",
+ "type": "string",
+ "secret": True,
}
- }
+ },
},
- 'arg_nested_secret': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
- }
+ "arg_nested_secret": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
+ },
+ },
+ },
}
TEST_NESTED_OBJECTS_SECRET_PARAMS = {
- 'arg_optional_object': {
- 'arg_nested_secret': 'string',
- 'arg_nested_object': {
- 'arg_double_nested_secret': 'string',
- }
+ "arg_optional_object": {
+ "arg_nested_secret": "string",
+ "arg_nested_object": {
+ "arg_double_nested_secret": "string",
+ },
}
}
################################################################################
TEST_ARRAY_SCHEMA = {
- 'arg_optional_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'down',
- 'type': 'string',
- 'secret': True
- }
+ "arg_optional_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {"description": "down", "type": "string", "secret": True},
}
}
-TEST_ARRAY_SECRET_PARAMS = {
- 'arg_optional_array': [
- 'string'
- ]
-}
+TEST_ARRAY_SECRET_PARAMS = {"arg_optional_array": ["string"]}
################################################################################
TEST_ROOT_ARRAY_SCHEMA = {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'down',
- 'type': 'object',
- 'properties': {
- 'secret_field_in_object': {
- 'type': 'string',
- 'secret': True
- }
- }
- }
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "description": "down",
+ "type": "object",
+ "properties": {"secret_field_in_object": {"type": "string", "secret": True}},
+ },
}
-TEST_ROOT_ARRAY_SECRET_PARAMS = [
- {
- 'secret_field_in_object': 'string'
- }
-]
+TEST_ROOT_ARRAY_SECRET_PARAMS = [{"secret_field_in_object": "string"}]
################################################################################
TEST_ROOT_OBJECT_SCHEMA = {
- 'description': 'root',
- 'type': 'object',
- 'properties': {
- 'arg_level_one': {
- 'description': 'down',
- 'type': 'object',
- 'properties': {
- 'secret_field_in_object': {
- 'type': 'string',
- 'secret': True
- }
- }
+ "description": "root",
+ "type": "object",
+ "properties": {
+ "arg_level_one": {
+ "description": "down",
+ "type": "object",
+ "properties": {
+ "secret_field_in_object": {"type": "string", "secret": True}
+ },
}
- }
+ },
}
-TEST_ROOT_OBJECT_SECRET_PARAMS = {
- 'arg_level_one':
- {
- 'secret_field_in_object': 'string'
- }
-}
+TEST_ROOT_OBJECT_SECRET_PARAMS = {"arg_level_one": {"secret_field_in_object": "string"}}
################################################################################
TEST_NESTED_ARRAYS_SCHEMA = {
- 'arg_optional_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
+ "arg_optional_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {"description": "Deep down", "type": "string", "secret": True},
},
- 'arg_optional_double_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
+ "arg_optional_double_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {"description": "Deep down", "type": "string", "secret": True},
+ },
},
- 'arg_optional_tripple_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
- }
+ "arg_optional_tripple_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {"description": "Deep down", "type": "string", "secret": True},
+ },
+ },
+ },
+ "arg_optional_quad_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
+ },
+ },
+ },
+ },
},
- 'arg_optional_quad_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
- }
- }
- }
}
TEST_NESTED_ARRAYS_SECRET_PARAMS = {
- 'arg_optional_array': [
- 'string'
- ],
- 'arg_optional_double_array': [
- [
- 'string'
- ]
- ],
- 'arg_optional_tripple_array': [
- [
- [
- 'string'
- ]
- ]
- ],
- 'arg_optional_quad_array': [
- [
- [
- [
- 'string'
- ]
- ]
- ]
- ]
+ "arg_optional_array": ["string"],
+ "arg_optional_double_array": [["string"]],
+ "arg_optional_tripple_array": [[["string"]]],
+ "arg_optional_quad_array": [[[["string"]]]],
}
################################################################################
TEST_NESTED_OBJECT_WITH_ARRAY_SCHEMA = {
- 'arg_optional_object_with_array': {
- 'description': 'Mirror',
- 'type': 'object',
- 'properties': {
- 'arg_nested_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
+ "arg_optional_object_with_array": {
+ "description": "Mirror",
+ "type": "object",
+ "properties": {
+ "arg_nested_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {"description": "Deep down", "type": "string", "secret": True},
}
- }
+ },
}
}
TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS = {
- 'arg_optional_object_with_array': {
- 'arg_nested_array': [
- 'string'
- ]
- }
+ "arg_optional_object_with_array": {"arg_nested_array": ["string"]}
}
################################################################################
TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA = {
- 'arg_optional_object_with_double_array': {
- 'description': 'Mirror',
- 'type': 'object',
- 'properties': {
- 'arg_double_nested_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
+ "arg_optional_object_with_double_array": {
+ "description": "Mirror",
+ "type": "object",
+ "properties": {
+ "arg_double_nested_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
+ },
+ },
}
- }
+ },
}
}
TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS = {
- 'arg_optional_object_with_double_array': {
- 'arg_double_nested_array': [
- [
- 'string'
- ]
- ]
- }
+ "arg_optional_object_with_double_array": {"arg_double_nested_array": [["string"]]}
}
################################################################################
TEST_NESTED_ARRAY_WITH_OBJECT_SCHEMA = {
- 'arg_optional_array_with_object': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'description': 'Mirror',
- 'type': 'object',
- 'properties': {
- 'arg_nested_secret': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
+ "arg_optional_array_with_object": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "description": "Mirror",
+ "type": "object",
+ "properties": {
+ "arg_nested_secret": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
}
- }
- }
+ },
+ },
}
}
TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS = {
- 'arg_optional_array_with_object': [
- {
- 'arg_nested_secret': 'string'
- }
- ]
+ "arg_optional_array_with_object": [{"arg_nested_secret": "string"}]
}
################################################################################
TEST_SECRET_ARRAY_SCHEMA = {
- 'arg_secret_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'secret': True,
+ "arg_secret_array": {
+ "description": "Mirror",
+ "type": "array",
+ "secret": True,
}
}
-TEST_SECRET_ARRAY_SECRET_PARAMS = {
- 'arg_secret_array': 'array'
-}
+TEST_SECRET_ARRAY_SECRET_PARAMS = {"arg_secret_array": "array"}
################################################################################
TEST_SECRET_OBJECT_SCHEMA = {
- 'arg_secret_object': {
- 'type': 'object',
- 'secret': True,
+ "arg_secret_object": {
+ "type": "object",
+ "secret": True,
}
}
-TEST_SECRET_OBJECT_SECRET_PARAMS = {
- 'arg_secret_object': 'object'
-}
+TEST_SECRET_OBJECT_SECRET_PARAMS = {"arg_secret_object": "object"}
################################################################################
TEST_SECRET_ROOT_ARRAY_SCHEMA = {
- 'description': 'secret array',
- 'type': 'array',
- 'secret': True,
- 'items': {
- 'description': 'down',
- 'type': 'object',
- 'properties': {
- 'secret_field_in_object': {
- 'type': 'string',
- 'secret': True
- }
- }
- }
+ "description": "secret array",
+ "type": "array",
+ "secret": True,
+ "items": {
+ "description": "down",
+ "type": "object",
+ "properties": {"secret_field_in_object": {"type": "string", "secret": True}},
+ },
}
-TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = 'array'
+TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = "array"
################################################################################
TEST_SECRET_ROOT_OBJECT_SCHEMA = {
- 'description': 'secret object',
- 'type': 'object',
- 'secret': True,
- 'proeprteis': {
- 'arg_level_one': {
- 'description': 'down',
- 'type': 'object',
- 'properties': {
- 'secret_field_in_object': {
- 'type': 'string',
- 'secret': True
- }
- }
+ "description": "secret object",
+ "type": "object",
+ "secret": True,
+ "proeprteis": {
+ "arg_level_one": {
+ "description": "down",
+ "type": "object",
+ "properties": {
+ "secret_field_in_object": {"type": "string", "secret": True}
+ },
}
- }
+ },
}
-TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = 'object'
+TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = "object"
################################################################################
TEST_SECRET_NESTED_OBJECTS_SCHEMA = {
- 'arg_object': {
- 'description': 'Mirror',
- 'type': 'object',
- 'properties': {
- 'arg_nested_object': {
- 'description': 'Mirror mirror',
- 'type': 'object',
- 'secret': True,
- 'properties': {
- 'arg_double_nested_secret': {
- 'description': 'Deep, deep down',
- 'type': 'string',
- 'secret': True
+ "arg_object": {
+ "description": "Mirror",
+ "type": "object",
+ "properties": {
+ "arg_nested_object": {
+ "description": "Mirror mirror",
+ "type": "object",
+ "secret": True,
+ "properties": {
+ "arg_double_nested_secret": {
+ "description": "Deep, deep down",
+ "type": "string",
+ "secret": True,
}
- }
+ },
},
- 'arg_nested_secret': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
+ "arg_nested_secret": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
+ },
+ },
},
- 'arg_secret_object': {
- 'description': 'Mirror',
- 'type': 'object',
- 'secret': True,
- 'properties': {
- 'arg_nested_object': {
- 'description': 'Mirror mirror',
- 'type': 'object',
- 'secret': True,
- 'properties': {
- 'arg_double_nested_secret': {
- 'description': 'Deep, deep down',
- 'type': 'string',
- 'secret': True
+ "arg_secret_object": {
+ "description": "Mirror",
+ "type": "object",
+ "secret": True,
+ "properties": {
+ "arg_nested_object": {
+ "description": "Mirror mirror",
+ "type": "object",
+ "secret": True,
+ "properties": {
+ "arg_double_nested_secret": {
+ "description": "Deep, deep down",
+ "type": "string",
+ "secret": True,
}
- }
+ },
},
- 'arg_nested_secret': {
- 'description': 'Deep down',
- 'type': 'string',
- 'secret': True
- }
- }
- }
+ "arg_nested_secret": {
+ "description": "Deep down",
+ "type": "string",
+ "secret": True,
+ },
+ },
+ },
}
TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS = {
- 'arg_object': {
- 'arg_nested_secret': 'string',
- 'arg_nested_object': 'object'
- },
- 'arg_secret_object': 'object'
+ "arg_object": {"arg_nested_secret": "string", "arg_nested_object": "object"},
+ "arg_secret_object": "object",
}
################################################################################
TEST_SECRET_NESTED_ARRAYS_SCHEMA = {
- 'arg_optional_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'secret': True,
- 'items': {
- 'description': 'Deep down',
- 'type': 'string'
- }
+ "arg_optional_array": {
+ "description": "Mirror",
+ "type": "array",
+ "secret": True,
+ "items": {"description": "Deep down", "type": "string"},
},
- 'arg_optional_double_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'secret': True,
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- }
- }
+ "arg_optional_double_array": {
+ "description": "Mirror",
+ "type": "array",
+ "secret": True,
+ "items": {
+ "type": "array",
+ "items": {
+ "description": "Deep down",
+ "type": "string",
+ },
+ },
},
- 'arg_optional_tripple_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'secret': True,
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- }
- }
- }
+ "arg_optional_tripple_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "secret": True,
+ "items": {
+ "type": "array",
+ "items": {
+ "description": "Deep down",
+ "type": "string",
+ },
+ },
+ },
+ },
+ "arg_optional_quad_array": {
+ "description": "Mirror",
+ "type": "array",
+ "items": {
+ "type": "array",
+ "items": {
+ "type": "array",
+ "secret": True,
+ "items": {
+ "type": "array",
+ "items": {
+ "description": "Deep down",
+ "type": "string",
+ },
+ },
+ },
+ },
},
- 'arg_optional_quad_array': {
- 'description': 'Mirror',
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'items': {
- 'type': 'array',
- 'secret': True,
- 'items': {
- 'type': 'array',
- 'items': {
- 'description': 'Deep down',
- 'type': 'string',
- }
- }
- }
- }
- }
}
TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS = {
- 'arg_optional_array': 'array',
- 'arg_optional_double_array': 'array',
- 'arg_optional_tripple_array': [
- 'array'
- ],
- 'arg_optional_quad_array': [
- [
- 'array'
- ]
- ]
+ "arg_optional_array": "array",
+ "arg_optional_double_array": "array",
+ "arg_optional_tripple_array": ["array"],
+ "arg_optional_quad_array": [["array"]],
}
################################################################################
class SecretUtilsTestCase(unittest2.TestCase):
-
def test_get_secret_parameters_flat(self):
result = secrets.get_secret_parameters(TEST_FLAT_SCHEMA)
self.assertEqual(TEST_FLAT_SECRET_PARAMS, result)
@@ -586,7 +478,9 @@ def test_get_secret_parameters_nested_object_with_array(self):
self.assertEqual(TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS, result)
def test_get_secret_parameters_nested_object_with_double_array(self):
- result = secrets.get_secret_parameters(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA)
+ result = secrets.get_secret_parameters(
+ TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA
+ )
self.assertEqual(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS, result)
def test_get_secret_parameters_nested_array_with_object(self):
@@ -621,178 +515,128 @@ def test_get_secret_parameters_secret_nested_objects(self):
def test_mask_secret_parameters_flat(self):
parameters = {
- 'arg_required_no_default': 'test',
- 'arg_optional_no_type_secret': None
+ "arg_required_no_default": "test",
+ "arg_optional_no_type_secret": None,
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_FLAT_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(parameters, TEST_FLAT_SECRET_PARAMS)
expected = {
- 'arg_required_no_default': 'test',
- 'arg_optional_no_type_secret': MASKED_ATTRIBUTE_VALUE
+ "arg_required_no_default": "test",
+ "arg_optional_no_type_secret": MASKED_ATTRIBUTE_VALUE,
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_no_secrets(self):
- parameters = {'arg_required_no_default': 'junk'}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NO_SECRETS_SECRET_PARAMS)
- expected = {
- 'arg_required_no_default': 'junk'
- }
+ parameters = {"arg_required_no_default": "junk"}
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NO_SECRETS_SECRET_PARAMS
+ )
+ expected = {"arg_required_no_default": "junk"}
self.assertEqual(expected, result)
def test_mask_secret_parameters_nested_objects(self):
parameters = {
- 'arg_optional_object': {
- 'arg_nested_secret': 'nested Secret',
- 'arg_nested_object': {
- 'arg_double_nested_secret': 'double nested $ecret',
- }
+ "arg_optional_object": {
+ "arg_nested_secret": "nested Secret",
+ "arg_nested_object": {
+ "arg_double_nested_secret": "double nested $ecret",
+ },
}
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NESTED_OBJECTS_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NESTED_OBJECTS_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_object': {
- 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE,
- 'arg_nested_object': {
- 'arg_double_nested_secret': MASKED_ATTRIBUTE_VALUE,
- }
+ "arg_optional_object": {
+ "arg_nested_secret": MASKED_ATTRIBUTE_VALUE,
+ "arg_nested_object": {
+ "arg_double_nested_secret": MASKED_ATTRIBUTE_VALUE,
+ },
}
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_array(self):
parameters = {
- 'arg_optional_array': [
- '$ecret $tring 1',
- '$ecret $tring 2',
- '$ecret $tring 3'
+ "arg_optional_array": [
+ "$ecret $tring 1",
+ "$ecret $tring 2",
+ "$ecret $tring 3",
]
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_ARRAY_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(parameters, TEST_ARRAY_SECRET_PARAMS)
expected = {
- 'arg_optional_array': [
+ "arg_optional_array": [
+ MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
]
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_root_array(self):
parameters = [
- {
- 'secret_field_in_object': 'Secret $tr!ng'
- },
- {
- 'secret_field_in_object': 'Secret $tr!ng 2'
- },
- {
- 'secret_field_in_object': 'Secret $tr!ng 3'
- },
- {
- 'secret_field_in_object': 'Secret $tr!ng 4'
- }
+ {"secret_field_in_object": "Secret $tr!ng"},
+ {"secret_field_in_object": "Secret $tr!ng 2"},
+ {"secret_field_in_object": "Secret $tr!ng 3"},
+ {"secret_field_in_object": "Secret $tr!ng 4"},
]
- result = secrets.mask_secret_parameters(parameters, TEST_ROOT_ARRAY_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_ROOT_ARRAY_SECRET_PARAMS
+ )
expected = [
- {
- 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE
- },
- {
- 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE
- },
- {
- 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE
- },
- {
- 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE
- }
+ {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE},
+ {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE},
+ {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE},
+ {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE},
]
self.assertEqual(expected, result)
def test_mask_secret_parameters_root_object(self):
- parameters = {
- 'arg_level_one':
- {
- 'secret_field_in_object': 'Secret $tr!ng'
- }
- }
+ parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}}
- result = secrets.mask_secret_parameters(parameters, TEST_ROOT_OBJECT_SECRET_PARAMS)
- expected = {
- 'arg_level_one':
- {
- 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE
- }
- }
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_ROOT_OBJECT_SECRET_PARAMS
+ )
+ expected = {"arg_level_one": {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}}
self.assertEqual(expected, result)
def test_mask_secret_parameters_nested_arrays(self):
parameters = {
- 'arg_optional_array': [
- 'secret 1',
- 'secret 2',
- 'secret 3',
+ "arg_optional_array": [
+ "secret 1",
+ "secret 2",
+ "secret 3",
],
- 'arg_optional_double_array': [
+ "arg_optional_double_array": [
[
- 'secret 4',
- 'secret 5',
- 'secret 6',
+ "secret 4",
+ "secret 5",
+ "secret 6",
],
[
- 'secret 7',
- 'secret 8',
- 'secret 9',
- ]
- ],
- 'arg_optional_tripple_array': [
- [
- [
- 'secret 10',
- 'secret 11'
- ],
- [
- 'secret 12',
- 'secret 13',
- 'secret 14'
- ]
+ "secret 7",
+ "secret 8",
+ "secret 9",
],
- [
- [
- 'secret 15',
- 'secret 16'
- ]
- ]
],
- 'arg_optional_quad_array': [
- [
- [
- [
- 'secret 17',
- 'secret 18'
- ],
- [
- 'secret 19'
- ]
- ]
- ]
- ]
+ "arg_optional_tripple_array": [
+ [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]],
+ [["secret 15", "secret 16"]],
+ ],
+ "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]],
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NESTED_ARRAYS_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NESTED_ARRAYS_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_array': [
+ "arg_optional_array": [
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
],
- 'arg_optional_double_array': [
+ "arg_optional_double_array": [
[
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
@@ -802,58 +646,46 @@ def test_mask_secret_parameters_nested_arrays(self):
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
- ]
+ ],
],
- 'arg_optional_tripple_array': [
+ "arg_optional_tripple_array": [
[
+ [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE],
[
MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
- ],
- [
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
- ]
+ ],
],
- [
- [
- MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
- ]
- ]
+ [[MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE]],
],
- 'arg_optional_quad_array': [
+ "arg_optional_quad_array": [
[
[
- [
- MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
- ],
- [
- MASKED_ATTRIBUTE_VALUE
- ]
+ [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE],
+ [MASKED_ATTRIBUTE_VALUE],
]
]
- ]
+ ],
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_nested_object_with_array(self):
parameters = {
- 'arg_optional_object_with_array': {
- 'arg_nested_array': [
- 'secret array value 1',
- 'secret array value 2',
- 'secret array value 3',
+ "arg_optional_object_with_array": {
+ "arg_nested_array": [
+ "secret array value 1",
+ "secret array value 2",
+ "secret array value 3",
]
}
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_object_with_array': {
- 'arg_nested_array': [
+ "arg_optional_object_with_array": {
+ "arg_nested_array": [
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
@@ -864,36 +696,33 @@ def test_mask_secret_parameters_nested_object_with_array(self):
def test_mask_secret_parameters_nested_object_with_double_array(self):
parameters = {
- 'arg_optional_object_with_double_array': {
- 'arg_double_nested_array': [
+ "arg_optional_object_with_double_array": {
+ "arg_double_nested_array": [
+ ["secret 1", "secret 2", "secret 3"],
[
- 'secret 1',
- 'secret 2',
- 'secret 3'
+ "secret 4",
+ "secret 5",
+ "secret 6",
],
[
- 'secret 4',
- 'secret 5',
- 'secret 6',
+ "secret 7",
+ "secret 8",
+ "secret 9",
+ "secret 10",
],
- [
- 'secret 7',
- 'secret 8',
- 'secret 9',
- 'secret 10',
- ]
]
}
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_object_with_double_array': {
- 'arg_double_nested_array': [
+ "arg_optional_object_with_double_array": {
+ "arg_double_nested_array": [
[
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
- MASKED_ATTRIBUTE_VALUE
+ MASKED_ATTRIBUTE_VALUE,
],
[
MASKED_ATTRIBUTE_VALUE,
@@ -905,7 +734,7 @@ def test_mask_secret_parameters_nested_object_with_double_array(self):
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
- ]
+ ],
]
}
}
@@ -913,187 +742,132 @@ def test_mask_secret_parameters_nested_object_with_double_array(self):
def test_mask_secret_parameters_nested_array_with_object(self):
parameters = {
- 'arg_optional_array_with_object': [
- {
- 'arg_nested_secret': 'secret 1'
- },
- {
- 'arg_nested_secret': 'secret 2'
- },
- {
- 'arg_nested_secret': 'secret 3'
- }
+ "arg_optional_array_with_object": [
+ {"arg_nested_secret": "secret 1"},
+ {"arg_nested_secret": "secret 2"},
+ {"arg_nested_secret": "secret 3"},
]
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_array_with_object': [
- {
- 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE
- },
- {
- 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE
- },
- {
- 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE
- }
+ "arg_optional_array_with_object": [
+ {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE},
+ {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE},
+ {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE},
]
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_array(self):
- parameters = {
- 'arg_secret_array': [
- "abc",
- 123,
- True
- ]
- }
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_ARRAY_SECRET_PARAMS)
- expected = {
- 'arg_secret_array': MASKED_ATTRIBUTE_VALUE
- }
+ parameters = {"arg_secret_array": ["abc", 123, True]}
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_ARRAY_SECRET_PARAMS
+ )
+ expected = {"arg_secret_array": MASKED_ATTRIBUTE_VALUE}
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_object(self):
parameters = {
- 'arg_secret_object':
- {
+ "arg_secret_object": {
"abc": 123,
"key": "value",
"bool": True,
"array": ["x", "y", "z"],
- "obj":
- {
- "x": "deep"
- }
+ "obj": {"x": "deep"},
}
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_OBJECT_SECRET_PARAMS)
- expected = {
- 'arg_secret_object': MASKED_ATTRIBUTE_VALUE
- }
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_OBJECT_SECRET_PARAMS
+ )
+ expected = {"arg_secret_object": MASKED_ATTRIBUTE_VALUE}
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_root_array(self):
- parameters = [
- "abc",
- 123,
- True
- ]
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS)
+ parameters = ["abc", 123, True]
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS
+ )
expected = MASKED_ATTRIBUTE_VALUE
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_root_object(self):
- parameters = {
- 'arg_level_one':
- {
- 'secret_field_in_object': 'Secret $tr!ng'
- }
- }
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS)
+ parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}}
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS
+ )
expected = MASKED_ATTRIBUTE_VALUE
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_nested_arrays(self):
parameters = {
- 'arg_optional_array': [
- 'secret 1',
- 'secret 2',
- 'secret 3',
+ "arg_optional_array": [
+ "secret 1",
+ "secret 2",
+ "secret 3",
],
- 'arg_optional_double_array': [
+ "arg_optional_double_array": [
[
- 'secret 4',
- 'secret 5',
- 'secret 6',
+ "secret 4",
+ "secret 5",
+ "secret 6",
],
[
- 'secret 7',
- 'secret 8',
- 'secret 9',
- ]
- ],
- 'arg_optional_tripple_array': [
- [
- [
- 'secret 10',
- 'secret 11'
- ],
- [
- 'secret 12',
- 'secret 13',
- 'secret 14'
- ]
+ "secret 7",
+ "secret 8",
+ "secret 9",
],
- [
- [
- 'secret 15',
- 'secret 16'
- ]
- ]
],
- 'arg_optional_quad_array': [
- [
- [
- [
- 'secret 17',
- 'secret 18'
- ],
- [
- 'secret 19'
- ]
- ]
- ]
- ]
+ "arg_optional_tripple_array": [
+ [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]],
+ [["secret 15", "secret 16"]],
+ ],
+ "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]],
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS
+ )
expected = {
- 'arg_optional_array': MASKED_ATTRIBUTE_VALUE,
- 'arg_optional_double_array': MASKED_ATTRIBUTE_VALUE,
- 'arg_optional_tripple_array': [
+ "arg_optional_array": MASKED_ATTRIBUTE_VALUE,
+ "arg_optional_double_array": MASKED_ATTRIBUTE_VALUE,
+ "arg_optional_tripple_array": [
MASKED_ATTRIBUTE_VALUE,
MASKED_ATTRIBUTE_VALUE,
],
- 'arg_optional_quad_array': [
+ "arg_optional_quad_array": [
[
MASKED_ATTRIBUTE_VALUE,
]
- ]
+ ],
}
self.assertEqual(expected, result)
def test_mask_secret_parameters_secret_nested_objects(self):
parameters = {
- 'arg_object': {
- 'arg_nested_secret': 'nested Secret',
- 'arg_nested_object': {
- 'arg_double_nested_secret': 'double nested $ecret',
- }
+ "arg_object": {
+ "arg_nested_secret": "nested Secret",
+ "arg_nested_object": {
+ "arg_double_nested_secret": "double nested $ecret",
+ },
+ },
+ "arg_secret_object": {
+ "arg_nested_secret": "secret data",
+ "arg_nested_object": {
+ "arg_double_nested_secret": "double nested $ecret",
+ },
},
- 'arg_secret_object': {
- 'arg_nested_secret': 'secret data',
- 'arg_nested_object': {
- 'arg_double_nested_secret': 'double nested $ecret',
- }
- }
}
- result = secrets.mask_secret_parameters(parameters,
- TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS)
+ result = secrets.mask_secret_parameters(
+ parameters, TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS
+ )
expected = {
- 'arg_object': {
- 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE,
- 'arg_nested_object': MASKED_ATTRIBUTE_VALUE,
+ "arg_object": {
+ "arg_nested_secret": MASKED_ATTRIBUTE_VALUE,
+ "arg_nested_object": MASKED_ATTRIBUTE_VALUE,
},
- 'arg_secret_object': MASKED_ATTRIBUTE_VALUE,
+ "arg_secret_object": MASKED_ATTRIBUTE_VALUE,
}
self.assertEqual(expected, result)
diff --git a/st2common/tests/unit/test_util_shell.py b/st2common/tests/unit/test_util_shell.py
index 86c37f2ad1..4a2a00e343 100644
--- a/st2common/tests/unit/test_util_shell.py
+++ b/st2common/tests/unit/test_util_shell.py
@@ -23,38 +23,26 @@
class ShellUtilsTestCase(unittest2.TestCase):
def test_quote_unix(self):
- arguments = [
- 'foo',
- 'foo bar',
- 'foo1 bar1',
- '"foo"',
- '"foo" "bar"',
- "'foo bar'"
- ]
+ arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"]
expected_values = [
"""
foo
""",
-
"""
'foo bar'
""",
-
"""
'foo1 bar1'
""",
-
"""
'"foo"'
""",
-
"""
'"foo" "bar"'
""",
-
"""
''"'"'foo bar'"'"''
- """
+ """,
]
for argument, expected_value in zip(arguments, expected_values):
@@ -63,38 +51,26 @@ def test_quote_unix(self):
self.assertEqual(actual_value, expected_value.strip())
def test_quote_windows(self):
- arguments = [
- 'foo',
- 'foo bar',
- 'foo1 bar1',
- '"foo"',
- '"foo" "bar"',
- "'foo bar'"
- ]
+ arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"]
expected_values = [
"""
foo
""",
-
"""
"foo bar"
""",
-
"""
"foo1 bar1"
""",
-
"""
\\"foo\\"
""",
-
"""
"\\"foo\\" \\"bar\\""
""",
-
"""
"'foo bar'"
- """
+ """,
]
for argument, expected_value in zip(arguments, expected_values):
diff --git a/st2common/tests/unit/test_util_templating.py b/st2common/tests/unit/test_util_templating.py
index 1756590bc1..c6cd539849 100644
--- a/st2common/tests/unit/test_util_templating.py
+++ b/st2common/tests/unit/test_util_templating.py
@@ -26,41 +26,45 @@ def setUp(self):
super(TemplatingUtilsTestCase, self).setUp()
# Insert mock DB objects
- kvp_1_db = KeyValuePairDB(name='key1', value='valuea')
+ kvp_1_db = KeyValuePairDB(name="key1", value="valuea")
kvp_1_db = KeyValuePair.add_or_update(kvp_1_db)
- kvp_2_db = KeyValuePairDB(name='key2', value='valueb')
+ kvp_2_db = KeyValuePairDB(name="key2", value="valueb")
kvp_2_db = KeyValuePair.add_or_update(kvp_2_db)
- kvp_3_db = KeyValuePairDB(name='stanley:key1', value='valuestanley1', scope=FULL_USER_SCOPE)
+ kvp_3_db = KeyValuePairDB(
+ name="stanley:key1", value="valuestanley1", scope=FULL_USER_SCOPE
+ )
kvp_3_db = KeyValuePair.add_or_update(kvp_3_db)
- kvp_4_db = KeyValuePairDB(name='joe:key1', value='valuejoe1', scope=FULL_USER_SCOPE)
+ kvp_4_db = KeyValuePairDB(
+ name="joe:key1", value="valuejoe1", scope=FULL_USER_SCOPE
+ )
kvp_4_db = KeyValuePair.add_or_update(kvp_4_db)
def test_render_template_with_system_and_user_context(self):
# 1. No reference to the user inside the template
- template = '{{st2kv.system.key1}}'
- user = 'stanley'
+ template = "{{st2kv.system.key1}}"
+ user = "stanley"
result = render_template_with_system_and_user_context(value=template, user=user)
- self.assertEqual(result, 'valuea')
+ self.assertEqual(result, "valuea")
- template = '{{st2kv.system.key2}}'
- user = 'stanley'
+ template = "{{st2kv.system.key2}}"
+ user = "stanley"
result = render_template_with_system_and_user_context(value=template, user=user)
- self.assertEqual(result, 'valueb')
+ self.assertEqual(result, "valueb")
# 2. Reference to the user inside the template
- template = '{{st2kv.user.key1}}'
- user = 'stanley'
+ template = "{{st2kv.user.key1}}"
+ user = "stanley"
result = render_template_with_system_and_user_context(value=template, user=user)
- self.assertEqual(result, 'valuestanley1')
+ self.assertEqual(result, "valuestanley1")
- template = '{{st2kv.user.key1}}'
- user = 'joe'
+ template = "{{st2kv.user.key1}}"
+ user = "joe"
result = render_template_with_system_and_user_context(value=template, user=user)
- self.assertEqual(result, 'valuejoe1')
+ self.assertEqual(result, "valuejoe1")
diff --git a/st2common/tests/unit/test_util_types.py b/st2common/tests/unit/test_util_types.py
index 1213eb69d1..8b7ef78864 100644
--- a/st2common/tests/unit/test_util_types.py
+++ b/st2common/tests/unit/test_util_types.py
@@ -17,9 +17,7 @@
from st2common.util.types import OrderedSet
-__all__ = [
- 'OrderedTestTypeTestCase'
-]
+__all__ = ["OrderedTestTypeTestCase"]
class OrderedTestTypeTestCase(unittest2.TestCase):
diff --git a/st2common/tests/unit/test_util_url.py b/st2common/tests/unit/test_util_url.py
index 551aed3e8c..8b23619593 100644
--- a/st2common/tests/unit/test_util_url.py
+++ b/st2common/tests/unit/test_util_url.py
@@ -23,16 +23,16 @@
class URLUtilsTestCase(unittest2.TestCase):
def test_get_url_without_trailing_slash(self):
values = [
- 'http://localhost:1818/foo/bar/',
- 'http://localhost:1818/foo/bar',
- 'http://localhost:1818/',
- 'http://localhost:1818',
+ "http://localhost:1818/foo/bar/",
+ "http://localhost:1818/foo/bar",
+ "http://localhost:1818/",
+ "http://localhost:1818",
]
expected = [
- 'http://localhost:1818/foo/bar',
- 'http://localhost:1818/foo/bar',
- 'http://localhost:1818',
- 'http://localhost:1818',
+ "http://localhost:1818/foo/bar",
+ "http://localhost:1818/foo/bar",
+ "http://localhost:1818",
+ "http://localhost:1818",
]
for value, expected_result in zip(values, expected):
diff --git a/st2common/tests/unit/test_versioning_utils.py b/st2common/tests/unit/test_versioning_utils.py
index 73d118aa89..de7bbbfeaf 100644
--- a/st2common/tests/unit/test_versioning_utils.py
+++ b/st2common/tests/unit/test_versioning_utils.py
@@ -23,40 +23,40 @@
class VersioningUtilsTestCase(unittest2.TestCase):
def test_complex_semver_match(self):
# Positive test case
- self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0, <2.2.0'))
- self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0, <2.2.0'))
- self.assertTrue(complex_semver_match('2.0.0', '>=1.6.0, <2.2.0'))
- self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0, <2.2.0'))
- self.assertTrue(complex_semver_match('2.1.9', '>=1.6.0, <2.2.0'))
+ self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0, <2.2.0"))
+ self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0, <2.2.0"))
+ self.assertTrue(complex_semver_match("2.0.0", ">=1.6.0, <2.2.0"))
+ self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0, <2.2.0"))
+ self.assertTrue(complex_semver_match("2.1.9", ">=1.6.0, <2.2.0"))
- self.assertTrue(complex_semver_match('1.6.0', 'all'))
- self.assertTrue(complex_semver_match('1.6.1', 'all'))
- self.assertTrue(complex_semver_match('2.0.0', 'all'))
- self.assertTrue(complex_semver_match('2.1.0', 'all'))
+ self.assertTrue(complex_semver_match("1.6.0", "all"))
+ self.assertTrue(complex_semver_match("1.6.1", "all"))
+ self.assertTrue(complex_semver_match("2.0.0", "all"))
+ self.assertTrue(complex_semver_match("2.1.0", "all"))
- self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0'))
- self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0'))
- self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0'))
+ self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0"))
+ self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0"))
+ self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0"))
# Negative test case
- self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0, <2.2.0'))
- self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0, <2.2.0'))
- self.assertFalse(complex_semver_match('2.2.1', '>=1.6.0, <2.2.0'))
- self.assertFalse(complex_semver_match('2.3.0', '>=1.6.0, <2.2.0'))
- self.assertFalse(complex_semver_match('3.0.0', '>=1.6.0, <2.2.0'))
+ self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0, <2.2.0"))
+ self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0, <2.2.0"))
+ self.assertFalse(complex_semver_match("2.2.1", ">=1.6.0, <2.2.0"))
+ self.assertFalse(complex_semver_match("2.3.0", ">=1.6.0, <2.2.0"))
+ self.assertFalse(complex_semver_match("3.0.0", ">=1.6.0, <2.2.0"))
- self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0'))
- self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0'))
- self.assertFalse(complex_semver_match('1.5.9', '>=1.6.0'))
+ self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0"))
+ self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0"))
+ self.assertFalse(complex_semver_match("1.5.9", ">=1.6.0"))
def test_normalize_pack_version(self):
# Already a valid semver version string
- self.assertEqual(normalize_pack_version('0.2.0'), '0.2.0')
- self.assertEqual(normalize_pack_version('0.2.1'), '0.2.1')
- self.assertEqual(normalize_pack_version('1.2.1'), '1.2.1')
+ self.assertEqual(normalize_pack_version("0.2.0"), "0.2.0")
+ self.assertEqual(normalize_pack_version("0.2.1"), "0.2.1")
+ self.assertEqual(normalize_pack_version("1.2.1"), "1.2.1")
# Not a valid semver version string
- self.assertEqual(normalize_pack_version('0.2'), '0.2.0')
- self.assertEqual(normalize_pack_version('0.3'), '0.3.0')
- self.assertEqual(normalize_pack_version('1.3'), '1.3.0')
- self.assertEqual(normalize_pack_version('2.0'), '2.0.0')
+ self.assertEqual(normalize_pack_version("0.2"), "0.2.0")
+ self.assertEqual(normalize_pack_version("0.3"), "0.3.0")
+ self.assertEqual(normalize_pack_version("1.3"), "1.3.0")
+ self.assertEqual(normalize_pack_version("2.0"), "2.0.0")
diff --git a/st2common/tests/unit/test_virtualenvs.py b/st2common/tests/unit/test_virtualenvs.py
index 90c0f4e989..439801f67a 100644
--- a/st2common/tests/unit/test_virtualenvs.py
+++ b/st2common/tests/unit/test_virtualenvs.py
@@ -30,30 +30,28 @@
from st2common.util.virtualenvs import setup_pack_virtualenv
-__all__ = [
- 'VirtualenvUtilsTestCase'
-]
+__all__ = ["VirtualenvUtilsTestCase"]
# Note: We set base requirements to an empty list to speed up the tests
-@mock.patch('st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS', [])
+@mock.patch("st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS", [])
class VirtualenvUtilsTestCase(CleanFilesTestCase):
def setUp(self):
super(VirtualenvUtilsTestCase, self).setUp()
config.parse_args()
dir_path = tempfile.mkdtemp()
- cfg.CONF.set_override(name='base_path', override=dir_path, group='system')
+ cfg.CONF.set_override(name="base_path", override=dir_path, group="system")
self.base_path = dir_path
- self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/')
+ self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/")
# Make sure dir is deleted on tearDown
self.to_delete_directories.append(self.base_path)
def test_setup_pack_virtualenv_doesnt_exist_yet(self):
# Test a fresh virtualenv creation
- pack_name = 'dummy_pack_1'
+ pack_name = "dummy_pack_1"
pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name)
# Verify virtualenv directory doesn't exist
@@ -61,58 +59,81 @@ def test_setup_pack_virtualenv_doesnt_exist_yet(self):
# Create virtualenv
# Note: This pack has no requirements
- setup_pack_virtualenv(pack_name=pack_name, update=False,
- include_pip=False, include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ include_pip=False,
+ include_setuptools=False,
+ include_wheel=False,
+ )
# Verify that virtualenv has been created
self.assertVirtualenvExists(pack_virtualenv_dir)
def test_setup_pack_virtualenv_already_exists(self):
# Test a scenario where virtualenv already exists
- pack_name = 'dummy_pack_1'
+ pack_name = "dummy_pack_1"
pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name)
# Verify virtualenv directory doesn't exist
self.assertFalse(os.path.exists(pack_virtualenv_dir))
# Create virtualenv
- setup_pack_virtualenv(pack_name=pack_name, update=False,
- include_pip=False, include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ include_pip=False,
+ include_setuptools=False,
+ include_wheel=False,
+ )
# Verify that virtualenv has been created
self.assertVirtualenvExists(pack_virtualenv_dir)
# Re-create virtualenv
- setup_pack_virtualenv(pack_name=pack_name, update=False,
- include_pip=False, include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ include_pip=False,
+ include_setuptools=False,
+ include_wheel=False,
+ )
# Verify virtrualenv is still there
self.assertVirtualenvExists(pack_virtualenv_dir)
def test_setup_virtualenv_update(self):
# Test a virtualenv update with pack which has requirements.txt
- pack_name = 'dummy_pack_2'
+ pack_name = "dummy_pack_2"
pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name)
# Verify virtualenv directory doesn't exist
self.assertFalse(os.path.exists(pack_virtualenv_dir))
# Create virtualenv
- setup_pack_virtualenv(pack_name=pack_name, update=False,
- include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ include_setuptools=False,
+ include_wheel=False,
+ )
# Verify that virtualenv has been created
self.assertVirtualenvExists(pack_virtualenv_dir)
# Update it
- setup_pack_virtualenv(pack_name=pack_name, update=True,
- include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=True,
+ include_setuptools=False,
+ include_wheel=False,
+ )
# Verify virtrualenv is still there
self.assertVirtualenvExists(pack_virtualenv_dir)
def test_setup_virtualenv_invalid_dependency_in_requirements_file(self):
- pack_name = 'pack_invalid_requirements'
+ pack_name = "pack_invalid_requirements"
pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name)
# Verify virtualenv directory doesn't exist
@@ -120,182 +141,240 @@ def test_setup_virtualenv_invalid_dependency_in_requirements_file(self):
# Try to create virtualenv, assert that it fails
try:
- setup_pack_virtualenv(pack_name=pack_name, update=False,
- include_setuptools=False, include_wheel=False)
+ setup_pack_virtualenv(
+ pack_name=pack_name,
+ update=False,
+ include_setuptools=False,
+ include_wheel=False,
+ )
except Exception as e:
- self.assertIn('Failed to install requirements from', six.text_type(e))
- self.assertTrue('No matching distribution found for someinvalidname' in
- six.text_type(e))
+ self.assertIn("Failed to install requirements from", six.text_type(e))
+ self.assertTrue(
+ "No matching distribution found for someinvalidname" in six.text_type(e)
+ )
else:
- self.fail('Exception not thrown')
-
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ self.fail("Exception not thrown")
+
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirement_without_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirement = 'six>=1.9.0'
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirement = "six>=1.9.0"
install_requirement(pack_virtualenv_dir, requirement, proxy_config=None)
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- 'install', 'six>=1.9.0'
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "install",
+ "six>=1.9.0",
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirement_with_http_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirement = 'six>=1.9.0'
- proxy_config = {
- 'http_proxy': 'http://192.168.1.5:8080'
- }
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirement = "six>=1.9.0"
+ proxy_config = {"http_proxy": "http://192.168.1.5:8080"}
install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config)
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'http://192.168.1.5:8080',
- 'install', 'six>=1.9.0'
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "http://192.168.1.5:8080",
+ "install",
+ "six>=1.9.0",
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirement_with_https_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirement = 'six>=1.9.0'
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirement = "six>=1.9.0"
proxy_config = {
- 'https_proxy': 'https://192.168.1.5:8080',
- 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem'
+ "https_proxy": "https://192.168.1.5:8080",
+ "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem",
}
install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config)
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'https://192.168.1.5:8080',
- '--cert', '/etc/ssl/certs/mitmproxy-ca.pem',
- 'install', 'six>=1.9.0'
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "https://192.168.1.5:8080",
+ "--cert",
+ "/etc/ssl/certs/mitmproxy-ca.pem",
+ "install",
+ "six>=1.9.0",
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirement_with_https_proxy_no_cert(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirement = 'six>=1.9.0'
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirement = "six>=1.9.0"
proxy_config = {
- 'https_proxy': 'https://192.168.1.5:8080',
+ "https_proxy": "https://192.168.1.5:8080",
}
install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config)
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'https://192.168.1.5:8080',
- 'install', 'six>=1.9.0'
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "https://192.168.1.5:8080",
+ "install",
+ "six>=1.9.0",
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirements_without_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt'
- install_requirements(pack_virtualenv_dir, requirements_file_path, proxy_config=None)
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirements_file_path = (
+ "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt"
+ )
+ install_requirements(
+ pack_virtualenv_dir, requirements_file_path, proxy_config=None
+ )
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- 'install', '-U',
- '-r', requirements_file_path
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "install",
+ "-U",
+ "-r",
+ requirements_file_path,
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirements_with_http_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt'
- proxy_config = {
- 'http_proxy': 'http://192.168.1.5:8080'
- }
- install_requirements(pack_virtualenv_dir, requirements_file_path,
- proxy_config=proxy_config)
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirements_file_path = (
+ "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt"
+ )
+ proxy_config = {"http_proxy": "http://192.168.1.5:8080"}
+ install_requirements(
+ pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config
+ )
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'http://192.168.1.5:8080',
- 'install', '-U',
- '-r', requirements_file_path
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "http://192.168.1.5:8080",
+ "install",
+ "-U",
+ "-r",
+ requirements_file_path,
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirements_with_https_proxy(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt'
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirements_file_path = (
+ "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt"
+ )
proxy_config = {
- 'https_proxy': 'https://192.168.1.5:8080',
- 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem'
+ "https_proxy": "https://192.168.1.5:8080",
+ "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem",
}
- install_requirements(pack_virtualenv_dir, requirements_file_path,
- proxy_config=proxy_config)
+ install_requirements(
+ pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config
+ )
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'https://192.168.1.5:8080',
- '--cert', '/etc/ssl/certs/mitmproxy-ca.pem',
- 'install', '-U',
- '-r', requirements_file_path
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "https://192.168.1.5:8080",
+ "--cert",
+ "/etc/ssl/certs/mitmproxy-ca.pem",
+ "install",
+ "-U",
+ "-r",
+ requirements_file_path,
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
- @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', '')))
- @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command',
- mock.MagicMock(return_value={}))
+ @mock.patch.object(
+ virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", ""))
+ )
+ @mock.patch.object(
+ virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={})
+ )
def test_install_requirements_with_https_proxy_no_cert(self):
- pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/'
- requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt'
+ pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/"
+ requirements_file_path = (
+ "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt"
+ )
proxy_config = {
- 'https_proxy': 'https://192.168.1.5:8080',
+ "https_proxy": "https://192.168.1.5:8080",
}
- install_requirements(pack_virtualenv_dir, requirements_file_path,
- proxy_config=proxy_config)
+ install_requirements(
+ pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config
+ )
expected_args = {
- 'cmd': [
- '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip',
- '--proxy', 'https://192.168.1.5:8080',
- 'install', '-U',
- '-r', requirements_file_path
+ "cmd": [
+ "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip",
+ "--proxy",
+ "https://192.168.1.5:8080",
+ "install",
+ "-U",
+ "-r",
+ requirements_file_path,
],
- 'env': {}
+ "env": {},
}
virtualenvs.run_command.assert_called_once_with(**expected_args)
def assertVirtualenvExists(self, virtualenv_dir):
self.assertTrue(os.path.exists(virtualenv_dir))
self.assertTrue(os.path.isdir(virtualenv_dir))
- self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, 'bin/')))
+ self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, "bin/")))
return True
diff --git a/st2exporter/dist_utils.py b/st2exporter/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2exporter/dist_utils.py
+++ b/st2exporter/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2exporter/setup.py b/st2exporter/setup.py
index bfd01f7061..afaae79cac 100644
--- a/st2exporter/setup.py
+++ b/st2exporter/setup.py
@@ -22,9 +22,9 @@
from dist_utils import apply_vagrant_workaround
from st2exporter import __version__
-ST2_COMPONENT = 'st2exporter'
+ST2_COMPONENT = "st2exporter"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -32,18 +32,18 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- scripts=[
- 'bin/st2exporter'
- ]
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ scripts=["bin/st2exporter"],
)
diff --git a/st2exporter/st2exporter/cmd/st2exporter_starter.py b/st2exporter/st2exporter/cmd/st2exporter_starter.py
index c5ce157e24..2b86ef2707 100644
--- a/st2exporter/st2exporter/cmd/st2exporter_starter.py
+++ b/st2exporter/st2exporter/cmd/st2exporter_starter.py
@@ -14,6 +14,7 @@
# limitations under the License.
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -25,26 +26,29 @@
from st2exporter import config
from st2exporter import worker
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _setup():
- common_setup(service='exporter', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True)
+ common_setup(
+ service="exporter",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ )
def _run_worker():
- LOG.info('(PID=%s) Exporter started.', os.getpid())
+ LOG.info("(PID=%s) Exporter started.", os.getpid())
export_worker = worker.get_worker()
try:
export_worker.start(wait=True)
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) Exporter stopped.', os.getpid())
+ LOG.info("(PID=%s) Exporter stopped.", os.getpid())
export_worker.shutdown()
except:
return 1
@@ -62,7 +66,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except:
- LOG.exception('(PID=%s) Exporter quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) Exporter quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2exporter/st2exporter/config.py b/st2exporter/st2exporter/config.py
index 456b09e365..83f4f45d5d 100644
--- a/st2exporter/st2exporter/config.py
+++ b/st2exporter/st2exporter/config.py
@@ -31,8 +31,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def get_logging_config_path():
@@ -51,16 +54,20 @@ def _register_common_opts():
def _register_app_opts():
dump_opts = [
cfg.StrOpt(
- 'dump_dir', default='/opt/stackstorm/exports/',
- help='Directory to dump data to.')
+ "dump_dir",
+ default="/opt/stackstorm/exports/",
+ help="Directory to dump data to.",
+ )
]
- CONF.register_opts(dump_opts, group='exporter')
+ CONF.register_opts(dump_opts, group="exporter")
logging_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.exporter.conf',
- help='location of the logging.exporter.conf file')
+ "logging",
+ default="/etc/st2/logging.exporter.conf",
+ help="location of the logging.exporter.conf file",
+ )
]
- CONF.register_opts(logging_opts, group='exporter')
+ CONF.register_opts(logging_opts, group="exporter")
diff --git a/st2exporter/st2exporter/exporter/dumper.py b/st2exporter/st2exporter/exporter/dumper.py
index 2059557420..12fbeb4f83 100644
--- a/st2exporter/st2exporter/exporter/dumper.py
+++ b/st2exporter/st2exporter/exporter/dumper.py
@@ -26,40 +26,43 @@
from st2common.util import date as date_utils
from st2common.util import isotime
-__all__ = [
- 'Dumper'
-]
+__all__ = ["Dumper"]
-ALLOWED_EXTENSIONS = ['json']
+ALLOWED_EXTENSIONS = ["json"]
-CONVERTERS = {
- 'json': JsonConverter
-}
+CONVERTERS = {"json": JsonConverter}
LOG = logging.getLogger(__name__)
class Dumper(object):
-
- def __init__(self, queue, export_dir, file_format='json',
- file_prefix='st2-executions-',
- batch_size=1000, sleep_interval=60,
- max_files_per_sleep=5,
- file_writer=None):
+ def __init__(
+ self,
+ queue,
+ export_dir,
+ file_format="json",
+ file_prefix="st2-executions-",
+ batch_size=1000,
+ sleep_interval=60,
+ max_files_per_sleep=5,
+ file_writer=None,
+ ):
if not queue:
- raise Exception('Need a queue to consume data from.')
+ raise Exception("Need a queue to consume data from.")
if not export_dir:
- raise Exception('Export dir needed to dump files to.')
+ raise Exception("Export dir needed to dump files to.")
self._export_dir = export_dir
if not os.path.exists(self._export_dir):
- raise Exception('Dir path %s does not exist. Create one before using exporter.' %
- self._export_dir)
+ raise Exception(
+ "Dir path %s does not exist. Create one before using exporter."
+ % self._export_dir
+ )
self._file_format = file_format.lower()
if self._file_format not in ALLOWED_EXTENSIONS:
- raise ValueError('Disallowed extension %s.' % file_format)
+ raise ValueError("Disallowed extension %s." % file_format)
self._file_prefix = file_prefix
self._batch_size = batch_size
@@ -99,8 +102,8 @@ def _get_batch(self):
else:
executions_to_write.append(item)
- LOG.debug('Returning %d items in batch.', len(executions_to_write))
- LOG.debug('Remaining items in queue: %d', self._queue.qsize())
+ LOG.debug("Returning %d items in batch.", len(executions_to_write))
+ LOG.debug("Remaining items in queue: %d", self._queue.qsize())
return executions_to_write
def _flush(self):
@@ -111,7 +114,7 @@ def _flush(self):
try:
self._write_to_disk()
except:
- LOG.error('Failed writing data to disk.')
+ LOG.error("Failed writing data to disk.")
def _write_to_disk(self):
count = 0
@@ -128,7 +131,7 @@ def _write_to_disk(self):
self._update_marker(batch)
count += 1
except:
- LOG.exception('Writing batch to disk failed.')
+ LOG.exception("Writing batch to disk failed.")
return count
def _create_date_folder(self):
@@ -139,7 +142,7 @@ def _create_date_folder(self):
try:
os.makedirs(folder_path)
except:
- LOG.exception('Unable to create sub-folder %s for export.', folder_name)
+ LOG.exception("Unable to create sub-folder %s for export.", folder_name)
raise
def _write_batch_to_disk(self, batch):
@@ -147,42 +150,44 @@ def _write_batch_to_disk(self, batch):
self._file_writer.write_text(doc_to_write, self._get_file_name())
def _get_file_name(self):
- timestring = date_utils.get_datetime_utc_now().strftime('%Y-%m-%dT%H:%M:%S.%fZ')
- file_name = self._file_prefix + timestring + '.' + self._file_format
+ timestring = date_utils.get_datetime_utc_now().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
+ file_name = self._file_prefix + timestring + "." + self._file_format
file_name = os.path.join(self._export_dir, self._get_date_folder(), file_name)
return file_name
def _get_date_folder(self):
- return date_utils.get_datetime_utc_now().strftime('%Y-%m-%d')
+ return date_utils.get_datetime_utc_now().strftime("%Y-%m-%d")
def _update_marker(self, batch):
timestamps = [isotime.parse(item.end_timestamp) for item in batch]
new_marker = max(timestamps)
if self._persisted_marker and self._persisted_marker > new_marker:
- LOG.warn('Older executions are being exported. Perhaps out of order messages.')
+ LOG.warn(
+ "Older executions are being exported. Perhaps out of order messages."
+ )
try:
self._write_marker_to_db(new_marker)
except:
- LOG.exception('Failed persisting dumper marker to db.')
+ LOG.exception("Failed persisting dumper marker to db.")
else:
self._persisted_marker = new_marker
return self._persisted_marker
def _write_marker_to_db(self, new_marker):
- LOG.info('Updating marker in db to: %s', new_marker)
+ LOG.info("Updating marker in db to: %s", new_marker)
markers = DumperMarker.get_all()
if len(markers) > 1:
- LOG.exception('More than one dumper marker found. Using first found one.')
+ LOG.exception("More than one dumper marker found. Using first found one.")
marker = isotime.format(new_marker, offset=False)
updated_at = date_utils.get_datetime_utc_now()
if markers:
- marker_id = markers[0]['id']
+ marker_id = markers[0]["id"]
else:
marker_id = None
diff --git a/st2exporter/st2exporter/exporter/file_writer.py b/st2exporter/st2exporter/exporter/file_writer.py
index ec7e4d876c..49b5b4d63a 100644
--- a/st2exporter/st2exporter/exporter/file_writer.py
+++ b/st2exporter/st2exporter/exporter/file_writer.py
@@ -18,15 +18,11 @@
import abc
import six
-__all__ = [
- 'FileWriter',
- 'TextFileWriter'
-]
+__all__ = ["FileWriter", "TextFileWriter"]
@six.add_metaclass(abc.ABCMeta)
class FileWriter(object):
-
@abc.abstractmethod
def write(self, data, file_path, replace=False):
"""
@@ -40,13 +36,13 @@ class TextFileWriter(FileWriter):
def write_text(self, text_data, file_path, replace=False, compressed=False):
if compressed:
- return Exception('Compression not supported.')
+ return Exception("Compression not supported.")
self.write(text_data, file_path, replace=replace)
def write(self, data, file_path, replace=False):
if os.path.exists(file_path) and not replace:
- raise Exception('File %s already exists.' % file_path)
+ raise Exception("File %s already exists." % file_path)
- with open(file_path, 'w') as f:
+ with open(file_path, "w") as f:
f.write(data)
diff --git a/st2exporter/st2exporter/exporter/json_converter.py b/st2exporter/st2exporter/exporter/json_converter.py
index a288197d41..ba7e95c0a5 100644
--- a/st2exporter/st2exporter/exporter/json_converter.py
+++ b/st2exporter/st2exporter/exporter/json_converter.py
@@ -15,15 +15,12 @@
from st2common.util.jsonify import json_encode
-__all__ = [
- 'JsonConverter'
-]
+__all__ = ["JsonConverter"]
class JsonConverter(object):
-
def convert(self, items_list):
if not isinstance(items_list, list):
- raise ValueError('Items to be converted should be a list.')
+ raise ValueError("Items to be converted should be a list.")
json_doc = json_encode(items_list)
return json_doc
diff --git a/st2exporter/st2exporter/worker.py b/st2exporter/st2exporter/worker.py
index 13273fd587..a5557ee41f 100644
--- a/st2exporter/st2exporter/worker.py
+++ b/st2exporter/st2exporter/worker.py
@@ -18,8 +18,11 @@
from oslo_config import cfg
from st2common import log as logging
-from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED,
- LIVEACTION_STATUS_CANCELED)
+from st2common.constants.action import (
+ LIVEACTION_STATUS_SUCCEEDED,
+ LIVEACTION_STATUS_FAILED,
+ LIVEACTION_STATUS_CANCELED,
+)
from st2common.models.api.execution import ActionExecutionAPI
from st2common.models.db.execution import ActionExecutionDB
from st2common.persistence.execution import ActionExecution
@@ -30,13 +33,13 @@
from st2exporter.exporter.dumper import Dumper
from st2common.transport.queues import EXPORTER_WORK_QUEUE
-__all__ = [
- 'ExecutionsExporter',
- 'get_worker'
-]
+__all__ = ["ExecutionsExporter", "get_worker"]
-COMPLETION_STATUSES = [LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED,
- LIVEACTION_STATUS_CANCELED]
+COMPLETION_STATUSES = [
+ LIVEACTION_STATUS_SUCCEEDED,
+ LIVEACTION_STATUS_FAILED,
+ LIVEACTION_STATUS_CANCELED,
+]
LOG = logging.getLogger(__name__)
@@ -46,18 +49,21 @@ class ExecutionsExporter(consumers.MessageHandler):
def __init__(self, connection, queues):
super(ExecutionsExporter, self).__init__(connection, queues)
self.pending_executions = queue.Queue()
- self._dumper = Dumper(queue=self.pending_executions,
- export_dir=cfg.CONF.exporter.dump_dir)
+ self._dumper = Dumper(
+ queue=self.pending_executions, export_dir=cfg.CONF.exporter.dump_dir
+ )
self._consumer_thread = None
def start(self, wait=False):
- LOG.info('Bootstrapping executions from db...')
+ LOG.info("Bootstrapping executions from db...")
try:
self._bootstrap()
except:
- LOG.exception('Unable to bootstrap executions from db. Aborting.')
+ LOG.exception("Unable to bootstrap executions from db. Aborting.")
raise
- self._consumer_thread = eventlet.spawn(super(ExecutionsExporter, self).start, wait=True)
+ self._consumer_thread = eventlet.spawn(
+ super(ExecutionsExporter, self).start, wait=True
+ )
self._dumper.start()
if wait:
self.wait()
@@ -71,7 +77,7 @@ def shutdown(self):
super(ExecutionsExporter, self).shutdown()
def process(self, execution):
- LOG.debug('Got execution from queue: %s', execution)
+ LOG.debug("Got execution from queue: %s", execution)
if execution.status not in COMPLETION_STATUSES:
return
execution_api = ActionExecutionAPI.from_model(execution, mask_secrets=True)
@@ -80,21 +86,23 @@ def process(self, execution):
def _bootstrap(self):
marker = self._get_export_marker_from_db()
- LOG.info('Using marker %s...' % marker)
+ LOG.info("Using marker %s..." % marker)
missed_executions = self._get_missed_executions_from_db(export_marker=marker)
- LOG.info('Found %d executions not exported yet...', len(missed_executions))
+ LOG.info("Found %d executions not exported yet...", len(missed_executions))
for missed_execution in missed_executions:
if missed_execution.status not in COMPLETION_STATUSES:
continue
- execution_api = ActionExecutionAPI.from_model(missed_execution, mask_secrets=True)
+ execution_api = ActionExecutionAPI.from_model(
+ missed_execution, mask_secrets=True
+ )
try:
- LOG.debug('Missed execution %s', execution_api)
+ LOG.debug("Missed execution %s", execution_api)
self.pending_executions.put_nowait(execution_api)
except:
- LOG.exception('Failed adding execution to in-memory queue.')
+ LOG.exception("Failed adding execution to in-memory queue.")
continue
- LOG.info('Bootstrapped executions...')
+ LOG.info("Bootstrapped executions...")
def _get_export_marker_from_db(self):
try:
@@ -114,8 +122,8 @@ def _get_missed_executions_from_db(self, export_marker=None):
# XXX: Should adapt this query to get only executions with status
# in COMPLETION_STATUSES.
- filters = {'end_timestamp__gt': export_marker}
- LOG.info('Querying for executions with filters: %s', filters)
+ filters = {"end_timestamp__gt": export_marker}
+ LOG.info("Querying for executions with filters: %s", filters)
return ActionExecution.query(**filters)
def _get_all_executions_from_db(self):
diff --git a/st2exporter/tests/integration/test_dumper_integration.py b/st2exporter/tests/integration/test_dumper_integration.py
index bdb87b1249..0de7b91ed0 100644
--- a/st2exporter/tests/integration/test_dumper_integration.py
+++ b/st2exporter/tests/integration/test_dumper_integration.py
@@ -28,21 +28,30 @@
from st2tests.base import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
class TestDumper(DbTestCase):
fixtures_loader = FixturesLoader()
- loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
- loaded_executions = loaded_fixtures['executions']
+ loaded_fixtures = fixtures_loader.load_fixtures(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
+ loaded_executions = loaded_fixtures["executions"]
execution_apis = []
for execution in loaded_executions.values():
execution_apis.append(ActionExecutionAPI(**execution))
@@ -54,31 +63,45 @@ def get_queue(self):
executions_queue.put(execution)
return executions_queue
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_write_marker_to_db(self):
executions_queue = self.get_queue()
- dumper = Dumper(queue=executions_queue,
- export_dir='/tmp', batch_size=5,
- max_files_per_sleep=1,
- file_prefix='st2-stuff-', file_format='json')
- timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis]
+ dumper = Dumper(
+ queue=executions_queue,
+ export_dir="/tmp",
+ batch_size=5,
+ max_files_per_sleep=1,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
+ timestamps = [
+ isotime.parse(execution.end_timestamp) for execution in self.execution_apis
+ ]
max_timestamp = max(timestamps)
marker_db = dumper._write_marker_to_db(max_timestamp)
persisted_marker = marker_db.marker
self.assertIsInstance(persisted_marker, six.string_types)
self.assertEqual(isotime.parse(persisted_marker), max_timestamp)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_write_marker_to_db_marker_exists(self):
executions_queue = self.get_queue()
- dumper = Dumper(queue=executions_queue,
- export_dir='/tmp', batch_size=5,
- max_files_per_sleep=1,
- file_prefix='st2-stuff-', file_format='json')
- timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis]
+ dumper = Dumper(
+ queue=executions_queue,
+ export_dir="/tmp",
+ batch_size=5,
+ max_files_per_sleep=1,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
+ timestamps = [
+ isotime.parse(execution.end_timestamp) for execution in self.execution_apis
+ ]
max_timestamp = max(timestamps)
first_marker_db = dumper._write_marker_to_db(max_timestamp)
- second_marker_db = dumper._write_marker_to_db(max_timestamp + datetime.timedelta(hours=1))
+ second_marker_db = dumper._write_marker_to_db(
+ max_timestamp + datetime.timedelta(hours=1)
+ )
markers = DumperMarker.get_all()
self.assertEqual(len(markers), 1)
final_marker_id = markers[0].id
diff --git a/st2exporter/tests/integration/test_export_worker.py b/st2exporter/tests/integration/test_export_worker.py
index 8b0caf7d86..9237aab0e8 100644
--- a/st2exporter/tests/integration/test_export_worker.py
+++ b/st2exporter/tests/integration/test_export_worker.py
@@ -27,75 +27,92 @@
from st2tests.base import DbTestCase
from st2tests.fixturesloader import FixturesLoader
import st2tests.config as tests_config
+
tests_config.parse_args()
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
class TestExportWorker(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(TestExportWorker, cls).setUpClass()
fixtures_loader = FixturesLoader()
- loaded_fixtures = fixtures_loader.save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
- TestExportWorker.saved_executions = loaded_fixtures['executions']
+ loaded_fixtures = fixtures_loader.save_fixtures_to_db(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
+ TestExportWorker.saved_executions = loaded_fixtures["executions"]
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_marker_from_db(self):
marker_dt = date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=5)
- marker_db = DumperMarkerDB(marker=isotime.format(marker_dt, offset=False),
- updated_at=date_utils.get_datetime_utc_now())
+ marker_db = DumperMarkerDB(
+ marker=isotime.format(marker_dt, offset=False),
+ updated_at=date_utils.get_datetime_utc_now(),
+ )
DumperMarker.add_or_update(marker_db)
exec_exporter = ExecutionsExporter(None, None)
export_marker = exec_exporter._get_export_marker_from_db()
self.assertEqual(export_marker, date_utils.add_utc_tz(marker_dt))
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_missed_executions_from_db_no_marker(self):
exec_exporter = ExecutionsExporter(None, None)
all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None)
self.assertEqual(len(all_execs), len(self.saved_executions.values()))
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_missed_executions_from_db_with_marker(self):
exec_exporter = ExecutionsExporter(None, None)
all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None)
min_timestamp = min([item.end_timestamp for item in all_execs])
marker = min_timestamp + datetime.timedelta(seconds=1)
- execs_greater_than_marker = [item for item in all_execs if item.end_timestamp > marker]
+ execs_greater_than_marker = [
+ item for item in all_execs if item.end_timestamp > marker
+ ]
all_execs = exec_exporter._get_missed_executions_from_db(export_marker=marker)
self.assertTrue(len(all_execs) > 0)
self.assertTrue(len(all_execs) == len(execs_greater_than_marker))
for item in all_execs:
self.assertTrue(item.end_timestamp > marker)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_bootstrap(self):
exec_exporter = ExecutionsExporter(None, None)
exec_exporter._bootstrap()
- self.assertEqual(exec_exporter.pending_executions.qsize(), len(self.saved_executions))
+ self.assertEqual(
+ exec_exporter.pending_executions.qsize(), len(self.saved_executions)
+ )
count = 0
while count < exec_exporter.pending_executions.qsize():
- self.assertIsInstance(exec_exporter.pending_executions.get(), ActionExecutionAPI)
+ self.assertIsInstance(
+ exec_exporter.pending_executions.get(), ActionExecutionAPI
+ )
count += 1
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_process(self):
some_execution = list(self.saved_executions.values())[5]
exec_exporter = ExecutionsExporter(None, None)
self.assertEqual(exec_exporter.pending_executions.qsize(), 0)
exec_exporter.process(some_execution)
self.assertEqual(exec_exporter.pending_executions.qsize(), 1)
- some_execution.status = 'scheduled'
+ some_execution.status = "scheduled"
exec_exporter.process(some_execution)
self.assertEqual(exec_exporter.pending_executions.qsize(), 1)
diff --git a/st2exporter/tests/unit/test_dumper.py b/st2exporter/tests/unit/test_dumper.py
index 98e42e60f1..0ddec72e3b 100644
--- a/st2exporter/tests/unit/test_dumper.py
+++ b/st2exporter/tests/unit/test_dumper.py
@@ -28,21 +28,30 @@
from st2tests.fixturesloader import FixturesLoader
from st2common.util import date as date_utils
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
class TestDumper(EventletTestCase):
fixtures_loader = FixturesLoader()
- loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
- loaded_executions = loaded_fixtures['executions']
+ loaded_fixtures = fixtures_loader.load_fixtures(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
+ loaded_executions = loaded_fixtures["executions"]
execution_apis = []
for execution in loaded_executions.values():
execution_apis.append(ActionExecutionAPI(**execution))
@@ -54,81 +63,101 @@ def get_queue(self):
executions_queue.put(execution)
return executions_queue
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_batch_batch_size_greater_than_actual(self):
executions_queue = self.get_queue()
qsize = executions_queue.qsize()
self.assertTrue(qsize > 0)
- dumper = Dumper(queue=executions_queue, batch_size=2 * qsize,
- export_dir='/tmp')
+ dumper = Dumper(queue=executions_queue, batch_size=2 * qsize, export_dir="/tmp")
batch = dumper._get_batch()
self.assertEqual(len(batch), qsize)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_batch_batch_size_lesser_than_actual(self):
executions_queue = self.get_queue()
qsize = executions_queue.qsize()
self.assertTrue(qsize > 0)
expected_batch_size = int(qsize / 2)
- dumper = Dumper(queue=executions_queue,
- batch_size=expected_batch_size,
- export_dir='/tmp')
+ dumper = Dumper(
+ queue=executions_queue, batch_size=expected_batch_size, export_dir="/tmp"
+ )
batch = dumper._get_batch()
self.assertEqual(len(batch), expected_batch_size)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_get_file_name(self):
- dumper = Dumper(queue=self.get_queue(),
- export_dir='/tmp',
- file_prefix='st2-stuff-', file_format='json')
+ dumper = Dumper(
+ queue=self.get_queue(),
+ export_dir="/tmp",
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
file_name = dumper._get_file_name()
- export_date = date_utils.get_datetime_utc_now().strftime('%Y-%m-%d')
- self.assertTrue(file_name.startswith('/tmp/' + export_date + '/st2-stuff-'))
- self.assertTrue(file_name.endswith('json'))
+ export_date = date_utils.get_datetime_utc_now().strftime("%Y-%m-%d")
+ self.assertTrue(file_name.startswith("/tmp/" + export_date + "/st2-stuff-"))
+ self.assertTrue(file_name.endswith("json"))
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_write_to_disk_empty_queue(self):
- dumper = Dumper(queue=queue.Queue(),
- export_dir='/tmp',
- file_prefix='st2-stuff-', file_format='json')
+ dumper = Dumper(
+ queue=queue.Queue(),
+ export_dir="/tmp",
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
# We just make sure this doesn't blow up.
ret = dumper._write_to_disk()
self.assertEqual(ret, 0)
- @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True))
- @mock.patch.object(Dumper, '_update_marker', mock.MagicMock(return_value=None))
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
+ @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True))
+ @mock.patch.object(Dumper, "_update_marker", mock.MagicMock(return_value=None))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
def test_write_to_disk(self):
executions_queue = self.get_queue()
max_files_per_sleep = 5
- dumper = Dumper(queue=executions_queue,
- export_dir='/tmp', batch_size=1, max_files_per_sleep=max_files_per_sleep,
- file_prefix='st2-stuff-', file_format='json')
+ dumper = Dumper(
+ queue=executions_queue,
+ export_dir="/tmp",
+ batch_size=1,
+ max_files_per_sleep=max_files_per_sleep,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
# We just make sure this doesn't blow up.
ret = dumper._write_to_disk()
self.assertEqual(ret, max_files_per_sleep)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
- @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
+ @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True))
def test_start_stop_dumper(self):
executions_queue = self.get_queue()
sleep_interval = 0.01
- dumper = Dumper(queue=executions_queue, sleep_interval=sleep_interval,
- export_dir='/tmp', batch_size=1, max_files_per_sleep=5,
- file_prefix='st2-stuff-', file_format='json')
+ dumper = Dumper(
+ queue=executions_queue,
+ sleep_interval=sleep_interval,
+ export_dir="/tmp",
+ batch_size=1,
+ max_files_per_sleep=5,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
dumper.start()
# Call stop after at least one batch was written to disk.
eventlet.sleep(10 * sleep_interval)
dumper.stop()
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
- @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
+ @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True))
def test_update_marker(self):
executions_queue = self.get_queue()
- dumper = Dumper(queue=executions_queue,
- export_dir='/tmp', batch_size=5,
- max_files_per_sleep=1,
- file_prefix='st2-stuff-', file_format='json')
+ dumper = Dumper(
+ queue=executions_queue,
+ export_dir="/tmp",
+ batch_size=5,
+ max_files_per_sleep=1,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
# Batch 1
batch = self.execution_apis[0:5]
new_marker = dumper._update_marker(batch)
@@ -145,15 +174,21 @@ def test_update_marker(self):
self.assertEqual(new_marker, max_timestamp)
dumper._write_marker_to_db.assert_called_with(new_marker)
- @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True))
- @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True))
+ @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True))
+ @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True))
def test_update_marker_out_of_order_batch(self):
executions_queue = self.get_queue()
- dumper = Dumper(queue=executions_queue,
- export_dir='/tmp', batch_size=5,
- max_files_per_sleep=1,
- file_prefix='st2-stuff-', file_format='json')
- timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis]
+ dumper = Dumper(
+ queue=executions_queue,
+ export_dir="/tmp",
+ batch_size=5,
+ max_files_per_sleep=1,
+ file_prefix="st2-stuff-",
+ file_format="json",
+ )
+ timestamps = [
+ isotime.parse(execution.end_timestamp) for execution in self.execution_apis
+ ]
max_timestamp = max(timestamps)
# set dumper persisted timestamp to something less than min timestamp in the batch
diff --git a/st2exporter/tests/unit/test_json_converter.py b/st2exporter/tests/unit/test_json_converter.py
index ce2f484bca..07f82a8bf0 100644
--- a/st2exporter/tests/unit/test_json_converter.py
+++ b/st2exporter/tests/unit/test_json_converter.py
@@ -20,34 +20,43 @@
from st2tests.fixturesloader import FixturesLoader
from st2exporter.exporter.json_converter import JsonConverter
-DESCENDANTS_PACK = 'descendants'
+DESCENDANTS_PACK = "descendants"
DESCENDANTS_FIXTURES = {
- 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml',
- 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml',
- 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml']
+ "executions": [
+ "root_execution.yaml",
+ "child1_level1.yaml",
+ "child2_level1.yaml",
+ "child1_level2.yaml",
+ "child2_level2.yaml",
+ "child3_level2.yaml",
+ "child1_level3.yaml",
+ "child2_level3.yaml",
+ "child3_level3.yaml",
+ ]
}
class TestJsonConverter(unittest2.TestCase):
fixtures_loader = FixturesLoader()
- loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK,
- fixtures_dict=DESCENDANTS_FIXTURES)
+ loaded_fixtures = fixtures_loader.load_fixtures(
+ fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES
+ )
def test_convert(self):
- executions_list = list(self.loaded_fixtures['executions'].values())
+ executions_list = list(self.loaded_fixtures["executions"].values())
converter = JsonConverter()
converted_doc = converter.convert(executions_list)
- self.assertTrue(type(converted_doc), 'string')
+ self.assertTrue(type(converted_doc), "string")
reversed_doc = json.loads(converted_doc)
self.assertListEqual(executions_list, reversed_doc)
def test_convert_non_list(self):
- executions_dict = self.loaded_fixtures['executions']
+ executions_dict = self.loaded_fixtures["executions"]
converter = JsonConverter()
try:
converter.convert(executions_dict)
- self.fail('Should have thrown exception.')
+ self.fail("Should have thrown exception.")
except ValueError:
pass
diff --git a/st2reactor/Makefile b/st2reactor/Makefile
index cd3eb75a3e..232abed4dd 100644
--- a/st2reactor/Makefile
+++ b/st2reactor/Makefile
@@ -7,7 +7,7 @@ VER=0.4.0
COMPONENTS := st2reactor
.PHONY: rpm
-rpm:
+rpm:
pushd ~ && rpmdev-setuptree && popd
tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS)
cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/
diff --git a/st2reactor/dist_utils.py b/st2reactor/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2reactor/dist_utils.py
+++ b/st2reactor/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2reactor/setup.py b/st2reactor/setup.py
index 0379240b8f..adb3e7accc 100644
--- a/st2reactor/setup.py
+++ b/st2reactor/setup.py
@@ -23,9 +23,9 @@
from dist_utils import apply_vagrant_workaround
from st2reactor import __version__
-ST2_COMPONENT = 'st2reactor'
+ST2_COMPONENT = "st2reactor"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -33,23 +33,25 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
+ packages=find_packages(exclude=["setuptools", "tests"]),
scripts=[
- 'bin/st2-rule-tester',
- 'bin/st2-trigger-refire',
- 'bin/st2rulesengine',
- 'bin/st2sensorcontainer',
- 'bin/st2garbagecollector',
- 'bin/st2timersengine',
- ]
+ "bin/st2-rule-tester",
+ "bin/st2-trigger-refire",
+ "bin/st2rulesengine",
+ "bin/st2sensorcontainer",
+ "bin/st2garbagecollector",
+ "bin/st2timersengine",
+ ],
)
diff --git a/st2reactor/st2reactor/__init__.py b/st2reactor/st2reactor/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2reactor/st2reactor/__init__.py
+++ b/st2reactor/st2reactor/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2reactor/st2reactor/cmd/garbagecollector.py b/st2reactor/st2reactor/cmd/garbagecollector.py
index ab3c64409b..b4be4dfa8b 100644
--- a/st2reactor/st2reactor/cmd/garbagecollector.py
+++ b/st2reactor/st2reactor/cmd/garbagecollector.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -31,9 +32,7 @@
from st2reactor.garbage_collector import config
from st2reactor.garbage_collector.base import GarbageCollectorService
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__])
@@ -41,14 +40,17 @@
def _setup():
- capabilities = {
- 'name': 'garbagecollector',
- 'type': 'passive'
- }
- common_setup(service='garbagecollector', config=config, setup_db=True,
- register_mq_exchanges=True, register_signal_handlers=True,
- register_runners=False, service_registry=True,
- capabilities=capabilities)
+ capabilities = {"name": "garbagecollector", "type": "passive"}
+ common_setup(
+ service="garbagecollector",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_runners=False,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _teardown():
@@ -61,13 +63,14 @@ def main():
collection_interval = cfg.CONF.garbagecollector.collection_interval
sleep_delay = cfg.CONF.garbagecollector.sleep_delay
- garbage_collector = GarbageCollectorService(collection_interval=collection_interval,
- sleep_delay=sleep_delay)
+ garbage_collector = GarbageCollectorService(
+ collection_interval=collection_interval, sleep_delay=sleep_delay
+ )
exit_code = garbage_collector.run()
except SystemExit as exit_code:
return exit_code
except:
- LOG.exception('(PID:%s) GarbageCollector quit due to exception.', os.getpid())
+ LOG.exception("(PID:%s) GarbageCollector quit due to exception.", os.getpid())
return FAILURE_EXIT_CODE
finally:
_teardown()
diff --git a/st2reactor/st2reactor/cmd/rule_tester.py b/st2reactor/st2reactor/cmd/rule_tester.py
index 926a27a4ff..b346168cb5 100644
--- a/st2reactor/st2reactor/cmd/rule_tester.py
+++ b/st2reactor/st2reactor/cmd/rule_tester.py
@@ -25,23 +25,27 @@
from st2common.script_setup import teardown as common_teardown
from st2reactor.rules.tester import RuleTester
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOG = logging.getLogger(__name__)
def _register_cli_opts():
cli_opts = [
- cfg.StrOpt('rule', default=None,
- help='Path to the file containing rule definition.'),
- cfg.StrOpt('rule-ref', default=None,
- help='Ref of the rule.'),
- cfg.StrOpt('trigger-instance', default=None,
- help='Path to the file containing trigger instance definition'),
- cfg.StrOpt('trigger-instance-id', default=None,
- help='Id of the Trigger Instance to use for validation.')
+ cfg.StrOpt(
+ "rule", default=None, help="Path to the file containing rule definition."
+ ),
+ cfg.StrOpt("rule-ref", default=None, help="Ref of the rule."),
+ cfg.StrOpt(
+ "trigger-instance",
+ default=None,
+ help="Path to the file containing trigger instance definition",
+ ),
+ cfg.StrOpt(
+ "trigger-instance-id",
+ default=None,
+ help="Id of the Trigger Instance to use for validation.",
+ ),
]
do_register_cli_opts(cli_opts)
@@ -51,17 +55,19 @@ def main():
common_setup(config=config, setup_db=True, register_mq_exchanges=False)
try:
- tester = RuleTester(rule_file_path=cfg.CONF.rule,
- rule_ref=cfg.CONF.rule_ref,
- trigger_instance_file_path=cfg.CONF.trigger_instance,
- trigger_instance_id=cfg.CONF.trigger_instance_id)
+ tester = RuleTester(
+ rule_file_path=cfg.CONF.rule,
+ rule_ref=cfg.CONF.rule_ref,
+ trigger_instance_file_path=cfg.CONF.trigger_instance,
+ trigger_instance_id=cfg.CONF.trigger_instance_id,
+ )
matches = tester.evaluate()
finally:
common_teardown()
if matches:
- LOG.info('=== RULE MATCHES ===')
+ LOG.info("=== RULE MATCHES ===")
sys.exit(0)
else:
- LOG.info('=== RULE DOES NOT MATCH ===')
+ LOG.info("=== RULE DOES NOT MATCH ===")
sys.exit(1)
diff --git a/st2reactor/st2reactor/cmd/rulesengine.py b/st2reactor/st2reactor/cmd/rulesengine.py
index f372cc252e..895fbe42d9 100644
--- a/st2reactor/st2reactor/cmd/rulesengine.py
+++ b/st2reactor/st2reactor/cmd/rulesengine.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -34,13 +35,18 @@
def _setup():
- capabilities = {
- 'name': 'rulesengine',
- 'type': 'passive'
- }
- common_setup(service='rulesengine', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, register_internal_trigger_types=True,
- register_runners=False, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "rulesengine", "type": "passive"}
+ common_setup(
+ service="rulesengine",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=True,
+ register_runners=False,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _teardown():
@@ -48,7 +54,7 @@ def _teardown():
def _run_worker():
- LOG.info('(PID=%s) RulesEngine started.', os.getpid())
+ LOG.info("(PID=%s) RulesEngine started.", os.getpid())
rules_engine_worker = worker.get_worker()
@@ -56,10 +62,10 @@ def _run_worker():
rules_engine_worker.start()
return rules_engine_worker.wait()
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) RulesEngine stopped.', os.getpid())
+ LOG.info("(PID=%s) RulesEngine stopped.", os.getpid())
rules_engine_worker.shutdown()
except:
- LOG.exception('(PID:%s) RulesEngine quit due to exception.', os.getpid())
+ LOG.exception("(PID:%s) RulesEngine quit due to exception.", os.getpid())
return 1
return 0
@@ -72,7 +78,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except:
- LOG.exception('(PID=%s) RulesEngine quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) RulesEngine quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2reactor/st2reactor/cmd/sensormanager.py b/st2reactor/st2reactor/cmd/sensormanager.py
index df2be8e7ac..f3d27afb5b 100644
--- a/st2reactor/st2reactor/cmd/sensormanager.py
+++ b/st2reactor/st2reactor/cmd/sensormanager.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -33,9 +34,7 @@
from st2reactor.container.manager import SensorContainerManager
from st2reactor.container.partitioner_lookup import get_sensors_partitioner
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__])
@@ -43,13 +42,17 @@
def _setup():
- capabilities = {
- 'name': 'sensorcontainer',
- 'type': 'passive'
- }
- common_setup(service='sensorcontainer', config=config, setup_db=True,
- register_mq_exchanges=True, register_signal_handlers=True,
- register_runners=False, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "sensorcontainer", "type": "passive"}
+ common_setup(
+ service="sensorcontainer",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_runners=False,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _teardown():
@@ -60,16 +63,21 @@ def main():
try:
_setup()
- single_sensor_mode = (cfg.CONF.single_sensor_mode or
- cfg.CONF.sensorcontainer.single_sensor_mode)
+ single_sensor_mode = (
+ cfg.CONF.single_sensor_mode or cfg.CONF.sensorcontainer.single_sensor_mode
+ )
if single_sensor_mode and not cfg.CONF.sensor_ref:
- raise ValueError('--sensor-ref argument must be provided when running in single '
- 'sensor mode')
+ raise ValueError(
+ "--sensor-ref argument must be provided when running in single "
+ "sensor mode"
+ )
sensors_partitioner = get_sensors_partitioner()
- container_manager = SensorContainerManager(sensors_partitioner=sensors_partitioner,
- single_sensor_mode=single_sensor_mode)
+ container_manager = SensorContainerManager(
+ sensors_partitioner=sensors_partitioner,
+ single_sensor_mode=single_sensor_mode,
+ )
return container_manager.run_sensors()
except SystemExit as exit_code:
return exit_code
@@ -77,7 +85,7 @@ def main():
LOG.exception(e)
return 1
except:
- LOG.exception('(PID:%s) SensorContainer quit due to exception.', os.getpid())
+ LOG.exception("(PID:%s) SensorContainer quit due to exception.", os.getpid())
return FAILURE_EXIT_CODE
finally:
_teardown()
diff --git a/st2reactor/st2reactor/cmd/timersengine.py b/st2reactor/st2reactor/cmd/timersengine.py
index 0b0cc4b5dd..9b4edd52b5 100644
--- a/st2reactor/st2reactor/cmd/timersengine.py
+++ b/st2reactor/st2reactor/cmd/timersengine.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -38,12 +39,16 @@
def _setup():
- capabilities = {
- 'name': 'timerengine',
- 'type': 'passive'
- }
- common_setup(service='timer_engine', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, service_registry=True, capabilities=capabilities)
+ capabilities = {"name": "timerengine", "type": "passive"}
+ common_setup(
+ service="timer_engine",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _teardown():
@@ -55,14 +60,16 @@ def _kickoff_timer(timer):
def _run_worker():
- LOG.info('(PID=%s) TimerEngine started.', os.getpid())
+ LOG.info("(PID=%s) TimerEngine started.", os.getpid())
timer = None
try:
timer_thread = None
if cfg.CONF.timer.enable or cfg.CONF.timersengine.enable:
- local_tz = cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone
+ local_tz = (
+ cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone
+ )
timer = St2Timer(local_timezone=local_tz)
timer_thread = concurrency.spawn(_kickoff_timer, timer)
LOG.info(TIMER_ENABLED_LOG_LINE)
@@ -70,9 +77,9 @@ def _run_worker():
else:
LOG.info(TIMER_DISABLED_LOG_LINE)
except (KeyboardInterrupt, SystemExit):
- LOG.info('(PID=%s) TimerEngine stopped.', os.getpid())
+ LOG.info("(PID=%s) TimerEngine stopped.", os.getpid())
except:
- LOG.exception('(PID:%s) TimerEngine quit due to exception.', os.getpid())
+ LOG.exception("(PID:%s) TimerEngine quit due to exception.", os.getpid())
return 1
finally:
if timer:
@@ -88,7 +95,7 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except Exception:
- LOG.exception('(PID=%s) TimerEngine quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) TimerEngine quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2reactor/st2reactor/cmd/trigger_re_fire.py b/st2reactor/st2reactor/cmd/trigger_re_fire.py
index 4f2c8f9ca1..8282a5decf 100644
--- a/st2reactor/st2reactor/cmd/trigger_re_fire.py
+++ b/st2reactor/st2reactor/cmd/trigger_re_fire.py
@@ -27,24 +27,23 @@
from st2common.persistence.trigger import TriggerInstance
from st2common.transport.reactor import TriggerDispatcher
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
CONF = cfg.CONF
def _parse_config():
cli_opts = [
- cfg.BoolOpt('verbose',
- short='v',
- default=False,
- help='Print more verbose output'),
- cfg.StrOpt('trigger-instance-id',
- short='t',
- required=True,
- dest='trigger_instance_id',
- help='Id of trigger instance'),
+ cfg.BoolOpt(
+ "verbose", short="v", default=False, help="Print more verbose output"
+ ),
+ cfg.StrOpt(
+ "trigger-instance-id",
+ short="t",
+ required=True,
+ dest="trigger_instance_id",
+ help="Id of trigger instance",
+ ),
]
CONF.register_cli_opts(cli_opts)
st2cfg.register_opts(ignore_errors=False)
@@ -54,22 +53,17 @@ def _parse_config():
def _setup_logging():
logging_config = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'formatters': {
- 'default': {
- 'format': '%(asctime)s %(levelname)s %(name)s %(message)s'
- },
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "default": {"format": "%(asctime)s %(levelname)s %(name)s %(message)s"},
},
- 'handlers': {
- 'console': {
- '()': std_logging.StreamHandler,
- 'formatter': 'default'
- }
+ "handlers": {
+ "console": {"()": std_logging.StreamHandler, "formatter": "default"}
},
- 'root': {
- 'handlers': ['console'],
- 'level': 'DEBUG',
+ "root": {
+ "handlers": ["console"],
+ "level": "DEBUG",
},
}
std_logging.config.dictConfig(logging_config)
@@ -82,8 +76,9 @@ def _setup_db():
def _refire_trigger_instance(trigger_instance_id, log_):
trigger_instance = TriggerInstance.get_by_id(trigger_instance_id)
trigger_dispatcher = TriggerDispatcher(log_)
- trigger_dispatcher.dispatch(trigger=trigger_instance.trigger,
- payload=trigger_instance.payload)
+ trigger_dispatcher.dispatch(
+ trigger=trigger_instance.trigger, payload=trigger_instance.payload
+ )
def main():
@@ -94,7 +89,8 @@ def main():
else:
output = pprint.pprint
_setup_db()
- _refire_trigger_instance(trigger_instance_id=CONF.trigger_instance_id,
- log_=logging.getLogger(__name__))
- output('Trigger re-fired')
+ _refire_trigger_instance(
+ trigger_instance_id=CONF.trigger_instance_id, log_=logging.getLogger(__name__)
+ )
+ output("Trigger re-fired")
db_teardown()
diff --git a/st2reactor/st2reactor/container/hash_partitioner.py b/st2reactor/st2reactor/container/hash_partitioner.py
index 9ed0cb78be..b9e5e46658 100644
--- a/st2reactor/st2reactor/container/hash_partitioner.py
+++ b/st2reactor/st2reactor/container/hash_partitioner.py
@@ -17,25 +17,25 @@
import ctypes
import hashlib
-from st2reactor.container.partitioners import DefaultPartitioner, get_all_enabled_sensors
+from st2reactor.container.partitioners import (
+ DefaultPartitioner,
+ get_all_enabled_sensors,
+)
-__all__ = [
- 'HashPartitioner',
- 'Range'
-]
+__all__ = ["HashPartitioner", "Range"]
# The range expression serialized is of the form `RANGE_START..RANGE_END|RANGE_START..RANGE_END ...`
-SUB_RANGE_SEPARATOR = '|'
-RANGE_BOUNDARY_SEPARATOR = '..'
+SUB_RANGE_SEPARATOR = "|"
+RANGE_BOUNDARY_SEPARATOR = ".."
class Range(object):
- RANGE_MIN_ENUM = 'min'
+ RANGE_MIN_ENUM = "min"
RANGE_MIN_VALUE = 0
- RANGE_MAX_ENUM = 'max'
- RANGE_MAX_VALUE = 2**32
+ RANGE_MAX_ENUM = "max"
+ RANGE_MAX_VALUE = 2 ** 32
def __init__(self, range_repr):
self.range_start, self.range_end = self._get_range_boundaries(range_repr)
@@ -44,15 +44,17 @@ def __contains__(self, item):
return item >= self.range_start and item < self.range_end
def _get_range_boundaries(self, range_repr):
- range_repr = [value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR)]
+ range_repr = [
+ value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR)
+ ]
if len(range_repr) != 2:
- raise ValueError('Unsupported sub-range format %s.' % range_repr)
+ raise ValueError("Unsupported sub-range format %s." % range_repr)
range_start = self._get_valid_range_boundary(range_repr[0])
range_end = self._get_valid_range_boundary(range_repr[1])
if range_start > range_end:
- raise ValueError('Misconfigured range [%d..%d]' % (range_start, range_end))
+ raise ValueError("Misconfigured range [%d..%d]" % (range_start, range_end))
return (range_start, range_end)
def _get_valid_range_boundary(self, boundary_value):
@@ -73,7 +75,6 @@ def _get_valid_range_boundary(self, boundary_value):
class HashPartitioner(DefaultPartitioner):
-
def __init__(self, sensor_node_name, hash_ranges):
super(HashPartitioner, self).__init__(sensor_node_name=sensor_node_name)
self._hash_ranges = self._create_hash_ranges(hash_ranges)
@@ -112,7 +113,7 @@ def _hash_sensor_ref(self, sensor_ref):
h = ctypes.c_uint(0)
for d in reversed(str(md5_hash_int_repr)):
d = ctypes.c_uint(int(d))
- higherorder = ctypes.c_uint(h.value & 0xf8000000)
+ higherorder = ctypes.c_uint(h.value & 0xF8000000)
h = ctypes.c_uint(h.value << 5)
h = ctypes.c_uint(h.value ^ (higherorder.value >> 27))
h = ctypes.c_uint(h.value ^ d.value)
diff --git a/st2reactor/st2reactor/container/manager.py b/st2reactor/st2reactor/container/manager.py
index 694d3ce337..e9f251aebc 100644
--- a/st2reactor/st2reactor/container/manager.py
+++ b/st2reactor/st2reactor/container/manager.py
@@ -27,16 +27,13 @@
LOG = logging.getLogger(__name__)
-__all__ = [
- 'SensorContainerManager'
-]
+__all__ = ["SensorContainerManager"]
class SensorContainerManager(object):
-
def __init__(self, sensors_partitioner, single_sensor_mode=False):
if not sensors_partitioner:
- raise ValueError('sensors_partitioner should be non-None.')
+ raise ValueError("sensors_partitioner should be non-None.")
self._sensors_partitioner = sensors_partitioner
self._single_sensor_mode = single_sensor_mode
@@ -44,10 +41,12 @@ def __init__(self, sensors_partitioner, single_sensor_mode=False):
self._sensor_container = None
self._container_thread = None
- self._sensors_watcher = SensorWatcher(create_handler=self._handle_create_sensor,
- update_handler=self._handle_update_sensor,
- delete_handler=self._handle_delete_sensor,
- queue_suffix='sensor_container')
+ self._sensors_watcher = SensorWatcher(
+ create_handler=self._handle_create_sensor,
+ update_handler=self._handle_update_sensor,
+ delete_handler=self._handle_delete_sensor,
+ queue_suffix="sensor_container",
+ )
def run_sensors(self):
"""
@@ -55,15 +54,18 @@ def run_sensors(self):
"""
sensors = self._sensors_partitioner.get_sensors()
if sensors:
- LOG.info('Setting up container to run %d sensors.', len(sensors))
- LOG.info('\tSensors list - %s.', [self._get_sensor_ref(sensor) for sensor in sensors])
+ LOG.info("Setting up container to run %d sensors.", len(sensors))
+ LOG.info(
+ "\tSensors list - %s.",
+ [self._get_sensor_ref(sensor) for sensor in sensors],
+ )
sensors_to_run = []
for sensor in sensors:
# TODO: Directly pass DB object to the ProcessContainer
sensors_to_run.append(self._to_sensor_object(sensor))
- LOG.info('(PID:%s) SensorContainer started.', os.getpid())
+ LOG.info("(PID:%s) SensorContainer started.", os.getpid())
self._setup_sigterm_handler()
exit_code = self._spin_container_and_wait(sensors_to_run)
@@ -74,22 +76,25 @@ def _spin_container_and_wait(self, sensors):
try:
self._sensor_container = ProcessSensorContainer(
- sensors=sensors,
- single_sensor_mode=self._single_sensor_mode)
+ sensors=sensors, single_sensor_mode=self._single_sensor_mode
+ )
self._container_thread = concurrency.spawn(self._sensor_container.run)
- LOG.debug('Starting sensor CUD watcher...')
+ LOG.debug("Starting sensor CUD watcher...")
self._sensors_watcher.start()
exit_code = self._container_thread.wait()
- LOG.error('Process container quit with exit_code %d.', exit_code)
- LOG.error('(PID:%s) SensorContainer stopped.', os.getpid())
+ LOG.error("Process container quit with exit_code %d.", exit_code)
+ LOG.error("(PID:%s) SensorContainer stopped.", os.getpid())
except (KeyboardInterrupt, SystemExit):
self._sensor_container.shutdown()
self._sensors_watcher.stop()
- LOG.info('(PID:%s) SensorContainer stopped. Reason - %s', os.getpid(),
- sys.exc_info()[0].__name__)
+ LOG.info(
+ "(PID:%s) SensorContainer stopped. Reason - %s",
+ os.getpid(),
+ sys.exc_info()[0].__name__,
+ )
concurrency.kill(self._container_thread)
self._container_thread = None
@@ -99,7 +104,6 @@ def _spin_container_and_wait(self, sensors):
return exit_code
def _setup_sigterm_handler(self):
-
def sigterm_handler(signum=None, frame=None):
# This will cause SystemExit to be throw and we call sensor_container.shutdown()
# there which cleans things up.
@@ -110,16 +114,16 @@ def sigterm_handler(signum=None, frame=None):
signal.signal(signal.SIGTERM, sigterm_handler)
def _to_sensor_object(self, sensor_db):
- file_path = sensor_db.artifact_uri.replace('file://', '')
- class_name = sensor_db.entry_point.split('.')[-1]
+ file_path = sensor_db.artifact_uri.replace("file://", "")
+ class_name = sensor_db.entry_point.split(".")[-1]
sensor_obj = {
- 'pack': sensor_db.pack,
- 'file_path': file_path,
- 'class_name': class_name,
- 'trigger_types': sensor_db.trigger_types,
- 'poll_interval': sensor_db.poll_interval,
- 'ref': self._get_sensor_ref(sensor_db)
+ "pack": sensor_db.pack,
+ "file_path": file_path,
+ "class_name": class_name,
+ "trigger_types": sensor_db.trigger_types,
+ "poll_interval": sensor_db.poll_interval,
+ "ref": self._get_sensor_ref(sensor_db),
}
return sensor_obj
@@ -130,42 +134,50 @@ def _to_sensor_object(self, sensor_db):
def _handle_create_sensor(self, sensor):
if not self._sensors_partitioner.is_sensor_owner(sensor):
- LOG.info('sensor %s is not supported. Ignoring create.', self._get_sensor_ref(sensor))
+ LOG.info(
+ "sensor %s is not supported. Ignoring create.",
+ self._get_sensor_ref(sensor),
+ )
return
if not sensor.enabled:
- LOG.info('sensor %s is not enabled.', self._get_sensor_ref(sensor))
+ LOG.info("sensor %s is not enabled.", self._get_sensor_ref(sensor))
return
- LOG.info('Adding sensor %s.', self._get_sensor_ref(sensor))
+ LOG.info("Adding sensor %s.", self._get_sensor_ref(sensor))
self._sensor_container.add_sensor(sensor=self._to_sensor_object(sensor))
def _handle_update_sensor(self, sensor):
if not self._sensors_partitioner.is_sensor_owner(sensor):
- LOG.info('sensor %s is not assigned to this partition. Ignoring update. ',
- self._get_sensor_ref(sensor))
+ LOG.info(
+ "sensor %s is not assigned to this partition. Ignoring update. ",
+ self._get_sensor_ref(sensor),
+ )
return
sensor_ref = self._get_sensor_ref(sensor)
sensor_obj = self._to_sensor_object(sensor)
# Handle disabling sensor
if not sensor.enabled:
- LOG.info('Sensor %s disabled. Unloading sensor.', sensor_ref)
+ LOG.info("Sensor %s disabled. Unloading sensor.", sensor_ref)
self._sensor_container.remove_sensor(sensor=sensor_obj)
return
- LOG.info('Sensor %s updated. Reloading sensor.', sensor_ref)
+ LOG.info("Sensor %s updated. Reloading sensor.", sensor_ref)
try:
self._sensor_container.remove_sensor(sensor=sensor_obj)
except:
- LOG.exception('Failed to reload sensor %s', sensor_ref)
+ LOG.exception("Failed to reload sensor %s", sensor_ref)
else:
self._sensor_container.add_sensor(sensor=sensor_obj)
- LOG.info('Sensor %s reloaded.', sensor_ref)
+ LOG.info("Sensor %s reloaded.", sensor_ref)
def _handle_delete_sensor(self, sensor):
if not self._sensors_partitioner.is_sensor_owner(sensor):
- LOG.info('sensor %s is not supported. Ignoring delete.', self._get_sensor_ref(sensor))
+ LOG.info(
+ "sensor %s is not supported. Ignoring delete.",
+ self._get_sensor_ref(sensor),
+ )
return
- LOG.info('Unloading sensor %s.', self._get_sensor_ref(sensor))
+ LOG.info("Unloading sensor %s.", self._get_sensor_ref(sensor))
self._sensor_container.remove_sensor(sensor=self._to_sensor_object(sensor))
def _get_sensor_ref(self, sensor):
diff --git a/st2reactor/st2reactor/container/partitioner_lookup.py b/st2reactor/st2reactor/container/partitioner_lookup.py
index c4f43db6da..1469b3c63c 100644
--- a/st2reactor/st2reactor/container/partitioner_lookup.py
+++ b/st2reactor/st2reactor/container/partitioner_lookup.py
@@ -18,16 +18,22 @@
from oslo_config import cfg
from st2common import log as logging
-from st2common.constants.sensors import DEFAULT_PARTITION_LOADER, KVSTORE_PARTITION_LOADER, \
- FILE_PARTITION_LOADER, HASH_PARTITION_LOADER
+from st2common.constants.sensors import (
+ DEFAULT_PARTITION_LOADER,
+ KVSTORE_PARTITION_LOADER,
+ FILE_PARTITION_LOADER,
+ HASH_PARTITION_LOADER,
+)
from st2common.exceptions.sensors import SensorPartitionerNotSupportedException
-from st2reactor.container.partitioners import DefaultPartitioner, KVStorePartitioner, \
- FileBasedPartitioner, SingleSensorPartitioner
+from st2reactor.container.partitioners import (
+ DefaultPartitioner,
+ KVStorePartitioner,
+ FileBasedPartitioner,
+ SingleSensorPartitioner,
+)
from st2reactor.container.hash_partitioner import HashPartitioner
-__all__ = [
- 'get_sensors_partitioner'
-]
+__all__ = ["get_sensors_partitioner"]
LOG = logging.getLogger(__name__)
@@ -35,25 +41,28 @@
DEFAULT_PARTITION_LOADER: DefaultPartitioner,
KVSTORE_PARTITION_LOADER: KVStorePartitioner,
FILE_PARTITION_LOADER: FileBasedPartitioner,
- HASH_PARTITION_LOADER: HashPartitioner
+ HASH_PARTITION_LOADER: HashPartitioner,
}
def get_sensors_partitioner():
if cfg.CONF.sensor_ref:
- LOG.info('Running in single sensor mode, using a single sensor partitioner...')
+ LOG.info("Running in single sensor mode, using a single sensor partitioner...")
return SingleSensorPartitioner(sensor_ref=cfg.CONF.sensor_ref)
partition_provider_config = copy.copy(cfg.CONF.sensorcontainer.partition_provider)
- partition_provider = partition_provider_config.pop('name')
+ partition_provider = partition_provider_config.pop("name")
sensor_node_name = cfg.CONF.sensorcontainer.sensor_node_name
provider = PROVIDERS.get(partition_provider.lower(), None)
if not provider:
- raise SensorPartitionerNotSupportedException('Partition provider %s not found.' %
- (partition_provider))
+ raise SensorPartitionerNotSupportedException(
+ "Partition provider %s not found." % (partition_provider)
+ )
- LOG.info('Using partitioner %s with sensornode %s.', partition_provider, sensor_node_name)
+ LOG.info(
+ "Using partitioner %s with sensornode %s.", partition_provider, sensor_node_name
+ )
# pass in extra config with no analysis
return provider(sensor_node_name=sensor_node_name, **partition_provider_config)
diff --git a/st2reactor/st2reactor/container/partitioners.py b/st2reactor/st2reactor/container/partitioners.py
index 12a17f9081..02a6d6137b 100644
--- a/st2reactor/st2reactor/container/partitioners.py
+++ b/st2reactor/st2reactor/container/partitioners.py
@@ -18,18 +18,20 @@
import yaml
from st2common import log as logging
-from st2common.exceptions.sensors import SensorNotFoundException, \
- SensorPartitionMapMissingException
+from st2common.exceptions.sensors import (
+ SensorNotFoundException,
+ SensorPartitionMapMissingException,
+)
from st2common.persistence.keyvalue import KeyValuePair
from st2common.persistence.sensor import SensorType
__all__ = [
- 'get_all_enabled_sensors',
- 'DefaultPartitioner',
- 'KVStorePartitioner',
- 'FileBasedPartitioner',
- 'SingleSensorPartitioner'
+ "get_all_enabled_sensors",
+ "DefaultPartitioner",
+ "KVStorePartitioner",
+ "FileBasedPartitioner",
+ "SingleSensorPartitioner",
]
LOG = logging.getLogger(__name__)
@@ -38,12 +40,11 @@
def get_all_enabled_sensors():
# only query for enabled sensors.
sensors = SensorType.query(enabled=True)
- LOG.info('Found %d registered sensors in db scan.', len(sensors))
+ LOG.info("Found %d registered sensors in db scan.", len(sensors))
return sensors
class DefaultPartitioner(object):
-
def __init__(self, sensor_node_name):
self.sensor_node_name = sensor_node_name
@@ -78,7 +79,6 @@ def get_required_sensor_refs(self):
class KVStorePartitioner(DefaultPartitioner):
-
def __init__(self, sensor_node_name):
super(KVStorePartitioner, self).__init__(sensor_node_name=sensor_node_name)
self._supported_sensor_refs = None
@@ -90,46 +90,51 @@ def get_required_sensor_refs(self):
partition_lookup_key = self._get_partition_lookup_key(self.sensor_node_name)
kvp = KeyValuePair.get_by_name(partition_lookup_key)
- sensor_refs_str = kvp.value if kvp.value else ''
- self._supported_sensor_refs = set([
- sensor_ref.strip() for sensor_ref in sensor_refs_str.split(',')])
+ sensor_refs_str = kvp.value if kvp.value else ""
+ self._supported_sensor_refs = set(
+ [sensor_ref.strip() for sensor_ref in sensor_refs_str.split(",")]
+ )
return list(self._supported_sensor_refs)
def _get_partition_lookup_key(self, sensor_node_name):
- return '{}.sensor_partition'.format(sensor_node_name)
+ return "{}.sensor_partition".format(sensor_node_name)
class FileBasedPartitioner(DefaultPartitioner):
-
def __init__(self, sensor_node_name, partition_file):
super(FileBasedPartitioner, self).__init__(sensor_node_name=sensor_node_name)
self.partition_file = partition_file
self._supported_sensor_refs = None
def is_sensor_owner(self, sensor_db):
- return sensor_db.get_reference().ref in self._supported_sensor_refs and sensor_db.enabled
+ return (
+ sensor_db.get_reference().ref in self._supported_sensor_refs
+ and sensor_db.enabled
+ )
def get_required_sensor_refs(self):
- with open(self.partition_file, 'r') as f:
+ with open(self.partition_file, "r") as f:
partition_map = yaml.safe_load(f)
sensor_refs = partition_map.get(self.sensor_node_name, None)
if sensor_refs is None:
- raise SensorPartitionMapMissingException('Sensor partition not found for %s in %s.'
- % (self.sensor_node_name,
- self.partition_file))
+ raise SensorPartitionMapMissingException(
+ "Sensor partition not found for %s in %s."
+ % (self.sensor_node_name, self.partition_file)
+ )
self._supported_sensor_refs = set(sensor_refs)
return list(self._supported_sensor_refs)
class SingleSensorPartitioner(object):
-
def __init__(self, sensor_ref):
self._sensor_ref = sensor_ref
def get_sensors(self):
sensor = SensorType.get_by_ref(self._sensor_ref)
if not sensor:
- raise SensorNotFoundException('Sensor %s not found in db.' % self._sensor_ref)
+ raise SensorNotFoundException(
+ "Sensor %s not found in db." % self._sensor_ref
+ )
return [sensor]
def is_sensor_owner(self, sensor_db):
diff --git a/st2reactor/st2reactor/container/process_container.py b/st2reactor/st2reactor/container/process_container.py
index f8f1638d71..890bcccbb9 100644
--- a/st2reactor/st2reactor/container/process_container.py
+++ b/st2reactor/st2reactor/container/process_container.py
@@ -31,7 +31,7 @@
from st2common.constants.error_messages import PACK_VIRTUALENV_DOESNT_EXIST
from st2common.constants.system import API_URL_ENV_VARIABLE_NAME
from st2common.constants.system import AUTH_TOKEN_ENV_VARIABLE_NAME
-from st2common.constants.triggers import (SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER)
+from st2common.constants.triggers import SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER
from st2common.constants.exit_codes import SUCCESS_EXIT_CODE
from st2common.constants.exit_codes import FAILURE_EXIT_CODE
from st2common.models.system.common import ResourceReference
@@ -44,14 +44,12 @@
from st2common.util.sandboxing import get_sandbox_python_binary_path
from st2common.util.sandboxing import get_sandbox_virtualenv_path
-__all__ = [
- 'ProcessSensorContainer'
-]
+__all__ = ["ProcessSensorContainer"]
-LOG = logging.getLogger('st2reactor.process_sensor_container')
+LOG = logging.getLogger("st2reactor.process_sensor_container")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-WRAPPER_SCRIPT_NAME = 'sensor_wrapper.py'
+WRAPPER_SCRIPT_NAME = "sensor_wrapper.py"
WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME)
# How many times to try to subsequently respawn a sensor after a non-zero exit before giving up
@@ -78,8 +76,15 @@ class ProcessSensorContainer(object):
Sensor container which runs sensors in a separate process.
"""
- def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatcher=None,
- wrapper_script_path=WRAPPER_SCRIPT_PATH, create_token=True):
+ def __init__(
+ self,
+ sensors,
+ poll_interval=5,
+ single_sensor_mode=False,
+ dispatcher=None,
+ wrapper_script_path=WRAPPER_SCRIPT_PATH,
+ create_token=True,
+ ):
"""
:param sensors: A list of sensor dicts.
:type sensors: ``list`` of ``dict``
@@ -119,7 +124,9 @@ def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatche
# Stores information needed for respawning dead sensors
self._sensor_start_times = {} # maps sensor_id -> sensor start time
- self._sensor_respawn_counts = defaultdict(int) # maps sensor_id -> number of respawns
+ self._sensor_respawn_counts = defaultdict(
+ int
+ ) # maps sensor_id -> number of respawns
# A list of all the instance variables which hold internal state information about a
# particular_sensor
@@ -144,10 +151,10 @@ def run(self):
sensor_ids = list(self._sensors.keys())
if len(sensor_ids) >= 1:
- LOG.debug('%d active sensor(s)' % (len(sensor_ids)))
+ LOG.debug("%d active sensor(s)" % (len(sensor_ids)))
self._poll_sensors_for_results(sensor_ids)
else:
- LOG.debug('No active sensors')
+ LOG.debug("No active sensors")
concurrency.sleep(self._poll_interval)
except success_exception_cls:
@@ -157,12 +164,12 @@ def run(self):
self._stopped = True
return SUCCESS_EXIT_CODE
except:
- LOG.exception('Container failed to run sensors.')
+ LOG.exception("Container failed to run sensors.")
self._stopped = True
return FAILURE_EXIT_CODE
self._stopped = True
- LOG.error('Process container stopped.')
+ LOG.error("Process container stopped.")
exit_code = self._exit_code or SUCCESS_EXIT_CODE
return exit_code
@@ -179,23 +186,29 @@ def _poll_sensors_for_results(self, sensor_ids):
if status is not None:
# Dead process detected
- LOG.info('Process for sensor %s has exited with code %s', sensor_id, status)
+ LOG.info(
+ "Process for sensor %s has exited with code %s", sensor_id, status
+ )
sensor = self._sensors[sensor_id]
self._delete_sensor(sensor_id)
- self._dispatch_trigger_for_sensor_exit(sensor=sensor,
- exit_code=status)
+ self._dispatch_trigger_for_sensor_exit(sensor=sensor, exit_code=status)
# Try to respawn a dead process (maybe it was a simple failure which can be
# resolved with a restart)
- concurrency.spawn(self._respawn_sensor, sensor_id=sensor_id, sensor=sensor,
- exit_code=status)
+ concurrency.spawn(
+ self._respawn_sensor,
+ sensor_id=sensor_id,
+ sensor=sensor,
+ exit_code=status,
+ )
else:
sensor_start_time = self._sensor_start_times[sensor_id]
sensor_respawn_count = self._sensor_respawn_counts[sensor_id]
- successfully_started = ((now - sensor_start_time) >=
- SENSOR_SUCCESSFUL_START_THRESHOLD)
+ successfully_started = (
+ now - sensor_start_time
+ ) >= SENSOR_SUCCESSFUL_START_THRESHOLD
if successfully_started and sensor_respawn_count >= 1:
# Sensor has been successfully running more than threshold seconds, clear the
@@ -209,7 +222,7 @@ def stopped(self):
return self._stopped
def shutdown(self, force=False):
- LOG.info('Container shutting down. Invoking cleanup on sensors.')
+ LOG.info("Container shutting down. Invoking cleanup on sensors.")
self._stopped = True
if force:
@@ -221,7 +234,7 @@ def shutdown(self, force=False):
for sensor_id in sensor_ids:
self._stop_sensor_process(sensor_id=sensor_id, exit_timeout=exit_timeout)
- LOG.info('All sensors are shut down.')
+ LOG.info("All sensors are shut down.")
self._sensors = {}
self._processes = {}
@@ -235,11 +248,11 @@ def add_sensor(self, sensor):
sensor_id = self._get_sensor_id(sensor=sensor)
if sensor_id in self._sensors:
- LOG.warning('Sensor %s already exists and running.', sensor_id)
+ LOG.warning("Sensor %s already exists and running.", sensor_id)
return False
self._spawn_sensor_process(sensor=sensor)
- LOG.debug('Sensor %s started.', sensor_id)
+ LOG.debug("Sensor %s started.", sensor_id)
self._sensors[sensor_id] = sensor
return True
@@ -252,11 +265,11 @@ def remove_sensor(self, sensor):
sensor_id = self._get_sensor_id(sensor=sensor)
if sensor_id not in self._sensors:
- LOG.warning('Sensor %s isn\'t running in this container.', sensor_id)
+ LOG.warning("Sensor %s isn't running in this container.", sensor_id)
return False
self._stop_sensor_process(sensor_id=sensor_id)
- LOG.debug('Sensor %s stopped.', sensor_id)
+ LOG.debug("Sensor %s stopped.", sensor_id)
return True
def _run_all_sensors(self):
@@ -264,7 +277,7 @@ def _run_all_sensors(self):
for sensor_id in sensor_ids:
sensor_obj = self._sensors[sensor_id]
- LOG.info('Running sensor %s', sensor_id)
+ LOG.info("Running sensor %s", sensor_id)
try:
self._spawn_sensor_process(sensor=sensor_obj)
@@ -275,7 +288,7 @@ def _run_all_sensors(self):
del self._sensors[sensor_id]
continue
- LOG.info('Sensor %s started' % sensor_id)
+ LOG.info("Sensor %s started" % sensor_id)
def _spawn_sensor_process(self, sensor):
"""
@@ -285,45 +298,53 @@ def _spawn_sensor_process(self, sensor):
belonging to the sensor pack.
"""
sensor_id = self._get_sensor_id(sensor=sensor)
- pack_ref = sensor['pack']
+ pack_ref = sensor["pack"]
virtualenv_path = get_sandbox_virtualenv_path(pack=pack_ref)
python_path = get_sandbox_python_binary_path(pack=pack_ref)
if virtualenv_path and not os.path.isdir(virtualenv_path):
- format_values = {'pack': sensor['pack'], 'virtualenv_path': virtualenv_path}
+ format_values = {"pack": sensor["pack"], "virtualenv_path": virtualenv_path}
msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values
raise Exception(msg)
- args = self._get_args_for_wrapper_script(python_binary=python_path, sensor=sensor)
+ args = self._get_args_for_wrapper_script(
+ python_binary=python_path, sensor=sensor
+ )
if self._enable_common_pack_libs:
- pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref)
+ pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(
+ pack_ref=pack_ref
+ )
else:
pack_common_libs_path = None
env = os.environ.copy()
- sandbox_python_path = get_sandbox_python_path(inherit_from_parent=True,
- inherit_parent_virtualenv=True)
+ sandbox_python_path = get_sandbox_python_path(
+ inherit_from_parent=True, inherit_parent_virtualenv=True
+ )
if self._enable_common_pack_libs and pack_common_libs_path:
- env['PYTHONPATH'] = pack_common_libs_path + ':' + sandbox_python_path
+ env["PYTHONPATH"] = pack_common_libs_path + ":" + sandbox_python_path
else:
- env['PYTHONPATH'] = sandbox_python_path
+ env["PYTHONPATH"] = sandbox_python_path
if self._create_token:
# Include full api URL and API token specific to that sensor
- LOG.debug('Creating temporary auth token for sensor %s' % (sensor['class_name']))
+ LOG.debug(
+ "Creating temporary auth token for sensor %s" % (sensor["class_name"])
+ )
ttl = cfg.CONF.auth.service_token_ttl
metadata = {
- 'service': 'sensors_container',
- 'sensor_path': sensor['file_path'],
- 'sensor_class': sensor['class_name']
+ "service": "sensors_container",
+ "sensor_path": sensor["file_path"],
+ "sensor_class": sensor["class_name"],
}
- temporary_token = create_token(username='sensors_container', ttl=ttl, metadata=metadata,
- service=True)
+ temporary_token = create_token(
+ username="sensors_container", ttl=ttl, metadata=metadata, service=True
+ )
env[API_URL_ENV_VARIABLE_NAME] = get_full_public_api_url()
env[AUTH_TOKEN_ENV_VARIABLE_NAME] = temporary_token.token
@@ -332,18 +353,27 @@ def _spawn_sensor_process(self, sensor):
# TODO 2: Store metadata (wrapper process id) with the token and delete
# tokens for old, dead processes on startup
- cmd = ' '.join(args)
+ cmd = " ".join(args)
LOG.debug('Running sensor subprocess (cmd="%s")', cmd)
# TODO: Intercept stdout and stderr for aggregated logging purposes
try:
- process = subprocess.Popen(args=args, stdin=None, stdout=None,
- stderr=None, shell=False, env=env,
- preexec_fn=on_parent_exit('SIGTERM'))
+ process = subprocess.Popen(
+ args=args,
+ stdin=None,
+ stdout=None,
+ stderr=None,
+ shell=False,
+ env=env,
+ preexec_fn=on_parent_exit("SIGTERM"),
+ )
except Exception as e:
- cmd = ' '.join(args)
- message = ('Failed to spawn process for sensor %s ("%s"): %s' %
- (sensor_id, cmd, six.text_type(e)))
+ cmd = " ".join(args)
+ message = 'Failed to spawn process for sensor %s ("%s"): %s' % (
+ sensor_id,
+ cmd,
+ six.text_type(e),
+ )
raise Exception(message)
self._processes[sensor_id] = process
@@ -397,32 +427,35 @@ def _respawn_sensor(self, sensor_id, sensor, exit_code):
"""
Method for respawning a sensor which died with a non-zero exit code.
"""
- extra = {'sensor_id': sensor_id, 'sensor': sensor}
+ extra = {"sensor_id": sensor_id, "sensor": sensor}
if self._single_sensor_mode:
# In single sensor mode we want to exit immediately on failure
- LOG.info('Not respawning a sensor since running in single sensor mode',
- extra=extra)
+ LOG.info(
+ "Not respawning a sensor since running in single sensor mode",
+ extra=extra,
+ )
self._stopped = True
self._exit_code = exit_code
return
if self._stopped:
- LOG.debug('Stopped, not respawning a dead sensor', extra=extra)
+ LOG.debug("Stopped, not respawning a dead sensor", extra=extra)
return
- should_respawn = self._should_respawn_sensor(sensor_id=sensor_id, sensor=sensor,
- exit_code=exit_code)
+ should_respawn = self._should_respawn_sensor(
+ sensor_id=sensor_id, sensor=sensor, exit_code=exit_code
+ )
if not should_respawn:
- LOG.debug('Not respawning a dead sensor', extra=extra)
+ LOG.debug("Not respawning a dead sensor", extra=extra)
return
- LOG.debug('Respawning dead sensor', extra=extra)
+ LOG.debug("Respawning dead sensor", extra=extra)
self._sensor_respawn_counts[sensor_id] += 1
- sleep_delay = (SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id])
+ sleep_delay = SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id]
concurrency.sleep(sleep_delay)
try:
@@ -443,7 +476,7 @@ def _should_respawn_sensor(self, sensor_id, sensor, exit_code):
respawn_count = self._sensor_respawn_counts[sensor_id]
if respawn_count >= SENSOR_MAX_RESPAWN_COUNTS:
- LOG.debug('Sensor has already been respawned max times, giving up')
+ LOG.debug("Sensor has already been respawned max times, giving up")
return False
return True
@@ -460,23 +493,23 @@ def _get_args_for_wrapper_script(self, python_binary, sensor):
:rtype: ``list``
"""
- trigger_type_refs = sensor['trigger_types'] or []
- trigger_type_refs = ','.join(trigger_type_refs)
+ trigger_type_refs = sensor["trigger_types"] or []
+ trigger_type_refs = ",".join(trigger_type_refs)
parent_args = json.dumps(sys.argv[1:])
args = [
python_binary,
self._wrapper_script_path,
- '--pack=%s' % (sensor['pack']),
- '--file-path=%s' % (sensor['file_path']),
- '--class-name=%s' % (sensor['class_name']),
- '--trigger-type-refs=%s' % (trigger_type_refs),
- '--parent-args=%s' % (parent_args)
+ "--pack=%s" % (sensor["pack"]),
+ "--file-path=%s" % (sensor["file_path"]),
+ "--class-name=%s" % (sensor["class_name"]),
+ "--trigger-type-refs=%s" % (trigger_type_refs),
+ "--parent-args=%s" % (parent_args),
]
- if sensor['poll_interval']:
- args.append('--poll-interval=%s' % (sensor['poll_interval']))
+ if sensor["poll_interval"]:
+ args.append("--poll-interval=%s" % (sensor["poll_interval"]))
return args
@@ -486,32 +519,28 @@ def _get_sensor_id(self, sensor):
:type sensor: ``dict``
"""
- sensor_id = sensor['ref']
+ sensor_id = sensor["ref"]
return sensor_id
def _dispatch_trigger_for_sensor_spawn(self, sensor, process, cmd):
trigger = ResourceReference.to_string_reference(
- name=SENSOR_SPAWN_TRIGGER['name'],
- pack=SENSOR_SPAWN_TRIGGER['pack'])
+ name=SENSOR_SPAWN_TRIGGER["name"], pack=SENSOR_SPAWN_TRIGGER["pack"]
+ )
now = int(time.time())
payload = {
- 'id': sensor['class_name'],
- 'timestamp': now,
- 'pid': process.pid,
- 'cmd': cmd
+ "id": sensor["class_name"],
+ "timestamp": now,
+ "pid": process.pid,
+ "cmd": cmd,
}
self._dispatcher.dispatch(trigger, payload=payload)
def _dispatch_trigger_for_sensor_exit(self, sensor, exit_code):
trigger = ResourceReference.to_string_reference(
- name=SENSOR_EXIT_TRIGGER['name'],
- pack=SENSOR_EXIT_TRIGGER['pack'])
+ name=SENSOR_EXIT_TRIGGER["name"], pack=SENSOR_EXIT_TRIGGER["pack"]
+ )
now = int(time.time())
- payload = {
- 'id': sensor['class_name'],
- 'timestamp': now,
- 'exit_code': exit_code
- }
+ payload = {"id": sensor["class_name"], "timestamp": now, "exit_code": exit_code}
self._dispatcher.dispatch(trigger, payload=payload)
def _delete_sensor(self, sensor_id):
diff --git a/st2reactor/st2reactor/container/sensor_wrapper.py b/st2reactor/st2reactor/container/sensor_wrapper.py
index 56a37707d2..c605b47291 100644
--- a/st2reactor/st2reactor/container/sensor_wrapper.py
+++ b/st2reactor/st2reactor/container/sensor_wrapper.py
@@ -25,6 +25,7 @@
# for details.
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -51,10 +52,7 @@
from st2common.services.datastore import SensorDatastoreService
from st2common.util.monkey_patch import use_select_poll_workaround
-__all__ = [
- 'SensorWrapper',
- 'SensorService'
-]
+__all__ = ["SensorWrapper", "SensorService"]
use_select_poll_workaround(nose_only=False)
@@ -69,12 +67,15 @@ def __init__(self, sensor_wrapper):
self._sensor_wrapper = sensor_wrapper
self._logger = self._sensor_wrapper._logger
- self._trigger_dispatcher_service = TriggerDispatcherService(logger=sensor_wrapper._logger)
+ self._trigger_dispatcher_service = TriggerDispatcherService(
+ logger=sensor_wrapper._logger
+ )
self._datastore_service = SensorDatastoreService(
logger=self._logger,
pack_name=self._sensor_wrapper._pack,
class_name=self._sensor_wrapper._class_name,
- api_username='sensor_service')
+ api_username="sensor_service",
+ )
self._client = None
@@ -86,7 +87,7 @@ def get_logger(self, name):
"""
Retrieve an instance of a logger to be used by the sensor class.
"""
- logger_name = '%s.%s' % (self._sensor_wrapper._logger.name, name)
+ logger_name = "%s.%s" % (self._sensor_wrapper._logger.name, name)
logger = logging.getLogger(logger_name)
logger.propagate = True
@@ -105,9 +106,12 @@ def get_user_info(self):
def dispatch(self, trigger, payload=None, trace_tag=None):
# Provided by the parent BaseTriggerDispatcherService class
- return self._trigger_dispatcher_service.dispatch(trigger=trigger, payload=payload,
- trace_tag=trace_tag,
- throw_on_validation_error=False)
+ return self._trigger_dispatcher_service.dispatch(
+ trigger=trigger,
+ payload=payload,
+ trace_tag=trace_tag,
+ throw_on_validation_error=False,
+ )
def dispatch_with_context(self, trigger, payload=None, trace_context=None):
"""
@@ -123,10 +127,12 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None):
:type trace_context: ``st2common.api.models.api.trace.TraceContext``
"""
# Provided by the parent BaseTriggerDispatcherService class
- return self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger,
+ return self._trigger_dispatcher_service.dispatch_with_context(
+ trigger=trigger,
payload=payload,
trace_context=trace_context,
- throw_on_validation_error=False)
+ throw_on_validation_error=False,
+ )
##################################
# Methods for datastore management
@@ -136,20 +142,31 @@ def list_values(self, local=True, prefix=None):
return self.datastore_service.list_values(local=local, prefix=prefix)
def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False):
- return self.datastore_service.get_value(name=name, local=local, scope=scope,
- decrypt=decrypt)
+ return self.datastore_service.get_value(
+ name=name, local=local, scope=scope, decrypt=decrypt
+ )
- def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False):
- return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local,
- scope=scope, encrypt=encrypt)
+ def set_value(
+ self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False
+ ):
+ return self.datastore_service.set_value(
+ name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt
+ )
def delete_value(self, name, local=True, scope=SYSTEM_SCOPE):
return self.datastore_service.delete_value(name=name, local=local, scope=scope)
class SensorWrapper(object):
- def __init__(self, pack, file_path, class_name, trigger_types,
- poll_interval=None, parent_args=None):
+ def __init__(
+ self,
+ pack,
+ file_path,
+ class_name,
+ trigger_types,
+ poll_interval=None,
+ parent_args=None,
+ ):
"""
:param pack: Name of the pack this sensor belongs to.
:type pack: ``str``
@@ -185,32 +202,48 @@ def __init__(self, pack, file_path, class_name, trigger_types,
pass
# 2. Establish DB connection
- username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None
- password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None
- db_setup_with_retry(cfg.CONF.database.db_name, cfg.CONF.database.host,
- cfg.CONF.database.port, username=username, password=password,
- ssl=cfg.CONF.database.ssl, ssl_keyfile=cfg.CONF.database.ssl_keyfile,
- ssl_certfile=cfg.CONF.database.ssl_certfile,
- ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs,
- ssl_ca_certs=cfg.CONF.database.ssl_ca_certs,
- authentication_mechanism=cfg.CONF.database.authentication_mechanism,
- ssl_match_hostname=cfg.CONF.database.ssl_match_hostname)
+ username = (
+ cfg.CONF.database.username
+ if hasattr(cfg.CONF.database, "username")
+ else None
+ )
+ password = (
+ cfg.CONF.database.password
+ if hasattr(cfg.CONF.database, "password")
+ else None
+ )
+ db_setup_with_retry(
+ cfg.CONF.database.db_name,
+ cfg.CONF.database.host,
+ cfg.CONF.database.port,
+ username=username,
+ password=password,
+ ssl=cfg.CONF.database.ssl,
+ ssl_keyfile=cfg.CONF.database.ssl_keyfile,
+ ssl_certfile=cfg.CONF.database.ssl_certfile,
+ ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs,
+ ssl_ca_certs=cfg.CONF.database.ssl_ca_certs,
+ authentication_mechanism=cfg.CONF.database.authentication_mechanism,
+ ssl_match_hostname=cfg.CONF.database.ssl_match_hostname,
+ )
# 3. Instantiate the watcher
- self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger,
- update_handler=self._handle_update_trigger,
- delete_handler=self._handle_delete_trigger,
- trigger_types=self._trigger_types,
- queue_suffix='sensorwrapper_%s_%s' %
- (self._pack, self._class_name),
- exclusive=True)
+ self._trigger_watcher = TriggerWatcher(
+ create_handler=self._handle_create_trigger,
+ update_handler=self._handle_update_trigger,
+ delete_handler=self._handle_delete_trigger,
+ trigger_types=self._trigger_types,
+ queue_suffix="sensorwrapper_%s_%s" % (self._pack, self._class_name),
+ exclusive=True,
+ )
# 4. Set up logging
- self._logger = logging.getLogger('SensorWrapper.%s.%s' %
- (self._pack, self._class_name))
+ self._logger = logging.getLogger(
+ "SensorWrapper.%s.%s" % (self._pack, self._class_name)
+ )
logging.setup(cfg.CONF.sensorcontainer.logging)
- if '--debug' in parent_args:
+ if "--debug" in parent_args:
set_log_level_for_all_loggers()
else:
# NOTE: statsd logger logs everything by default under INFO so we ignore those log
@@ -223,16 +256,17 @@ def run(self):
atexit.register(self.stop)
self._trigger_watcher.start()
- self._logger.info('Watcher started')
+ self._logger.info("Watcher started")
- self._logger.info('Running sensor initialization code')
+ self._logger.info("Running sensor initialization code")
self._sensor_instance.setup()
if self._poll_interval:
- message = ('Running sensor in active mode (poll interval=%ss)' %
- (self._poll_interval))
+ message = "Running sensor in active mode (poll interval=%ss)" % (
+ self._poll_interval
+ )
else:
- message = 'Running sensor in passive mode'
+ message = "Running sensor in passive mode"
self._logger.info(message)
@@ -240,18 +274,20 @@ def run(self):
self._sensor_instance.run()
except Exception as e:
# Include traceback
- msg = ('Sensor "%s" run method raised an exception: %s.' %
- (self._class_name, six.text_type(e)))
+ msg = 'Sensor "%s" run method raised an exception: %s.' % (
+ self._class_name,
+ six.text_type(e),
+ )
self._logger.warn(msg, exc_info=True)
raise Exception(msg)
def stop(self):
# Stop watcher
- self._logger.info('Stopping trigger watcher')
+ self._logger.info("Stopping trigger watcher")
self._trigger_watcher.stop()
# Run sensor cleanup code
- self._logger.info('Invoking cleanup on sensor')
+ self._logger.info("Invoking cleanup on sensor")
self._sensor_instance.cleanup()
##############################################
@@ -259,16 +295,18 @@ def stop(self):
##############################################
def _handle_create_trigger(self, trigger):
- self._logger.debug('Calling sensor "add_trigger" method (trigger.type=%s)' %
- (trigger.type))
+ self._logger.debug(
+ 'Calling sensor "add_trigger" method (trigger.type=%s)' % (trigger.type)
+ )
self._trigger_names[str(trigger.id)] = trigger
trigger = self._sanitize_trigger(trigger=trigger)
self._sensor_instance.add_trigger(trigger=trigger)
def _handle_update_trigger(self, trigger):
- self._logger.debug('Calling sensor "update_trigger" method (trigger.type=%s)' %
- (trigger.type))
+ self._logger.debug(
+ 'Calling sensor "update_trigger" method (trigger.type=%s)' % (trigger.type)
+ )
self._trigger_names[str(trigger.id)] = trigger
trigger = self._sanitize_trigger(trigger=trigger)
@@ -279,8 +317,9 @@ def _handle_delete_trigger(self, trigger):
if trigger_id not in self._trigger_names:
return
- self._logger.debug('Calling sensor "remove_trigger" method (trigger.type=%s)' %
- (trigger.type))
+ self._logger.debug(
+ 'Calling sensor "remove_trigger" method (trigger.type=%s)' % (trigger.type)
+ )
del self._trigger_names[trigger_id]
trigger = self._sanitize_trigger(trigger=trigger)
@@ -294,35 +333,45 @@ def _get_sensor_instance(self):
module_name, _ = os.path.splitext(filename)
try:
- sensor_class = loader.register_plugin_class(base_class=Sensor,
- file_path=self._file_path,
- class_name=self._class_name)
+ sensor_class = loader.register_plugin_class(
+ base_class=Sensor,
+ file_path=self._file_path,
+ class_name=self._class_name,
+ )
except Exception as e:
tb_msg = traceback.format_exc()
- msg = ('Failed to load sensor class from file "%s" (sensor file most likely doesn\'t '
- 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e)))
- msg += '\n\n' + tb_msg
+ msg = (
+ 'Failed to load sensor class from file "%s" (sensor file most likely doesn\'t '
+ "exist or contains invalid syntax): %s"
+ % (self._file_path, six.text_type(e))
+ )
+ msg += "\n\n" + tb_msg
exc_cls = type(e)
raise exc_cls(msg)
if not sensor_class:
- raise ValueError('Sensor module is missing a class with name "%s"' %
- (self._class_name))
+ raise ValueError(
+ 'Sensor module is missing a class with name "%s"' % (self._class_name)
+ )
sensor_class_kwargs = {}
- sensor_class_kwargs['sensor_service'] = SensorService(sensor_wrapper=self)
+ sensor_class_kwargs["sensor_service"] = SensorService(sensor_wrapper=self)
sensor_config = self._get_sensor_config()
- sensor_class_kwargs['config'] = sensor_config
+ sensor_class_kwargs["config"] = sensor_config
if self._poll_interval and issubclass(sensor_class, PollingSensor):
- sensor_class_kwargs['poll_interval'] = self._poll_interval
+ sensor_class_kwargs["poll_interval"] = self._poll_interval
try:
sensor_instance = sensor_class(**sensor_class_kwargs)
except Exception:
- self._logger.exception('Failed to instantiate "%s" sensor class' % (self._class_name))
- raise Exception('Failed to instantiate "%s" sensor class' % (self._class_name))
+ self._logger.exception(
+ 'Failed to instantiate "%s" sensor class' % (self._class_name)
+ )
+ raise Exception(
+ 'Failed to instantiate "%s" sensor class' % (self._class_name)
+ )
return sensor_instance
@@ -342,31 +391,43 @@ def _sanitize_trigger(self, trigger):
return sanitized
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Sensor runner wrapper')
- parser.add_argument('--pack', required=True,
- help='Name of the pack this sensor belongs to')
- parser.add_argument('--file-path', required=True,
- help='Path to the sensor module')
- parser.add_argument('--class-name', required=True,
- help='Name of the sensor class')
- parser.add_argument('--trigger-type-refs', required=False,
- help='Comma delimited string of trigger type references')
- parser.add_argument('--poll-interval', type=int, default=None, required=False,
- help='Sensor poll interval')
- parser.add_argument('--parent-args', required=False,
- help='Command line arguments passed to the parent process')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Sensor runner wrapper")
+ parser.add_argument(
+ "--pack", required=True, help="Name of the pack this sensor belongs to"
+ )
+ parser.add_argument("--file-path", required=True, help="Path to the sensor module")
+ parser.add_argument("--class-name", required=True, help="Name of the sensor class")
+ parser.add_argument(
+ "--trigger-type-refs",
+ required=False,
+ help="Comma delimited string of trigger type references",
+ )
+ parser.add_argument(
+ "--poll-interval",
+ type=int,
+ default=None,
+ required=False,
+ help="Sensor poll interval",
+ )
+ parser.add_argument(
+ "--parent-args",
+ required=False,
+ help="Command line arguments passed to the parent process",
+ )
args = parser.parse_args()
trigger_types = args.trigger_type_refs
- trigger_types = trigger_types.split(',') if trigger_types else []
+ trigger_types = trigger_types.split(",") if trigger_types else []
parent_args = json.loads(args.parent_args) if args.parent_args else []
assert isinstance(parent_args, list)
- obj = SensorWrapper(pack=args.pack,
- file_path=args.file_path,
- class_name=args.class_name,
- trigger_types=trigger_types,
- poll_interval=args.poll_interval,
- parent_args=parent_args)
+ obj = SensorWrapper(
+ pack=args.pack,
+ file_path=args.file_path,
+ class_name=args.class_name,
+ trigger_types=trigger_types,
+ poll_interval=args.poll_interval,
+ parent_args=parent_args,
+ )
obj.run()
diff --git a/st2reactor/st2reactor/container/utils.py b/st2reactor/st2reactor/container/utils.py
index a156d209b0..6b05904627 100644
--- a/st2reactor/st2reactor/container/utils.py
+++ b/st2reactor/st2reactor/container/utils.py
@@ -22,10 +22,12 @@
from st2common.persistence.trigger import TriggerInstance
from st2common.services.triggers import get_trigger_db_by_ref_or_dict
-LOG = logging.getLogger('st2reactor.sensor.container_utils')
+LOG = logging.getLogger("st2reactor.sensor.container_utils")
-def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigger=False):
+def create_trigger_instance(
+ trigger, payload, occurrence_time, raise_on_no_trigger=False
+):
"""
This creates a trigger instance object given trigger and payload.
Trigger can be just a string reference (pack.name) or a ``dict`` containing 'id' or
@@ -40,9 +42,9 @@ def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigg
trigger_db = get_trigger_db_by_ref_or_dict(trigger=trigger)
if not trigger_db:
- LOG.debug('No trigger in db for %s', trigger)
+ LOG.debug("No trigger in db for %s", trigger)
if raise_on_no_trigger:
- raise StackStormDBObjectNotFoundError('Trigger not found for %s' % trigger)
+ raise StackStormDBObjectNotFoundError("Trigger not found for %s" % trigger)
return None
trigger_ref = trigger_db.get_reference().ref
diff --git a/st2reactor/st2reactor/garbage_collector/base.py b/st2reactor/st2reactor/garbage_collector/base.py
index 3261458677..bb963e8e51 100644
--- a/st2reactor/st2reactor/garbage_collector/base.py
+++ b/st2reactor/st2reactor/garbage_collector/base.py
@@ -42,16 +42,17 @@
from st2common.garbage_collection.inquiries import purge_inquiries
from st2common.garbage_collection.trigger_instances import purge_trigger_instances
-__all__ = [
- 'GarbageCollectorService'
-]
+__all__ = ["GarbageCollectorService"]
LOG = logging.getLogger(__name__)
class GarbageCollectorService(object):
- def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL,
- sleep_delay=DEFAULT_SLEEP_DELAY):
+ def __init__(
+ self,
+ collection_interval=DEFAULT_COLLECTION_INTERVAL,
+ sleep_delay=DEFAULT_SLEEP_DELAY,
+ ):
"""
:param collection_interval: How often to check database for old data and perform garbage
collection.
@@ -64,7 +65,9 @@ def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL,
self._collection_interval = collection_interval
self._action_executions_ttl = cfg.CONF.garbagecollector.action_executions_ttl
- self._action_executions_output_ttl = cfg.CONF.garbagecollector.action_executions_output_ttl
+ self._action_executions_output_ttl = (
+ cfg.CONF.garbagecollector.action_executions_output_ttl
+ )
self._trigger_instances_ttl = cfg.CONF.garbagecollector.trigger_instances_ttl
self._purge_inquiries = cfg.CONF.garbagecollector.purge_inquiries
self._workflow_execution_max_idle = cfg.CONF.workflow_engine.gc_max_idle_sec
@@ -91,7 +94,7 @@ def run(self):
self._running = False
return SUCCESS_EXIT_CODE
except Exception as e:
- LOG.exception('Exception in the garbage collector: %s' % (six.text_type(e)))
+ LOG.exception("Exception in the garbage collector: %s" % (six.text_type(e)))
self._running = False
return FAILURE_EXIT_CODE
@@ -101,7 +104,7 @@ def _register_signal_handlers(self):
signal.signal(signal.SIGUSR2, self.handle_sigusr2)
def handle_sigusr2(self, signal_number, stack_frame):
- LOG.info('Forcing garbage collection...')
+ LOG.info("Forcing garbage collection...")
self._perform_garbage_collection()
def shutdown(self):
@@ -111,61 +114,88 @@ def _main_loop(self):
while self._running:
self._perform_garbage_collection()
- LOG.info('Sleeping for %s seconds before next garbage collection...' %
- (self._collection_interval))
+ LOG.info(
+ "Sleeping for %s seconds before next garbage collection..."
+ % (self._collection_interval)
+ )
concurrency.sleep(self._collection_interval)
def _validate_ttl_values(self):
"""
Validate that a user has supplied reasonable TTL values.
"""
- if self._action_executions_ttl and self._action_executions_ttl < MINIMUM_TTL_DAYS:
- raise ValueError('Minimum possible TTL for action_executions_ttl in days is %s' %
- (MINIMUM_TTL_DAYS))
-
- if self._trigger_instances_ttl and self._trigger_instances_ttl < MINIMUM_TTL_DAYS:
- raise ValueError('Minimum possible TTL for trigger_instances_ttl in days is %s' %
- (MINIMUM_TTL_DAYS))
-
- if self._action_executions_output_ttl and \
- self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT:
- raise ValueError(('Minimum possible TTL for action_executions_output_ttl in days '
- 'is %s') % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT))
+ if (
+ self._action_executions_ttl
+ and self._action_executions_ttl < MINIMUM_TTL_DAYS
+ ):
+ raise ValueError(
+ "Minimum possible TTL for action_executions_ttl in days is %s"
+ % (MINIMUM_TTL_DAYS)
+ )
+
+ if (
+ self._trigger_instances_ttl
+ and self._trigger_instances_ttl < MINIMUM_TTL_DAYS
+ ):
+ raise ValueError(
+ "Minimum possible TTL for trigger_instances_ttl in days is %s"
+ % (MINIMUM_TTL_DAYS)
+ )
+
+ if (
+ self._action_executions_output_ttl
+ and self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT
+ ):
+ raise ValueError(
+ (
+ "Minimum possible TTL for action_executions_output_ttl in days "
+ "is %s"
+ )
+ % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)
+ )
def _perform_garbage_collection(self):
- LOG.info('Performing garbage collection...')
+ LOG.info("Performing garbage collection...")
proc_message = "Performing garbage collection for %s."
skip_message = "Skipping garbage collection for %s since it's not configured."
# Note: We sleep for a bit between garbage collection of each object type to prevent busy
# waiting
- obj_type = 'action executions'
- if self._action_executions_ttl and self._action_executions_ttl >= MINIMUM_TTL_DAYS:
+ obj_type = "action executions"
+ if (
+ self._action_executions_ttl
+ and self._action_executions_ttl >= MINIMUM_TTL_DAYS
+ ):
LOG.info(proc_message, obj_type)
self._purge_action_executions()
concurrency.sleep(self._sleep_delay)
else:
LOG.debug(skip_message, obj_type)
- obj_type = 'action executions output'
- if self._action_executions_output_ttl and \
- self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT:
+ obj_type = "action executions output"
+ if (
+ self._action_executions_output_ttl
+ and self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT
+ ):
LOG.info(proc_message, obj_type)
self._purge_action_executions_output()
concurrency.sleep(self._sleep_delay)
else:
LOG.debug(skip_message, obj_type)
- obj_type = 'trigger instances'
- if self._trigger_instances_ttl and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS:
+ obj_type = "trigger instances"
+ if (
+ self._trigger_instances_ttl
+ and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS
+ ):
LOG.info(proc_message, obj_type)
self._purge_trigger_instances()
concurrency.sleep(self._sleep_delay)
else:
LOG.debug(skip_message, obj_type)
- obj_type = 'inquiries'
+ obj_type = "inquiries"
if self._purge_inquiries:
LOG.info(proc_message, obj_type)
self._timeout_inquiries()
@@ -173,7 +203,7 @@ def _perform_garbage_collection(self):
else:
LOG.debug(skip_message, obj_type)
- obj_type = 'orphaned workflow executions'
+ obj_type = "orphaned workflow executions"
if self._workflow_execution_max_idle > 0:
LOG.info(proc_message, obj_type)
self._purge_orphaned_workflow_executions()
@@ -187,41 +217,53 @@ def _purge_action_executions(self):
the criteria defined in the config.
"""
utc_now = get_datetime_utc_now()
- timestamp = (utc_now - datetime.timedelta(days=self._action_executions_ttl))
+ timestamp = utc_now - datetime.timedelta(days=self._action_executions_ttl)
# Another sanity check to make sure we don't delete new executions
if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)):
- raise ValueError('Calculated timestamp would violate the minimum TTL constraint')
+ raise ValueError(
+ "Calculated timestamp would violate the minimum TTL constraint"
+ )
timestamp_str = isotime.format(dt=timestamp)
- LOG.info('Deleting action executions older than: %s' % (timestamp_str))
+ LOG.info("Deleting action executions older than: %s" % (timestamp_str))
assert timestamp < utc_now
try:
purge_executions(logger=LOG, timestamp=timestamp)
except Exception as e:
- LOG.exception('Failed to delete executions: %s' % (six.text_type(e)))
+ LOG.exception("Failed to delete executions: %s" % (six.text_type(e)))
return True
def _purge_action_executions_output(self):
utc_now = get_datetime_utc_now()
- timestamp = (utc_now - datetime.timedelta(days=self._action_executions_output_ttl))
+ timestamp = utc_now - datetime.timedelta(
+ days=self._action_executions_output_ttl
+ )
# Another sanity check to make sure we don't delete new objects
- if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)):
- raise ValueError('Calculated timestamp would violate the minimum TTL constraint')
+ if timestamp > (
+ utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)
+ ):
+ raise ValueError(
+ "Calculated timestamp would violate the minimum TTL constraint"
+ )
timestamp_str = isotime.format(dt=timestamp)
- LOG.info('Deleting action executions output objects older than: %s' % (timestamp_str))
+ LOG.info(
+ "Deleting action executions output objects older than: %s" % (timestamp_str)
+ )
assert timestamp < utc_now
try:
purge_execution_output_objects(logger=LOG, timestamp=timestamp)
except Exception as e:
- LOG.exception('Failed to delete execution output objects: %s' % (six.text_type(e)))
+ LOG.exception(
+ "Failed to delete execution output objects: %s" % (six.text_type(e))
+ )
return True
@@ -230,31 +272,32 @@ def _purge_trigger_instances(self):
Purge trigger instances which match the criteria defined in the config.
"""
utc_now = get_datetime_utc_now()
- timestamp = (utc_now - datetime.timedelta(days=self._trigger_instances_ttl))
+ timestamp = utc_now - datetime.timedelta(days=self._trigger_instances_ttl)
# Another sanity check to make sure we don't delete new executions
if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)):
- raise ValueError('Calculated timestamp would violate the minimum TTL constraint')
+ raise ValueError(
+ "Calculated timestamp would violate the minimum TTL constraint"
+ )
timestamp_str = isotime.format(dt=timestamp)
- LOG.info('Deleting trigger instances older than: %s' % (timestamp_str))
+ LOG.info("Deleting trigger instances older than: %s" % (timestamp_str))
assert timestamp < utc_now
try:
purge_trigger_instances(logger=LOG, timestamp=timestamp)
except Exception as e:
- LOG.exception('Failed to trigger instances: %s' % (six.text_type(e)))
+ LOG.exception("Failed to trigger instances: %s" % (six.text_type(e)))
return True
def _timeout_inquiries(self):
- """Mark Inquiries as "timeout" that have exceeded their TTL
- """
+ """Mark Inquiries as "timeout" that have exceeded their TTL"""
try:
purge_inquiries(logger=LOG)
except Exception as e:
- LOG.exception('Failed to purge inquiries: %s' % (six.text_type(e)))
+ LOG.exception("Failed to purge inquiries: %s" % (six.text_type(e)))
return True
@@ -265,6 +308,8 @@ def _purge_orphaned_workflow_executions(self):
try:
purge_orphaned_workflow_executions(logger=LOG)
except Exception as e:
- LOG.exception('Failed to purge orphaned workflow executions: %s' % (six.text_type(e)))
+ LOG.exception(
+ "Failed to purge orphaned workflow executions: %s" % (six.text_type(e))
+ )
return True
diff --git a/st2reactor/st2reactor/garbage_collector/config.py b/st2reactor/st2reactor/garbage_collector/config.py
index 19cf53362e..9a0faf0dec 100644
--- a/st2reactor/st2reactor/garbage_collector/config.py
+++ b/st2reactor/st2reactor/garbage_collector/config.py
@@ -29,8 +29,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -49,48 +52,62 @@ def _register_common_opts():
def _register_garbage_collector_opts():
logging_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.garbagecollector.conf',
- help='Location of the logging configuration file.')
+ "logging",
+ default="/etc/st2/logging.garbagecollector.conf",
+ help="Location of the logging configuration file.",
+ )
]
- CONF.register_opts(logging_opts, group='garbagecollector')
+ CONF.register_opts(logging_opts, group="garbagecollector")
common_opts = [
cfg.IntOpt(
- 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL,
- help='How often to check database for old data and perform garbage collection.'),
+ "collection_interval",
+ default=DEFAULT_COLLECTION_INTERVAL,
+ help="How often to check database for old data and perform garbage collection.",
+ ),
cfg.FloatOpt(
- 'sleep_delay', default=DEFAULT_SLEEP_DELAY,
- help='How long to wait / sleep (in seconds) between '
- 'collection of different object types.')
+ "sleep_delay",
+ default=DEFAULT_SLEEP_DELAY,
+ help="How long to wait / sleep (in seconds) between "
+ "collection of different object types.",
+ ),
]
- CONF.register_opts(common_opts, group='garbagecollector')
+ CONF.register_opts(common_opts, group="garbagecollector")
ttl_opts = [
cfg.IntOpt(
- 'action_executions_ttl', default=None,
- help='Action executions and related objects (live actions, action output '
- 'objects) older than this value (days) will be automatically deleted.'),
+ "action_executions_ttl",
+ default=None,
+ help="Action executions and related objects (live actions, action output "
+ "objects) older than this value (days) will be automatically deleted.",
+ ),
cfg.IntOpt(
- 'action_executions_output_ttl', default=7,
- help='Action execution output objects (ones generated by action output '
- 'streaming) older than this value (days) will be automatically deleted.'),
+ "action_executions_output_ttl",
+ default=7,
+ help="Action execution output objects (ones generated by action output "
+ "streaming) older than this value (days) will be automatically deleted.",
+ ),
cfg.IntOpt(
- 'trigger_instances_ttl', default=None,
- help='Trigger instances older than this value (days) will be automatically deleted.')
+ "trigger_instances_ttl",
+ default=None,
+ help="Trigger instances older than this value (days) will be automatically deleted.",
+ ),
]
- CONF.register_opts(ttl_opts, group='garbagecollector')
+ CONF.register_opts(ttl_opts, group="garbagecollector")
inquiry_opts = [
cfg.BoolOpt(
- 'purge_inquiries', default=False,
- help='Set to True to perform garbage collection on Inquiries (based on '
- 'the TTL value per Inquiry)')
+ "purge_inquiries",
+ default=False,
+ help="Set to True to perform garbage collection on Inquiries (based on "
+ "the TTL value per Inquiry)",
+ )
]
- CONF.register_opts(inquiry_opts, group='garbagecollector')
+ CONF.register_opts(inquiry_opts, group="garbagecollector")
register_opts()
diff --git a/st2reactor/st2reactor/rules/config.py b/st2reactor/st2reactor/rules/config.py
index 004c45b870..637ef4e457 100644
--- a/st2reactor/st2reactor/rules/config.py
+++ b/st2reactor/st2reactor/rules/config.py
@@ -27,8 +27,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -47,11 +50,13 @@ def _register_common_opts():
def _register_rules_engine_opts():
logging_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.rulesengine.conf',
- help='Location of the logging configuration file.')
+ "logging",
+ default="/etc/st2/logging.rulesengine.conf",
+ help="Location of the logging configuration file.",
+ )
]
- CONF.register_opts(logging_opts, group='rulesengine')
+ CONF.register_opts(logging_opts, group="rulesengine")
register_opts()
diff --git a/st2reactor/st2reactor/rules/enforcer.py b/st2reactor/st2reactor/rules/enforcer.py
index 594f157482..4d34b86ce2 100644
--- a/st2reactor/st2reactor/rules/enforcer.py
+++ b/st2reactor/st2reactor/rules/enforcer.py
@@ -40,15 +40,15 @@
from st2common.exceptions import param as param_exc
from st2common.exceptions import apivalidation as validation_exc
-__all__ = [
- 'RuleEnforcer'
-]
+__all__ = ["RuleEnforcer"]
-LOG = logging.getLogger('st2reactor.ruleenforcement.enforce')
+LOG = logging.getLogger("st2reactor.ruleenforcement.enforce")
-EXEC_KICKED_OFF_STATES = [action_constants.LIVEACTION_STATUS_SCHEDULED,
- action_constants.LIVEACTION_STATUS_REQUESTED]
+EXEC_KICKED_OFF_STATES = [
+ action_constants.LIVEACTION_STATUS_SCHEDULED,
+ action_constants.LIVEACTION_STATUS_REQUESTED,
+]
class RuleEnforcer(object):
@@ -58,95 +58,117 @@ def __init__(self, trigger_instance, rule):
def get_action_execution_context(self, action_db, trace_context=None):
context = {
- 'trigger_instance': reference.get_ref_from_model(self.trigger_instance),
- 'rule': reference.get_ref_from_model(self.rule),
- 'user': get_system_username(),
- 'pack': action_db.pack,
+ "trigger_instance": reference.get_ref_from_model(self.trigger_instance),
+ "rule": reference.get_ref_from_model(self.rule),
+ "user": get_system_username(),
+ "pack": action_db.pack,
}
if trace_context is not None:
context[TRACE_CONTEXT] = trace_context
# Additional non-action / global context
- additional_context = {
- TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload
- }
+ additional_context = {TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload}
return context, additional_context
- def get_resolved_parameters(self, action_db, runnertype_db, params, context=None,
- additional_contexts=None):
+ def get_resolved_parameters(
+ self, action_db, runnertype_db, params, context=None, additional_contexts=None
+ ):
resolved_params = param_utils.render_live_params(
runner_parameters=runnertype_db.runner_parameters,
action_parameters=action_db.parameters,
params=params,
action_context=context,
- additional_contexts=additional_contexts)
+ additional_contexts=additional_contexts,
+ )
return resolved_params
def enforce(self):
- rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid}
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id),
- rule=rule_spec)
- extra = {
- 'trigger_instance_db': self.trigger_instance,
- 'rule_db': self.rule
+ rule_spec = {
+ "ref": self.rule.ref,
+ "id": str(self.rule.id),
+ "uid": self.rule.uid,
}
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(self.trigger_instance.id), rule=rule_spec
+ )
+ extra = {"trigger_instance_db": self.trigger_instance, "rule_db": self.rule}
execution_db = None
try:
execution_db = self._do_enforce()
# pylint: disable=no-member
enforcement_db.execution_id = str(execution_db.id)
enforcement_db.status = RULE_ENFORCEMENT_STATUS_SUCCEEDED
- extra['execution_db'] = execution_db
+ extra["execution_db"] = execution_db
except Exception as e:
# Record the failure reason in the RuleEnforcement.
enforcement_db.status = RULE_ENFORCEMENT_STATUS_FAILED
enforcement_db.failure_reason = six.text_type(e)
- LOG.exception('Failed kicking off execution for rule %s.', self.rule, extra=extra)
+ LOG.exception(
+ "Failed kicking off execution for rule %s.", self.rule, extra=extra
+ )
finally:
self._update_enforcement(enforcement_db)
# pylint: disable=no-member
if not execution_db or execution_db.status not in EXEC_KICKED_OFF_STATES:
- LOG.audit('Rule enforcement failed. Execution of Action %s failed. '
- 'TriggerInstance: %s and Rule: %s',
- self.rule.action.ref, self.trigger_instance, self.rule,
- extra=extra)
+ LOG.audit(
+ "Rule enforcement failed. Execution of Action %s failed. "
+ "TriggerInstance: %s and Rule: %s",
+ self.rule.action.ref,
+ self.trigger_instance,
+ self.rule,
+ extra=extra,
+ )
else:
- LOG.audit('Rule enforced. Execution %s, TriggerInstance %s and Rule %s.',
- execution_db, self.trigger_instance, self.rule, extra=extra)
+ LOG.audit(
+ "Rule enforced. Execution %s, TriggerInstance %s and Rule %s.",
+ execution_db,
+ self.trigger_instance,
+ self.rule,
+ extra=extra,
+ )
return execution_db
def _do_enforce(self):
# TODO: Refactor this to avoid additional lookup in cast_params
- action_ref = self.rule.action['ref']
+ action_ref = self.rule.action["ref"]
# Verify action referenced in the rule exists in the database
action_db = action_utils.get_action_by_ref(action_ref)
if not action_db:
raise ValueError('Action "%s" doesn\'t exist' % (action_ref))
- runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_utils.get_runnertype_by_name(
+ action_db.runner_type["name"]
+ )
params = self.rule.action.parameters
- LOG.info('Invoking action %s for trigger_instance %s with params %s.',
- self.rule.action.ref, self.trigger_instance.id,
- json.dumps(params))
+ LOG.info(
+ "Invoking action %s for trigger_instance %s with params %s.",
+ self.rule.action.ref,
+ self.trigger_instance.id,
+ json.dumps(params),
+ )
# update trace before invoking the action.
trace_context = self._update_trace()
- LOG.debug('Updated trace %s with rule %s.', trace_context, self.rule.id)
+ LOG.debug("Updated trace %s with rule %s.", trace_context, self.rule.id)
context, additional_contexts = self.get_action_execution_context(
- action_db=action_db,
- trace_context=trace_context)
+ action_db=action_db, trace_context=trace_context
+ )
- return self._invoke_action(action_db=action_db, runnertype_db=runnertype_db, params=params,
- context=context,
- additional_contexts=additional_contexts)
+ return self._invoke_action(
+ action_db=action_db,
+ runnertype_db=runnertype_db,
+ params=params,
+ context=context,
+ additional_contexts=additional_contexts,
+ )
def _update_trace(self):
"""
@@ -154,9 +176,13 @@ def _update_trace(self):
"""
trace_db = None
try:
- trace_db = trace_service.get_trace_db_by_trigger_instance(self.trigger_instance)
+ trace_db = trace_service.get_trace_db_by_trigger_instance(
+ self.trigger_instance
+ )
except:
- LOG.exception('No Trace found for TriggerInstance %s.', self.trigger_instance.id)
+ LOG.exception(
+ "No Trace found for TriggerInstance %s.", self.trigger_instance.id
+ )
return None
# This would signify some sort of coding error so assert.
@@ -165,19 +191,23 @@ def _update_trace(self):
trace_db = trace_service.add_or_update_given_trace_db(
trace_db=trace_db,
rules=[
- trace_service.get_trace_component_for_rule(self.rule, self.trigger_instance)
- ])
+ trace_service.get_trace_component_for_rule(
+ self.rule, self.trigger_instance
+ )
+ ],
+ )
return vars(TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag))
def _update_enforcement(self, enforcement_db):
try:
RuleEnforcement.add_or_update(enforcement_db)
except:
- extra = {'enforcement_db': enforcement_db}
- LOG.exception('Failed writing enforcement model to db.', extra=extra)
+ extra = {"enforcement_db": enforcement_db}
+ LOG.exception("Failed writing enforcement model to db.", extra=extra)
- def _invoke_action(self, action_db, runnertype_db, params, context=None,
- additional_contexts=None):
+ def _invoke_action(
+ self, action_db, runnertype_db, params, context=None, additional_contexts=None
+ ):
"""
Schedule an action execution.
@@ -189,9 +219,13 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None,
:rtype: :class:`LiveActionDB` on successful scheduling, None otherwise.
"""
action_ref = action_db.ref
- runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name'])
+ runnertype_db = action_utils.get_runnertype_by_name(
+ action_db.runner_type["name"]
+ )
- liveaction_db = LiveActionDB(action=action_ref, context=context, parameters=params)
+ liveaction_db = LiveActionDB(
+ action=action_ref, context=context, parameters=params
+ )
try:
liveaction_db.parameters = self.get_resolved_parameters(
@@ -199,7 +233,8 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None,
action_db=action_db,
params=liveaction_db.parameters,
context=liveaction_db.context,
- additional_contexts=additional_contexts)
+ additional_contexts=additional_contexts,
+ )
except param_exc.ParamException as e:
# We still need to create a request, so liveaction_db is assigned an ID
liveaction_db, execution_db = action_service.create_request(liveaction_db)
@@ -209,8 +244,11 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None,
action_service.update_status(
liveaction=liveaction_db,
new_status=action_constants.LIVEACTION_STATUS_FAILED,
- result={'error': six.text_type(e),
- 'traceback': ''.join(traceback.format_tb(tb, 20))})
+ result={
+ "error": six.text_type(e),
+ "traceback": "".join(traceback.format_tb(tb, 20)),
+ },
+ )
# Might be a good idea to return the actual ActionExecution rather than bubble up
# the exception.
diff --git a/st2reactor/st2reactor/rules/engine.py b/st2reactor/st2reactor/rules/engine.py
index 1d50d01c9e..453a0457da 100644
--- a/st2reactor/st2reactor/rules/engine.py
+++ b/st2reactor/st2reactor/rules/engine.py
@@ -21,11 +21,9 @@
from st2reactor.rules.matcher import RulesMatcher
from st2common.metrics.base import get_driver
-LOG = logging.getLogger('st2reactor.rules.RulesEngine')
+LOG = logging.getLogger("st2reactor.rules.RulesEngine")
-__all__ = [
- 'RulesEngine'
-]
+__all__ = ["RulesEngine"]
class RulesEngine(object):
@@ -40,7 +38,10 @@ def handle_trigger_instance(self, trigger_instance):
# Enforce the rules.
self.enforce_rules(enforcers)
else:
- LOG.info('No matching rules found for trigger instance %s.', trigger_instance['id'])
+ LOG.info(
+ "No matching rules found for trigger instance %s.",
+ trigger_instance["id"],
+ )
def get_matching_rules_for_trigger(self, trigger_instance):
trigger = trigger_instance.trigger
@@ -48,23 +49,34 @@ def get_matching_rules_for_trigger(self, trigger_instance):
trigger_db = get_trigger_db_by_ref(trigger_instance.trigger)
if not trigger_db:
- LOG.error('No matching trigger found in db for trigger instance %s.', trigger_instance)
+ LOG.error(
+ "No matching trigger found in db for trigger instance %s.",
+ trigger_instance,
+ )
return None
rules = get_rules_given_trigger(trigger=trigger)
- LOG.info('Found %d rules defined for trigger %s', len(rules),
- trigger_db.get_reference().ref)
+ LOG.info(
+ "Found %d rules defined for trigger %s",
+ len(rules),
+ trigger_db.get_reference().ref,
+ )
if len(rules) < 1:
return rules
- matcher = RulesMatcher(trigger_instance=trigger_instance,
- trigger=trigger_db, rules=rules)
+ matcher = RulesMatcher(
+ trigger_instance=trigger_instance, trigger=trigger_db, rules=rules
+ )
matching_rules = matcher.get_matching_rules()
- LOG.info('Matched %s rule(s) for trigger_instance %s (trigger=%s)', len(matching_rules),
- trigger_instance['id'], trigger_db.ref)
+ LOG.info(
+ "Matched %s rule(s) for trigger_instance %s (trigger=%s)",
+ len(matching_rules),
+ trigger_instance["id"],
+ trigger_db.ref,
+ )
return matching_rules
def create_rule_enforcers(self, trigger_instance, matching_rules):
@@ -78,8 +90,8 @@ def create_rule_enforcers(self, trigger_instance, matching_rules):
enforcers = []
for matching_rule in matching_rules:
- metrics_driver.inc_counter('rule.matched')
- metrics_driver.inc_counter('rule.%s.matched' % (matching_rule.ref))
+ metrics_driver.inc_counter("rule.matched")
+ metrics_driver.inc_counter("rule.%s.matched" % (matching_rule.ref))
enforcers.append(RuleEnforcer(trigger_instance, matching_rule))
return enforcers
@@ -89,4 +101,4 @@ def enforce_rules(self, enforcers):
try:
enforcer.enforce() # Should this happen in an eventlet pool?
except:
- LOG.exception('Exception enforcing rule %s.', enforcer.rule)
+ LOG.exception("Exception enforcing rule %s.", enforcer.rule)
diff --git a/st2reactor/st2reactor/rules/filter.py b/st2reactor/st2reactor/rules/filter.py
index 1c67538198..700d072c31 100644
--- a/st2reactor/st2reactor/rules/filter.py
+++ b/st2reactor/st2reactor/rules/filter.py
@@ -31,12 +31,10 @@
from st2common.util.payload import PayloadLookup
from st2common.util.templating import render_template_with_system_context
-__all__ = [
- 'RuleFilter'
-]
+__all__ = ["RuleFilter"]
-LOG = logging.getLogger('st2reactor.ruleenforcement.filter')
+LOG = logging.getLogger("st2reactor.ruleenforcement.filter")
class RuleFilter(object):
@@ -58,9 +56,9 @@ def __init__(self, trigger_instance, trigger, rule, extra_info=False):
# Base context used with a logger
self._base_logger_context = {
- 'rule': self.rule,
- 'trigger': self.trigger,
- 'trigger_instance': self.trigger_instance
+ "rule": self.rule,
+ "trigger": self.trigger,
+ "trigger_instance": self.trigger_instance,
}
def filter(self):
@@ -69,12 +67,18 @@ def filter(self):
:rtype: ``bool``
"""
- LOG.info('Validating rule %s for %s.', self.rule.ref, self.trigger['name'],
- extra=self._base_logger_context)
+ LOG.info(
+ "Validating rule %s for %s.",
+ self.rule.ref,
+ self.trigger["name"],
+ extra=self._base_logger_context,
+ )
if not self.rule.enabled:
if self.extra_info:
- LOG.info('Validation failed for rule %s as it is disabled.', self.rule.ref)
+ LOG.info(
+ "Validation failed for rule %s as it is disabled.", self.rule.ref
+ )
return False
criteria = self.rule.criteria
@@ -85,52 +89,66 @@ def filter(self):
payload_lookup = PayloadLookup(self.trigger_instance.payload)
- LOG.debug('Trigger payload: %s', self.trigger_instance.payload,
- extra=self._base_logger_context)
+ LOG.debug(
+ "Trigger payload: %s",
+ self.trigger_instance.payload,
+ extra=self._base_logger_context,
+ )
for (criterion_k, criterion_v) in six.iteritems(criteria):
- is_rule_applicable, payload_value, criterion_pattern = self._check_criterion(
- criterion_k,
- criterion_v,
- payload_lookup
- )
+ (
+ is_rule_applicable,
+ payload_value,
+ criterion_pattern,
+ ) = self._check_criterion(criterion_k, criterion_v, payload_lookup)
if not is_rule_applicable:
if self.extra_info:
- criteria_extra_info = '\n'.join([
- ' key: %s' % criterion_k,
- ' pattern: %s' % criterion_pattern,
- ' type: %s' % criterion_v['type'],
- ' payload: %s' % payload_value
- ])
- LOG.info('Validation for rule %s failed on criteria -\n%s', self.rule.ref,
- criteria_extra_info,
- extra=self._base_logger_context)
+ criteria_extra_info = "\n".join(
+ [
+ " key: %s" % criterion_k,
+ " pattern: %s" % criterion_pattern,
+ " type: %s" % criterion_v["type"],
+ " payload: %s" % payload_value,
+ ]
+ )
+ LOG.info(
+ "Validation for rule %s failed on criteria -\n%s",
+ self.rule.ref,
+ criteria_extra_info,
+ extra=self._base_logger_context,
+ )
break
if not is_rule_applicable:
- LOG.debug('Rule %s not applicable for %s.', self.rule.id, self.trigger['name'],
- extra=self._base_logger_context)
+ LOG.debug(
+ "Rule %s not applicable for %s.",
+ self.rule.id,
+ self.trigger["name"],
+ extra=self._base_logger_context,
+ )
return is_rule_applicable
def _check_criterion(self, criterion_k, criterion_v, payload_lookup):
- if 'type' not in criterion_v:
+ if "type" not in criterion_v:
# Comparison operator type not specified, can't perform a comparison
return (False, None, None)
- criteria_operator = criterion_v['type']
- criteria_condition = criterion_v.get('condition', None)
- criteria_pattern = criterion_v.get('pattern', None)
+ criteria_operator = criterion_v["type"]
+ criteria_condition = criterion_v.get("condition", None)
+ criteria_pattern = criterion_v.get("pattern", None)
# Render the pattern (it can contain a jinja expressions)
try:
criteria_pattern = self._render_criteria_pattern(
criteria_pattern=criteria_pattern,
- criteria_context=payload_lookup.context
+ criteria_context=payload_lookup.context,
)
except Exception as e:
- msg = ('Failed to render pattern value "%s" for key "%s"' % (criteria_pattern,
- criterion_k))
+ msg = 'Failed to render pattern value "%s" for key "%s"' % (
+ criteria_pattern,
+ criterion_k,
+ )
LOG.exception(msg, extra=self._base_logger_context)
self._create_rule_enforcement(failure_reason=msg, exc=e)
@@ -144,7 +162,7 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup):
else:
payload_value = None
except Exception as e:
- msg = ('Failed transforming criteria key %s' % criterion_k)
+ msg = "Failed transforming criteria key %s" % criterion_k
LOG.exception(msg, extra=self._base_logger_context)
self._create_rule_enforcement(failure_reason=msg, exc=e)
@@ -154,13 +172,18 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup):
try:
if criteria_operator == criteria_operators.SEARCH:
- result = op_func(value=payload_value, criteria_pattern=criteria_pattern,
- criteria_condition=criteria_condition,
- check_function=self._bool_criterion)
+ result = op_func(
+ value=payload_value,
+ criteria_pattern=criteria_pattern,
+ criteria_condition=criteria_condition,
+ check_function=self._bool_criterion,
+ )
else:
result = op_func(value=payload_value, criteria_pattern=criteria_pattern)
except Exception as e:
- msg = ('There might be a problem with the criteria in rule %s' % (self.rule.ref))
+ msg = "There might be a problem with the criteria in rule %s" % (
+ self.rule.ref
+ )
LOG.exception(msg, extra=self._base_logger_context)
self._create_rule_enforcement(failure_reason=msg, exc=e)
@@ -185,9 +208,9 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context):
return criteria_pattern
LOG.debug(
- 'Rendering criteria pattern (%s) with context: %s',
+ "Rendering criteria pattern (%s) with context: %s",
criteria_pattern,
- criteria_context
+ criteria_context,
)
to_complex = False
@@ -197,30 +220,24 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context):
if len(re.findall(MATCH_CRITERIA, criteria_pattern)) > 0:
LOG.debug("Rendering Complex")
complex_criteria_pattern = re.sub(
- MATCH_CRITERIA, r'\1\2 | to_complex\3',
- criteria_pattern
+ MATCH_CRITERIA, r"\1\2 | to_complex\3", criteria_pattern
)
try:
criteria_rendered = render_template_with_system_context(
- value=complex_criteria_pattern,
- context=criteria_context
+ value=complex_criteria_pattern, context=criteria_context
)
criteria_rendered = json.loads(criteria_rendered)
to_complex = True
except ValueError as error:
- LOG.debug('Criteria pattern not valid JSON: %s', error)
+ LOG.debug("Criteria pattern not valid JSON: %s", error)
if not to_complex:
criteria_rendered = render_template_with_system_context(
- value=criteria_pattern,
- context=criteria_context
+ value=criteria_pattern, context=criteria_context
)
- LOG.debug(
- 'Rendered criteria pattern: %s',
- criteria_rendered
- )
+ LOG.debug("Rendered criteria pattern: %s", criteria_rendered)
return criteria_rendered
@@ -231,19 +248,32 @@ def _create_rule_enforcement(self, failure_reason, exc):
Without that, only way for users to find out about those failes matches is by inspecting
the logs.
"""
- failure_reason = ('Failed to match rule "%s" against trigger instance "%s": %s: %s' %
- (self.rule.ref, str(self.trigger_instance.id), failure_reason, str(exc)))
- rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid}
- enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id),
- rule=rule_spec,
- failure_reason=failure_reason,
- status=RULE_ENFORCEMENT_STATUS_FAILED)
+ failure_reason = (
+ 'Failed to match rule "%s" against trigger instance "%s": %s: %s'
+ % (
+ self.rule.ref,
+ str(self.trigger_instance.id),
+ failure_reason,
+ str(exc),
+ )
+ )
+ rule_spec = {
+ "ref": self.rule.ref,
+ "id": str(self.rule.id),
+ "uid": self.rule.uid,
+ }
+ enforcement_db = RuleEnforcementDB(
+ trigger_instance_id=str(self.trigger_instance.id),
+ rule=rule_spec,
+ failure_reason=failure_reason,
+ status=RULE_ENFORCEMENT_STATUS_FAILED,
+ )
try:
RuleEnforcement.add_or_update(enforcement_db)
except:
- extra = {'enforcement_db': enforcement_db}
- LOG.exception('Failed writing enforcement model to db.', extra=extra)
+ extra = {"enforcement_db": enforcement_db}
+ LOG.exception("Failed writing enforcement model to db.", extra=extra)
return enforcement_db
@@ -253,6 +283,7 @@ class SecondPassRuleFilter(RuleFilter):
Special filter that handles all second pass rules. For not these are only
backstop rules i.e. those that can match when no other rule has matched.
"""
+
def __init__(self, trigger_instance, trigger, rule, first_pass_matched):
"""
:param trigger_instance: TriggerInstance DB object.
@@ -277,4 +308,4 @@ def filter(self):
return super(SecondPassRuleFilter, self).filter()
def _is_backstop_rule(self):
- return self.rule.type['ref'] == RULE_TYPE_BACKSTOP
+ return self.rule.type["ref"] == RULE_TYPE_BACKSTOP
diff --git a/st2reactor/st2reactor/rules/matcher.py b/st2reactor/st2reactor/rules/matcher.py
index b2ed198945..4b3a8a2483 100644
--- a/st2reactor/st2reactor/rules/matcher.py
+++ b/st2reactor/st2reactor/rules/matcher.py
@@ -18,7 +18,7 @@
from st2common.constants.rules import RULE_TYPE_BACKSTOP
from st2reactor.rules.filter import RuleFilter, SecondPassRuleFilter
-LOG = logging.getLogger('st2reactor.rules.RulesMatcher')
+LOG = logging.getLogger("st2reactor.rules.RulesMatcher")
class RulesMatcher(object):
@@ -31,25 +31,44 @@ def __init__(self, trigger_instance, trigger, rules, extra_info=False):
def get_matching_rules(self):
first_pass, second_pass = self._split_rules_into_passes()
# first pass
- rule_filters = [RuleFilter(trigger_instance=self.trigger_instance,
- trigger=self.trigger,
- rule=rule,
- extra_info=self.extra_info)
- for rule in first_pass]
- matched_rules = [rule_filter.rule for rule_filter in rule_filters if rule_filter.filter()]
- LOG.debug('[1st_pass] %d rule(s) found to enforce for %s.', len(matched_rules),
- self.trigger['name'])
+ rule_filters = [
+ RuleFilter(
+ trigger_instance=self.trigger_instance,
+ trigger=self.trigger,
+ rule=rule,
+ extra_info=self.extra_info,
+ )
+ for rule in first_pass
+ ]
+ matched_rules = [
+ rule_filter.rule for rule_filter in rule_filters if rule_filter.filter()
+ ]
+ LOG.debug(
+ "[1st_pass] %d rule(s) found to enforce for %s.",
+ len(matched_rules),
+ self.trigger["name"],
+ )
# second pass
- rule_filters = [SecondPassRuleFilter(self.trigger_instance, self.trigger, rule,
- matched_rules)
- for rule in second_pass]
- matched_in_second_pass = [rule_filter.rule for rule_filter in rule_filters
- if rule_filter.filter()]
- LOG.debug('[2nd_pass] %d rule(s) found to enforce for %s.', len(matched_in_second_pass),
- self.trigger['name'])
+ rule_filters = [
+ SecondPassRuleFilter(
+ self.trigger_instance, self.trigger, rule, matched_rules
+ )
+ for rule in second_pass
+ ]
+ matched_in_second_pass = [
+ rule_filter.rule for rule_filter in rule_filters if rule_filter.filter()
+ ]
+ LOG.debug(
+ "[2nd_pass] %d rule(s) found to enforce for %s.",
+ len(matched_in_second_pass),
+ self.trigger["name"],
+ )
matched_rules.extend(matched_in_second_pass)
- LOG.info('%d rule(s) found to enforce for %s.', len(matched_rules),
- self.trigger['name'])
+ LOG.info(
+ "%d rule(s) found to enforce for %s.",
+ len(matched_rules),
+ self.trigger["name"],
+ )
return matched_rules
def _split_rules_into_passes(self):
@@ -68,4 +87,4 @@ def _split_rules_into_passes(self):
return first_pass, second_pass
def _is_first_pass_rule(self, rule):
- return rule.type['ref'] != RULE_TYPE_BACKSTOP
+ return rule.type["ref"] != RULE_TYPE_BACKSTOP
diff --git a/st2reactor/st2reactor/rules/tester.py b/st2reactor/st2reactor/rules/tester.py
index 790148d82d..da4e3572c5 100644
--- a/st2reactor/st2reactor/rules/tester.py
+++ b/st2reactor/st2reactor/rules/tester.py
@@ -32,16 +32,19 @@
from st2reactor.rules.enforcer import RuleEnforcer
from st2reactor.rules.matcher import RulesMatcher
-__all__ = [
- 'RuleTester'
-]
+__all__ = ["RuleTester"]
LOG = logging.getLogger(__name__)
class RuleTester(object):
- def __init__(self, rule_file_path=None, rule_ref=None, trigger_instance_file_path=None,
- trigger_instance_id=None):
+ def __init__(
+ self,
+ rule_file_path=None,
+ rule_ref=None,
+ trigger_instance_file_path=None,
+ trigger_instance_id=None,
+ ):
"""
:param rule_file_path: Path to the file containing rule definition.
:type rule_file_path: ``str``
@@ -69,13 +72,20 @@ def evaluate(self):
# The trigger check needs to be performed here as that is not performed
# by RulesMatcher.
if rule_db.trigger != trigger_db.ref:
- LOG.info('rule.trigger "%s" and trigger.ref "%s" do not match.',
- rule_db.trigger, trigger_db.ref)
+ LOG.info(
+ 'rule.trigger "%s" and trigger.ref "%s" do not match.',
+ rule_db.trigger,
+ trigger_db.ref,
+ )
return False
# Check if rule matches criteria.
- matcher = RulesMatcher(trigger_instance=trigger_instance_db, trigger=trigger_db,
- rules=[rule_db], extra_info=True)
+ matcher = RulesMatcher(
+ trigger_instance=trigger_instance_db,
+ trigger=trigger_db,
+ rules=[rule_db],
+ extra_info=True,
+ )
matching_rules = matcher.get_matching_rules()
# Rule does not match so early exit.
@@ -91,69 +101,86 @@ def evaluate(self):
action_db.parameters = {}
params = rule_db.action.parameters # pylint: disable=no-member
- context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db,
- trace_context=None)
+ context, additional_contexts = enforcer.get_action_execution_context(
+ action_db=action_db, trace_context=None
+ )
# Note: We only return partially resolved parameters.
# To be able to return all parameters we would need access to corresponding ActionDB,
# RunnerTypeDB and ConfigDB object, but this would add a dependency on the database and the
# tool is meant to be used standalone.
try:
- params = enforcer.get_resolved_parameters(action_db=action_db,
- runnertype_db=runner_type_db,
- params=params,
- context=context,
- additional_contexts=additional_contexts)
-
- LOG.info('Action parameters resolved to:')
+ params = enforcer.get_resolved_parameters(
+ action_db=action_db,
+ runnertype_db=runner_type_db,
+ params=params,
+ context=context,
+ additional_contexts=additional_contexts,
+ )
+
+ LOG.info("Action parameters resolved to:")
for param in six.iteritems(params):
- LOG.info('\t%s: %s', param[0], param[1])
+ LOG.info("\t%s: %s", param[0], param[1])
return True
except (UndefinedError, ValueError) as e:
- LOG.error('Failed to resolve parameters\n\tOriginal error : %s', six.text_type(e))
+ LOG.error(
+ "Failed to resolve parameters\n\tOriginal error : %s", six.text_type(e)
+ )
return False
except:
- LOG.exception('Failed to resolve parameters.')
+ LOG.exception("Failed to resolve parameters.")
return False
def _get_rule_db(self):
if self._rule_file_path:
return self._get_rule_db_from_file(
- file_path=os.path.realpath(self._rule_file_path))
+ file_path=os.path.realpath(self._rule_file_path)
+ )
elif self._rule_ref:
return Rule.get_by_ref(self._rule_ref)
- raise ValueError('One of _rule_file_path or _rule_ref should be specified.')
+ raise ValueError("One of _rule_file_path or _rule_ref should be specified.")
def _get_trigger_instance_db(self):
if self._trigger_instance_file_path:
return self._get_trigger_instance_db_from_file(
- file_path=os.path.realpath(self._trigger_instance_file_path))
+ file_path=os.path.realpath(self._trigger_instance_file_path)
+ )
elif self._trigger_instance_id:
trigger_instance_db = TriggerInstance.get_by_id(self._trigger_instance_id)
trigger_db = Trigger.get_by_ref(trigger_instance_db.trigger)
return trigger_instance_db, trigger_db
- raise ValueError('One of _trigger_instance_file_path or'
- '_trigger_instance_id should be specified.')
+ raise ValueError(
+ "One of _trigger_instance_file_path or"
+ "_trigger_instance_id should be specified."
+ )
def _get_rule_db_from_file(self, file_path):
data = self._meta_loader.load(file_path=file_path)
- pack = data.get('pack', 'unknown')
- name = data.get('name', 'unknown')
- trigger = data['trigger']['type']
- criteria = data.get('criteria', None)
- action = data.get('action', {})
-
- rule_db = RuleDB(pack=pack, name=name, trigger=trigger, criteria=criteria, action=action,
- enabled=True)
- rule_db.id = 'rule_tester_rule'
+ pack = data.get("pack", "unknown")
+ name = data.get("name", "unknown")
+ trigger = data["trigger"]["type"]
+ criteria = data.get("criteria", None)
+ action = data.get("action", {})
+
+ rule_db = RuleDB(
+ pack=pack,
+ name=name,
+ trigger=trigger,
+ criteria=criteria,
+ action=action,
+ enabled=True,
+ )
+ rule_db.id = "rule_tester_rule"
return rule_db
def _get_trigger_instance_db_from_file(self, file_path):
data = self._meta_loader.load(file_path=file_path)
instance = TriggerInstanceDB(**data)
- instance.id = 'rule_tester_instance'
+ instance.id = "rule_tester_instance"
- trigger_ref = ResourceReference.from_string_reference(instance['trigger'])
- trigger_db = TriggerDB(pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref)
+ trigger_ref = ResourceReference.from_string_reference(instance["trigger"])
+ trigger_db = TriggerDB(
+ pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref
+ )
return instance, trigger_db
diff --git a/st2reactor/st2reactor/rules/worker.py b/st2reactor/st2reactor/rules/worker.py
index 7dbe4a59e1..53e636a346 100644
--- a/st2reactor/st2reactor/rules/worker.py
+++ b/st2reactor/st2reactor/rules/worker.py
@@ -41,12 +41,12 @@ def __init__(self, connection, queues):
self.rules_engine = RulesEngine()
def pre_ack_process(self, message):
- '''
+ """
TriggerInstance from message is create prior to acknowledging the message. This
gets us a way to not acknowledge messages.
- '''
- trigger = message['trigger']
- payload = message['payload']
+ """
+ trigger = message["trigger"]
+ payload = message["payload"]
# Accomodate for not being able to create a TrigegrInstance if a TriggerDB
# is not found.
@@ -54,16 +54,19 @@ def pre_ack_process(self, message):
trigger,
payload or {},
date_utils.get_datetime_utc_now(),
- raise_on_no_trigger=True)
+ raise_on_no_trigger=True,
+ )
return self._compose_pre_ack_process_response(trigger_instance, message)
def process(self, pre_ack_response):
- trigger_instance, message = self._decompose_pre_ack_process_response(pre_ack_response)
+ trigger_instance, message = self._decompose_pre_ack_process_response(
+ pre_ack_response
+ )
if not trigger_instance:
- raise ValueError('No trigger_instance provided for processing.')
+ raise ValueError("No trigger_instance provided for processing.")
- get_driver().inc_counter('trigger.%s.processed' % (trigger_instance.trigger))
+ get_driver().inc_counter("trigger.%s.processed" % (trigger_instance.trigger))
try:
# Use trace_context from the message and if not found create a new context
@@ -71,34 +74,39 @@ def process(self, pre_ack_response):
trace_context = message.get(TRACE_CONTEXT, None)
if not trace_context:
trace_context = {
- TRACE_ID: 'trigger_instance-%s' % str(trigger_instance.id)
+ TRACE_ID: "trigger_instance-%s" % str(trigger_instance.id)
}
# add a trace or update an existing trace with trigger_instance
trace_service.add_or_update_given_trace_context(
trace_context=trace_context,
trigger_instances=[
- trace_service.get_trace_component_for_trigger_instance(trigger_instance)
- ]
+ trace_service.get_trace_component_for_trigger_instance(
+ trigger_instance
+ )
+ ],
)
container_utils.update_trigger_instance_status(
- trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING)
+ trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING
+ )
- with CounterWithTimer(key='rule.processed'):
- with Timer(key='trigger.%s.processed' % (trigger_instance.trigger)):
+ with CounterWithTimer(key="rule.processed"):
+ with Timer(key="trigger.%s.processed" % (trigger_instance.trigger)):
self.rules_engine.handle_trigger_instance(trigger_instance)
container_utils.update_trigger_instance_status(
- trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED)
+ trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED
+ )
except:
# TODO : Capture the reason for failure.
container_utils.update_trigger_instance_status(
- trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED)
+ trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED
+ )
# This could be a large message but at least in case of an exception
# we get to see more context.
# Beyond this point code cannot really handle the exception anyway so
# eating up the exception.
- LOG.exception('Failed to handle trigger_instance %s.', trigger_instance)
+ LOG.exception("Failed to handle trigger_instance %s.", trigger_instance)
return
@staticmethod
@@ -106,14 +114,14 @@ def _compose_pre_ack_process_response(trigger_instance, message):
"""
Codify response of the pre_ack_process method.
"""
- return {'trigger_instance': trigger_instance, 'message': message}
+ return {"trigger_instance": trigger_instance, "message": message}
@staticmethod
def _decompose_pre_ack_process_response(response):
"""
Break-down response of pre_ack_process into constituents for simpler consumption.
"""
- return response.get('trigger_instance', None), response.get('message', None)
+ return response.get("trigger_instance", None), response.get("message", None)
def get_worker():
diff --git a/st2reactor/st2reactor/sensor/base.py b/st2reactor/st2reactor/sensor/base.py
index f7fce2460b..a8309ba292 100644
--- a/st2reactor/st2reactor/sensor/base.py
+++ b/st2reactor/st2reactor/sensor/base.py
@@ -21,10 +21,7 @@
from st2common.util import concurrency
-__all__ = [
- 'Sensor',
- 'PollingSensor'
-]
+__all__ = ["Sensor", "PollingSensor"]
@six.add_metaclass(abc.ABCMeta)
@@ -107,7 +104,9 @@ class PollingSensor(BaseSensor):
"""
def __init__(self, sensor_service, config=None, poll_interval=5):
- super(PollingSensor, self).__init__(sensor_service=sensor_service, config=config)
+ super(PollingSensor, self).__init__(
+ sensor_service=sensor_service, config=config
+ )
self._poll_interval = poll_interval
@abc.abstractmethod
diff --git a/st2reactor/st2reactor/sensor/config.py b/st2reactor/st2reactor/sensor/config.py
index 981ddd9b8f..8126bdbc9f 100644
--- a/st2reactor/st2reactor/sensor/config.py
+++ b/st2reactor/st2reactor/sensor/config.py
@@ -26,8 +26,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts(ignore_errors=False):
@@ -46,48 +49,62 @@ def _register_common_opts(ignore_errors=False):
def _register_sensor_container_opts(ignore_errors=False):
logging_opts = [
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.sensorcontainer.conf',
- help='location of the logging.conf file')
+ "logging",
+ default="/etc/st2/logging.sensorcontainer.conf",
+ help="location of the logging.conf file",
+ )
]
- st2cfg.do_register_opts(logging_opts, group='sensorcontainer', ignore_errors=ignore_errors)
+ st2cfg.do_register_opts(
+ logging_opts, group="sensorcontainer", ignore_errors=ignore_errors
+ )
# Partitioning options
partition_opts = [
cfg.StrOpt(
- 'sensor_node_name', default='sensornode1',
- help='name of the sensor node.'),
+ "sensor_node_name", default="sensornode1", help="name of the sensor node."
+ ),
cfg.Opt(
- 'partition_provider',
+ "partition_provider",
type=types.Dict(value_type=types.String()),
- default={'name': DEFAULT_PARTITION_LOADER},
- help='Provider of sensor node partition config.')
+ default={"name": DEFAULT_PARTITION_LOADER},
+ help="Provider of sensor node partition config.",
+ ),
]
- st2cfg.do_register_opts(partition_opts, group='sensorcontainer', ignore_errors=ignore_errors)
+ st2cfg.do_register_opts(
+ partition_opts, group="sensorcontainer", ignore_errors=ignore_errors
+ )
# Other options
other_opts = [
cfg.BoolOpt(
- 'single_sensor_mode', default=False,
- help='Run in a single sensor mode where parent process exits when a sensor crashes / '
- 'dies. This is useful in environments where partitioning, sensor process life '
- 'cycle and failover is handled by a 3rd party service such as kubernetes.')
+ "single_sensor_mode",
+ default=False,
+ help="Run in a single sensor mode where parent process exits when a sensor crashes / "
+ "dies. This is useful in environments where partitioning, sensor process life "
+ "cycle and failover is handled by a 3rd party service such as kubernetes.",
+ )
]
- st2cfg.do_register_opts(other_opts, group='sensorcontainer', ignore_errors=ignore_errors)
+ st2cfg.do_register_opts(
+ other_opts, group="sensorcontainer", ignore_errors=ignore_errors
+ )
# CLI options
cli_opts = [
cfg.StrOpt(
- 'sensor-ref',
- help='Only run sensor with the provided reference. Value is of the form '
- '. (e.g. linux.FileWatchSensor).'),
+ "sensor-ref",
+ help="Only run sensor with the provided reference. Value is of the form "
+ ". (e.g. linux.FileWatchSensor).",
+ ),
cfg.BoolOpt(
- 'single-sensor-mode', default=False,
- help='Run in a single sensor mode where parent process exits when a sensor crashes / '
- 'dies. This is useful in environments where partitioning, sensor process life '
- 'cycle and failover is handled by a 3rd party service such as kubernetes.')
+ "single-sensor-mode",
+ default=False,
+ help="Run in a single sensor mode where parent process exits when a sensor crashes / "
+ "dies. This is useful in environments where partitioning, sensor process life "
+ "cycle and failover is handled by a 3rd party service such as kubernetes.",
+ ),
]
st2cfg.do_register_cli_opts(cli_opts, ignore_errors=ignore_errors)
diff --git a/st2reactor/st2reactor/timer/base.py b/st2reactor/st2reactor/timer/base.py
index ed99d90e77..723d362066 100644
--- a/st2reactor/st2reactor/timer/base.py
+++ b/st2reactor/st2reactor/timer/base.py
@@ -41,17 +41,20 @@ class St2Timer(object):
"""
A timer interface that uses APScheduler 3.0.
"""
+
def __init__(self, local_timezone=None):
self._timezone = local_timezone
self._scheduler = BlockingScheduler(timezone=self._timezone)
self._jobs = {}
self._trigger_types = list(TIMER_TRIGGER_TYPES.keys())
- self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger,
- update_handler=self._handle_update_trigger,
- delete_handler=self._handle_delete_trigger,
- trigger_types=self._trigger_types,
- queue_suffix=self.__class__.__name__,
- exclusive=True)
+ self._trigger_watcher = TriggerWatcher(
+ create_handler=self._handle_create_trigger,
+ update_handler=self._handle_update_trigger,
+ delete_handler=self._handle_delete_trigger,
+ trigger_types=self._trigger_types,
+ queue_suffix=self.__class__.__name__,
+ exclusive=True,
+ )
self._trigger_dispatcher = TriggerDispatcher(LOG)
def start(self):
@@ -70,89 +73,109 @@ def update_trigger(self, trigger):
self.add_trigger(trigger)
def remove_trigger(self, trigger):
- trigger_id = trigger['id']
+ trigger_id = trigger["id"]
try:
job_id = self._jobs[trigger_id]
except KeyError:
- LOG.info('Job not found: %s', trigger_id)
+ LOG.info("Job not found: %s", trigger_id)
return
self._scheduler.remove_job(job_id)
del self._jobs[trigger_id]
def _add_job_to_scheduler(self, trigger):
- trigger_type_ref = trigger['type']
+ trigger_type_ref = trigger["type"]
trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref]
try:
- util_schema.validate(instance=trigger['parameters'],
- schema=trigger_type['parameters_schema'],
- cls=util_schema.CustomValidator,
- use_default=True,
- allow_default_none=True)
+ util_schema.validate(
+ instance=trigger["parameters"],
+ schema=trigger_type["parameters_schema"],
+ cls=util_schema.CustomValidator,
+ use_default=True,
+ allow_default_none=True,
+ )
except jsonschema.ValidationError as e:
- LOG.error('Exception scheduling timer: %s, %s',
- trigger['parameters'], e, exc_info=True)
+ LOG.error(
+ "Exception scheduling timer: %s, %s",
+ trigger["parameters"],
+ e,
+ exc_info=True,
+ )
raise # Or should we just return?
- time_spec = trigger['parameters']
- time_zone = aps_utils.astimezone(trigger['parameters'].get('timezone'))
+ time_spec = trigger["parameters"]
+ time_zone = aps_utils.astimezone(trigger["parameters"].get("timezone"))
time_type = None
- if trigger_type['name'] == 'st2.IntervalTimer':
- unit = time_spec.get('unit', None)
- value = time_spec.get('delta', None)
- time_type = IntervalTrigger(**{unit: value, 'timezone': time_zone})
- elif trigger_type['name'] == 'st2.DateTimer':
+ if trigger_type["name"] == "st2.IntervalTimer":
+ unit = time_spec.get("unit", None)
+ value = time_spec.get("delta", None)
+ time_type = IntervalTrigger(**{unit: value, "timezone": time_zone})
+ elif trigger_type["name"] == "st2.DateTimer":
# Raises an exception if date string isn't a valid one.
- dat = date_parser.parse(time_spec.get('date', None))
+ dat = date_parser.parse(time_spec.get("date", None))
time_type = DateTrigger(dat, timezone=time_zone)
- elif trigger_type['name'] == 'st2.CronTimer':
+ elif trigger_type["name"] == "st2.CronTimer":
cron = time_spec.copy()
- cron['timezone'] = time_zone
+ cron["timezone"] = time_zone
time_type = CronTrigger(**cron)
utc_now = date_utils.get_datetime_utc_now()
- if hasattr(time_type, 'run_date') and utc_now > time_type.run_date:
- LOG.warning('Not scheduling expired timer: %s : %s',
- trigger['parameters'], time_type.run_date)
+ if hasattr(time_type, "run_date") and utc_now > time_type.run_date:
+ LOG.warning(
+ "Not scheduling expired timer: %s : %s",
+ trigger["parameters"],
+ time_type.run_date,
+ )
else:
self._add_job(trigger, time_type)
return time_type
def _add_job(self, trigger, time_type, replace=True):
try:
- job = self._scheduler.add_job(self._emit_trigger_instance,
- trigger=time_type,
- args=[trigger],
- replace_existing=replace)
- LOG.info('Job %s scheduled.', job.id)
- self._jobs[trigger['id']] = job.id
+ job = self._scheduler.add_job(
+ self._emit_trigger_instance,
+ trigger=time_type,
+ args=[trigger],
+ replace_existing=replace,
+ )
+ LOG.info("Job %s scheduled.", job.id)
+ self._jobs[trigger["id"]] = job.id
except Exception as e:
- LOG.error('Exception scheduling timer: %s, %s',
- trigger['parameters'], e, exc_info=True)
+ LOG.error(
+ "Exception scheduling timer: %s, %s",
+ trigger["parameters"],
+ e,
+ exc_info=True,
+ )
def _emit_trigger_instance(self, trigger):
utc_now = date_utils.get_datetime_utc_now()
# debug logging is reasonable for this one. A high resolution timer will end up
# trashing standard logs.
- LOG.debug('Timer fired at: %s. Trigger: %s', str(utc_now), trigger)
+ LOG.debug("Timer fired at: %s. Trigger: %s", str(utc_now), trigger)
payload = {
- 'executed_at': str(utc_now),
- 'schedule': trigger['parameters'].get('time')
+ "executed_at": str(utc_now),
+ "schedule": trigger["parameters"].get("time"),
}
- trace_context = TraceContext(trace_tag='%s-%s' % (self._get_trigger_type_name(trigger),
- trigger.get('name', uuid.uuid4().hex)))
+ trace_context = TraceContext(
+ trace_tag="%s-%s"
+ % (
+ self._get_trigger_type_name(trigger),
+ trigger.get("name", uuid.uuid4().hex),
+ )
+ )
self._trigger_dispatcher.dispatch(trigger, payload, trace_context=trace_context)
def _get_trigger_type_name(self, trigger):
- trigger_type_ref = trigger['type']
+ trigger_type_ref = trigger["type"]
trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref]
- return trigger_type['name']
+ return trigger_type["name"]
def _register_timer_trigger_types(self):
return trigger_services.add_trigger_models(list(TIMER_TRIGGER_TYPES.values()))
diff --git a/st2reactor/st2reactor/timer/config.py b/st2reactor/st2reactor/timer/config.py
index db180f85dd..bbc1020cb9 100644
--- a/st2reactor/st2reactor/timer/config.py
+++ b/st2reactor/st2reactor/timer/config.py
@@ -25,8 +25,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
diff --git a/st2reactor/tests/integration/test_garbage_collector.py b/st2reactor/tests/integration/test_garbage_collector.py
index 1de1e9c529..5b0f890ac3 100644
--- a/st2reactor/tests/integration/test_garbage_collector.py
+++ b/st2reactor/tests/integration/test_garbage_collector.py
@@ -37,33 +37,28 @@
from st2tests.fixturesloader import FixturesLoader
from six.moves import range
-__all__ = [
- 'GarbageCollectorServiceTestCase'
-]
+__all__ = ["GarbageCollectorServiceTestCase"]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf')
+ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf")
ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH)
-INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests2.conf')
+INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests2.conf")
INQUIRY_CONFIG_PATH = os.path.abspath(INQUIRY_CONFIG_PATH)
PYTHON_BINARY = sys.executable
-BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2garbagecollector')
+BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2garbagecollector")
BINARY = os.path.abspath(BINARY)
-CMD = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH]
-CMD_INQUIRY = [PYTHON_BINARY, BINARY, '--config-file', INQUIRY_CONFIG_PATH]
+CMD = [PYTHON_BINARY, BINARY, "--config-file", ST2_CONFIG_PATH]
+CMD_INQUIRY = [PYTHON_BINARY, BINARY, "--config-file", INQUIRY_CONFIG_PATH]
-TEST_FIXTURES = {
- 'runners': ['inquirer.yaml'],
- 'actions': ['ask.yaml']
-}
+TEST_FIXTURES = {"runners": ["inquirer.yaml"], "actions": ["ask.yaml"]}
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
class GarbageCollectorServiceTestCase(IntegrationTestCase, CleanDbTestCase):
@@ -75,7 +70,8 @@ def setUp(self):
super(GarbageCollectorServiceTestCase, self).setUp()
self.models = FixturesLoader().save_fixtures_to_db(
- fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES)
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES
+ )
def test_garbage_collection(self):
now = date_utils.get_datetime_utc_now()
@@ -85,102 +81,125 @@ def test_garbage_collection(self):
# config
old_executions_count = 15
ttl_days = 30 # > 20
- timestamp = (now - datetime.timedelta(days=ttl_days))
+ timestamp = now - datetime.timedelta(days=ttl_days)
for index in range(0, old_executions_count):
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
ActionExecution.add_or_update(action_execution_db)
- stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout')
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout",
+ )
ActionExecutionOutput.add_or_update(stdout_db)
- stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stderr',
- data='stderr')
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stderr",
+ data="stderr",
+ )
ActionExecutionOutput.add_or_update(stderr_db)
# Insert come mock ActionExecutionDB objects with start_timestamp > TTL defined in the
# config
new_executions_count = 5
ttl_days = 2 # < 20
- timestamp = (now - datetime.timedelta(days=ttl_days))
+ timestamp = now - datetime.timedelta(days=ttl_days)
for index in range(0, new_executions_count):
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
ActionExecution.add_or_update(action_execution_db)
- stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout')
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout",
+ )
ActionExecutionOutput.add_or_update(stdout_db)
- stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stderr',
- data='stderr')
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stderr",
+ data="stderr",
+ )
ActionExecutionOutput.add_or_update(stderr_db)
# Insert some mock output objects where start_timestamp > action_executions_output_ttl
new_output_count = 5
ttl_days = 15 # > 10 and < 20
- timestamp = (now - datetime.timedelta(days=ttl_days))
+ timestamp = now - datetime.timedelta(days=ttl_days)
for index in range(0, new_output_count):
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
ActionExecution.add_or_update(action_execution_db)
- stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout')
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout",
+ )
ActionExecutionOutput.add_or_update(stdout_db)
- stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stderr',
- data='stderr')
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stderr",
+ data="stderr",
+ )
ActionExecutionOutput.add_or_update(stderr_db)
execs = ActionExecution.get_all()
- self.assertEqual(len(execs),
- (old_executions_count + new_executions_count + new_output_count))
-
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
- self.assertEqual(len(stdout_dbs),
- (old_executions_count + new_executions_count + new_output_count))
-
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
- self.assertEqual(len(stderr_dbs),
- (old_executions_count + new_executions_count + new_output_count))
+ self.assertEqual(
+ len(execs), (old_executions_count + new_executions_count + new_output_count)
+ )
+
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
+ self.assertEqual(
+ len(stdout_dbs),
+ (old_executions_count + new_executions_count + new_output_count),
+ )
+
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
+ self.assertEqual(
+ len(stderr_dbs),
+ (old_executions_count + new_executions_count + new_output_count),
+ )
# Start garbage collector
process = self._start_garbage_collector()
@@ -196,10 +215,10 @@ def test_garbage_collection(self):
# Collection for output objects older than 10 days is also enabled, so those objects
# should be deleted as well
- stdout_dbs = ActionExecutionOutput.query(output_type='stdout')
+ stdout_dbs = ActionExecutionOutput.query(output_type="stdout")
self.assertEqual(len(stdout_dbs), (new_executions_count))
- stderr_dbs = ActionExecutionOutput.query(output_type='stderr')
+ stderr_dbs = ActionExecutionOutput.query(output_type="stderr")
self.assertEqual(len(stderr_dbs), (new_executions_count))
def test_inquiry_garbage_collection(self):
@@ -207,28 +226,28 @@ def test_inquiry_garbage_collection(self):
# Insert some mock Inquiries with start_timestamp > TTL
old_inquiry_count = 15
- timestamp = (now - datetime.timedelta(minutes=3))
+ timestamp = now - datetime.timedelta(minutes=3)
for index in range(0, old_inquiry_count):
self._create_inquiry(ttl=2, timestamp=timestamp)
# Insert some mock Inquiries with TTL set to a "disabled" value
disabled_inquiry_count = 3
- timestamp = (now - datetime.timedelta(minutes=3))
+ timestamp = now - datetime.timedelta(minutes=3)
for index in range(0, disabled_inquiry_count):
self._create_inquiry(ttl=0, timestamp=timestamp)
# Insert some mock Inquiries with start_timestamp < TTL
new_inquiry_count = 5
- timestamp = (now - datetime.timedelta(minutes=3))
+ timestamp = now - datetime.timedelta(minutes=3)
for index in range(0, new_inquiry_count):
self._create_inquiry(ttl=15, timestamp=timestamp)
- filters = {
- 'status': action_constants.LIVEACTION_STATUS_PENDING
- }
+ filters = {"status": action_constants.LIVEACTION_STATUS_PENDING}
inquiries = list(ActionExecution.query(**filters))
- self.assertEqual(len(inquiries),
- (old_inquiry_count + new_inquiry_count + disabled_inquiry_count))
+ self.assertEqual(
+ len(inquiries),
+ (old_inquiry_count + new_inquiry_count + disabled_inquiry_count),
+ )
# Start garbage collector
process = self._start_garbage_collector()
@@ -243,18 +262,25 @@ def test_inquiry_garbage_collection(self):
self.assertEqual(len(inquiries), new_inquiry_count + disabled_inquiry_count)
def _create_inquiry(self, ttl, timestamp):
- action_db = self.models['actions']['ask.yaml']
+ action_db = self.models["actions"]["ask.yaml"]
liveaction_db = LiveActionDB()
liveaction_db.status = action_constants.LIVEACTION_STATUS_PENDING
liveaction_db.start_timestamp = timestamp
- liveaction_db.action = ResourceReference(name=action_db.name, pack=action_db.pack).ref
- liveaction_db.result = {'ttl': ttl}
+ liveaction_db.action = ResourceReference(
+ name=action_db.name, pack=action_db.pack
+ ).ref
+ liveaction_db.result = {"ttl": ttl}
liveaction_db = LiveAction.add_or_update(liveaction_db)
executions.create_execution_object(liveaction_db)
def _start_garbage_collector(self):
subprocess = concurrency.get_subprocess_module()
- process = subprocess.Popen(CMD_INQUIRY, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- shell=False, preexec_fn=os.setsid)
+ process = subprocess.Popen(
+ CMD_INQUIRY,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ preexec_fn=os.setsid,
+ )
self.add_process(process=process)
return process
diff --git a/st2reactor/tests/integration/test_rules_engine.py b/st2reactor/tests/integration/test_rules_engine.py
index 669a88797f..05ebce5e9e 100644
--- a/st2reactor/tests/integration/test_rules_engine.py
+++ b/st2reactor/tests/integration/test_rules_engine.py
@@ -26,18 +26,16 @@
from st2tests.base import IntegrationTestCase
from st2tests.base import CleanDbTestCase
-__all__ = [
- 'TimersEngineServiceEnableDisableTestCase'
-]
+__all__ = ["TimersEngineServiceEnableDisableTestCase"]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf')
+ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf")
ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH)
PYTHON_BINARY = sys.executable
-BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2timersengine')
+BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2timersengine")
BINARY = os.path.abspath(BINARY)
-CMD = [PYTHON_BINARY, BINARY, '--config-file']
+CMD = [PYTHON_BINARY, BINARY, "--config-file"]
class TimersEngineServiceEnableDisableTestCase(IntegrationTestCase, CleanDbTestCase):
@@ -46,7 +44,7 @@ def setUp(self):
config_text = open(ST2_CONFIG_PATH).read()
self.cfg_fd, self.cfg_path = tempfile.mkstemp()
- with open(self.cfg_path, 'w') as f:
+ with open(self.cfg_path, "w") as f:
f.write(config_text)
self.cmd = []
self.cmd.extend(CMD)
@@ -65,7 +63,7 @@ def test_timer_enable_implicit(self):
process = self._start_times_engine(cmd=self.cmd)
lines = 0
while lines < 100:
- line = process.stdout.readline().decode('utf-8')
+ line = process.stdout.readline().decode("utf-8")
lines += 1
sys.stdout.write(line)
@@ -78,12 +76,15 @@ def test_timer_enable_implicit(self):
self.remove_process(process=process)
if not seen_line:
- raise AssertionError('Didn\'t see "%s" log line in timer output' %
- (TIMER_ENABLED_LOG_LINE))
+ raise AssertionError(
+ 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE)
+ )
def test_timer_enable_explicit(self):
- self._append_to_cfg_file(cfg_path=self.cfg_path,
- content='\n[timersengine]\nenable = True\n[timer]\nenable = True')
+ self._append_to_cfg_file(
+ cfg_path=self.cfg_path,
+ content="\n[timersengine]\nenable = True\n[timer]\nenable = True",
+ )
process = None
seen_line = False
@@ -91,7 +92,7 @@ def test_timer_enable_explicit(self):
process = self._start_times_engine(cmd=self.cmd)
lines = 0
while lines < 100:
- line = process.stdout.readline().decode('utf-8')
+ line = process.stdout.readline().decode("utf-8")
lines += 1
sys.stdout.write(line)
@@ -104,12 +105,15 @@ def test_timer_enable_explicit(self):
self.remove_process(process=process)
if not seen_line:
- raise AssertionError('Didn\'t see "%s" log line in timer output' %
- (TIMER_ENABLED_LOG_LINE))
+ raise AssertionError(
+ 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE)
+ )
def test_timer_disable_explicit(self):
- self._append_to_cfg_file(cfg_path=self.cfg_path,
- content='\n[timersengine]\nenable = False\n[timer]\nenable = False')
+ self._append_to_cfg_file(
+ cfg_path=self.cfg_path,
+ content="\n[timersengine]\nenable = False\n[timer]\nenable = False",
+ )
process = None
seen_line = False
@@ -117,7 +121,7 @@ def test_timer_disable_explicit(self):
process = self._start_times_engine(cmd=self.cmd)
lines = 0
while lines < 100:
- line = process.stdout.readline().decode('utf-8')
+ line = process.stdout.readline().decode("utf-8")
lines += 1
sys.stdout.write(line)
@@ -130,18 +134,24 @@ def test_timer_disable_explicit(self):
self.remove_process(process=process)
if not seen_line:
- raise AssertionError('Didn\'t see "%s" log line in timer output' %
- (TIMER_DISABLED_LOG_LINE))
+ raise AssertionError(
+ 'Didn\'t see "%s" log line in timer output' % (TIMER_DISABLED_LOG_LINE)
+ )
def _start_times_engine(self, cmd):
subprocess = concurrency.get_subprocess_module()
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- shell=False, preexec_fn=os.setsid)
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ preexec_fn=os.setsid,
+ )
self.add_process(process=process)
return process
def _append_to_cfg_file(self, cfg_path, content):
- with open(cfg_path, 'a') as f:
+ with open(cfg_path, "a") as f:
f.write(content)
def _remove_tempfile(self, fd, path):
diff --git a/st2reactor/tests/integration/test_sensor_container.py b/st2reactor/tests/integration/test_sensor_container.py
index 7971e36106..41eb3307bc 100644
--- a/st2reactor/tests/integration/test_sensor_container.py
+++ b/st2reactor/tests/integration/test_sensor_container.py
@@ -30,28 +30,26 @@
from st2common.bootstrap.sensorsregistrar import register_sensors
from st2tests.base import IntegrationTestCase
-__all__ = [
- 'SensorContainerTestCase'
-]
+__all__ = ["SensorContainerTestCase"]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf')
+ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf")
ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH)
PYTHON_BINARY = sys.executable
-BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2sensorcontainer')
+BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2sensorcontainer")
BINARY = os.path.abspath(BINARY)
-PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, '../../../contrib'))
+PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, "../../../contrib"))
DEFAULT_CMD = [
PYTHON_BINARY,
BINARY,
- '--config-file',
+ "--config-file",
ST2_CONFIG_PATH,
- '--sensor-ref=examples.SamplePollingSensor'
+ "--sensor-ref=examples.SamplePollingSensor",
]
@@ -69,11 +67,24 @@ def setUpClass(cls):
st2tests.config.parse_args()
- username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None
- password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None
+ username = (
+ cfg.CONF.database.username
+ if hasattr(cfg.CONF.database, "username")
+ else None
+ )
+ password = (
+ cfg.CONF.database.password
+ if hasattr(cfg.CONF.database, "password")
+ else None
+ )
cls.db_connection = db_setup(
- cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port,
- username=username, password=password, ensure_indexes=False)
+ cfg.CONF.database.db_name,
+ cfg.CONF.database.host,
+ cfg.CONF.database.port,
+ username=username,
+ password=password,
+ ensure_indexes=False,
+ )
# NOTE: We need to perform this patching because test fixtures are located outside of the
# packs base paths directory. This will never happen outside the context of test fixtures.
@@ -83,11 +94,17 @@ def setUpClass(cls):
register_sensors(packs_base_paths=[PACKS_BASE_PATH], use_pack_cache=False)
# Create virtualenv for examples pack
- virtualenv_path = '/tmp/virtualenvs/examples'
+ virtualenv_path = "/tmp/virtualenvs/examples"
- run_command(cmd=['rm', '-rf', virtualenv_path])
+ run_command(cmd=["rm", "-rf", virtualenv_path])
- cmd = ['virtualenv', '--system-site-packages', '--python', PYTHON_BINARY, virtualenv_path]
+ cmd = [
+ "virtualenv",
+ "--system-site-packages",
+ "--python",
+ PYTHON_BINARY,
+ virtualenv_path,
+ ]
run_command(cmd=cmd)
def test_child_processes_are_killed_on_sigint(self):
@@ -169,7 +186,13 @@ def test_child_processes_are_killed_on_sigkill(self):
def test_single_sensor_mode(self):
# 1. --sensor-ref not provided
- cmd = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode']
+ cmd = [
+ PYTHON_BINARY,
+ BINARY,
+ "--config-file",
+ ST2_CONFIG_PATH,
+ "--single-sensor-mode",
+ ]
process = self._start_sensor_container(cmd=cmd)
pp = psutil.Process(process.pid)
@@ -178,14 +201,24 @@ def test_single_sensor_mode(self):
concurrency.sleep(4)
stdout = process.stdout.read()
- self.assertTrue((b'--sensor-ref argument must be provided when running in single sensor '
- b'mode') in stdout)
+ self.assertTrue(
+ (
+ b"--sensor-ref argument must be provided when running in single sensor "
+ b"mode"
+ )
+ in stdout
+ )
self.assertProcessExited(proc=pp)
self.remove_process(process=process)
# 2. sensor ref provided
- cmd = [BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode',
- '--sensor-ref=examples.SampleSensorExit']
+ cmd = [
+ BINARY,
+ "--config-file",
+ ST2_CONFIG_PATH,
+ "--single-sensor-mode",
+ "--sensor-ref=examples.SampleSensorExit",
+ ]
process = self._start_sensor_container(cmd=cmd)
pp = psutil.Process(process.pid)
@@ -196,9 +229,11 @@ def test_single_sensor_mode(self):
# Container should exit and not respawn a sensor in single sensor mode
stdout = process.stdout.read()
- self.assertTrue(b'Process for sensor examples.SampleSensorExit has exited with code 110')
- self.assertTrue(b'Not respawning a sensor since running in single sensor mode')
- self.assertTrue(b'Process container quit with exit_code 110.')
+ self.assertTrue(
+ b"Process for sensor examples.SampleSensorExit has exited with code 110"
+ )
+ self.assertTrue(b"Not respawning a sensor since running in single sensor mode")
+ self.assertTrue(b"Process container quit with exit_code 110.")
concurrency.sleep(2)
self.assertProcessExited(proc=pp)
@@ -207,7 +242,12 @@ def test_single_sensor_mode(self):
def _start_sensor_container(self, cmd=DEFAULT_CMD):
subprocess = concurrency.get_subprocess_module()
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
- shell=False, preexec_fn=os.setsid)
+ process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=False,
+ preexec_fn=os.setsid,
+ )
self.add_process(process=process)
return process
diff --git a/st2reactor/tests/integration/test_sensor_watcher.py b/st2reactor/tests/integration/test_sensor_watcher.py
index 9727da92a6..6caee09c7f 100644
--- a/st2reactor/tests/integration/test_sensor_watcher.py
+++ b/st2reactor/tests/integration/test_sensor_watcher.py
@@ -22,19 +22,15 @@
from st2common.services.sensor_watcher import SensorWatcher
from st2tests.base import IntegrationTestCase
-__all__ = [
- 'SensorWatcherTestCase'
-]
+__all__ = ["SensorWatcherTestCase"]
class SensorWatcherTestCase(IntegrationTestCase):
-
@classmethod
def setUpClass(cls):
super(SensorWatcherTestCase, cls).setUpClass()
def test_sensor_watch_queue_gets_deleted_on_stop(self):
-
def create_handler(sensor_db):
pass
@@ -44,25 +40,32 @@ def update_handler(sensor_db):
def delete_handler(sensor_db):
pass
- sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler,
- queue_suffix='covfefe')
+ sensor_watcher = SensorWatcher(
+ create_handler, update_handler, delete_handler, queue_suffix="covfefe"
+ )
sensor_watcher.start()
- sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe')
+ sw_queues = self._get_sensor_watcher_amqp_queues(
+ queue_name="st2.sensor.watch.covfefe"
+ )
start = monotonic()
done = False
while not done:
concurrency.sleep(0.01)
- sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe')
+ sw_queues = self._get_sensor_watcher_amqp_queues(
+ queue_name="st2.sensor.watch.covfefe"
+ )
done = len(sw_queues) > 0 or ((monotonic() - start) < 5)
sensor_watcher.stop()
- sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe')
+ sw_queues = self._get_sensor_watcher_amqp_queues(
+ queue_name="st2.sensor.watch.covfefe"
+ )
self.assertTrue(len(sw_queues) == 0)
def _list_amqp_queues(self):
- rabbit_client = Client('localhost:15672', 'guest', 'guest')
- queues = [q['name'] for q in rabbit_client.get_queues()]
+ rabbit_client = Client("localhost:15672", "guest", "guest")
+ queues = [q["name"] for q in rabbit_client.get_queues()]
return queues
def _get_sensor_watcher_amqp_queues(self, queue_name):
diff --git a/st2reactor/tests/unit/test_container_utils.py b/st2reactor/tests/unit/test_container_utils.py
index d8c14bf1d5..24d297ba7d 100644
--- a/st2reactor/tests/unit/test_container_utils.py
+++ b/st2reactor/tests/unit/test_container_utils.py
@@ -23,20 +23,25 @@
from st2tests.base import CleanDbTestCase
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class ContainerUtilsTest(CleanDbTestCase):
def setUp(self):
super(ContainerUtilsTest, self).setUp()
# Insert mock TriggerDB
- trigger_db = TriggerDB(name='name1', pack='pack1', type='type1',
- parameters={'a': 1, 'b': '2', 'c': 'foo'})
+ trigger_db = TriggerDB(
+ name="name1",
+ pack="pack1",
+ type="type1",
+ parameters={"a": 1, "b": "2", "c": "foo"},
+ )
self.trigger_db = Trigger.add_or_update(trigger_db)
def test_create_trigger_instance_invalid_trigger(self):
- trigger_instance = 'dummy_pack.footrigger'
- instance = create_trigger_instance(trigger=trigger_instance, payload={},
- occurrence_time=None)
+ trigger_instance = "dummy_pack.footrigger"
+ instance = create_trigger_instance(
+ trigger=trigger_instance, payload={}, occurrence_time=None
+ )
self.assertIsNone(instance)
def test_create_trigger_instance_success(self):
@@ -46,34 +51,40 @@ def test_create_trigger_instance_success(self):
occurrence_time = None
# TriggerDB look up by id
- trigger = {'id': self.trigger_db.id}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
- self.assertEqual(trigger_instance_db.trigger, 'pack1.name1')
+ trigger = {"id": self.trigger_db.id}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
+ self.assertEqual(trigger_instance_db.trigger, "pack1.name1")
# Object doesn't exist (invalid id)
- trigger = {'id': '5776aa2b0640fd2991b15987'}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
+ trigger = {"id": "5776aa2b0640fd2991b15987"}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
self.assertEqual(trigger_instance_db, None)
# TriggerDB look up by uid
- trigger = {'uid': self.trigger_db.uid}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
- self.assertEqual(trigger_instance_db.trigger, 'pack1.name1')
+ trigger = {"uid": self.trigger_db.uid}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
+ self.assertEqual(trigger_instance_db.trigger, "pack1.name1")
- trigger = {'uid': 'invaliduid'}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
+ trigger = {"uid": "invaliduid"}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
self.assertEqual(trigger_instance_db, None)
# TriggerDB look up by type and parameters (last resort)
- trigger = {'type': 'pack1.name1', 'parameters': self.trigger_db.parameters}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
+ trigger = {"type": "pack1.name1", "parameters": self.trigger_db.parameters}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
- trigger = {'type': 'pack1.name1', 'parameters': {}}
- trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload,
- occurrence_time=occurrence_time)
+ trigger = {"type": "pack1.name1", "parameters": {}}
+ trigger_instance_db = create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
self.assertEqual(trigger_instance_db, None)
diff --git a/st2reactor/tests/unit/test_enforce.py b/st2reactor/tests/unit/test_enforce.py
index 174216dbb4..4b282305bd 100644
--- a/st2reactor/tests/unit/test_enforce.py
+++ b/st2reactor/tests/unit/test_enforce.py
@@ -38,62 +38,68 @@
from st2tests import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'RuleEnforcerTestCase',
- 'RuleEnforcerDataTransformationTestCase'
-]
+__all__ = ["RuleEnforcerTestCase", "RuleEnforcerDataTransformationTestCase"]
-PACK = 'generic'
+PACK = "generic"
FIXTURES_1 = {
- 'runners': ['testrunner1.yaml', 'testrunner2.yaml'],
- 'actions': ['action1.yaml', 'a2.yaml', 'a2_default_value.yaml'],
- 'triggertypes': ['triggertype1.yaml'],
- 'triggers': ['trigger1.yaml'],
- 'traces': ['trace_for_test_enforce.yaml', 'trace_for_test_enforce_2.yaml',
- 'trace_for_test_enforce_3.yaml']
+ "runners": ["testrunner1.yaml", "testrunner2.yaml"],
+ "actions": ["action1.yaml", "a2.yaml", "a2_default_value.yaml"],
+ "triggertypes": ["triggertype1.yaml"],
+ "triggers": ["trigger1.yaml"],
+ "traces": [
+ "trace_for_test_enforce.yaml",
+ "trace_for_test_enforce_2.yaml",
+ "trace_for_test_enforce_3.yaml",
+ ],
}
FIXTURES_2 = {
- 'rules': [
- 'rule1.yaml',
- 'rule2.yaml',
- 'rule_use_none_filter.yaml',
- 'rule_none_no_use_none_filter.yaml',
- 'rule_action_default_value.yaml',
- 'rule_action_default_value_overridden.yaml',
- 'rule_action_default_value_render_fail.yaml'
+ "rules": [
+ "rule1.yaml",
+ "rule2.yaml",
+ "rule_use_none_filter.yaml",
+ "rule_none_no_use_none_filter.yaml",
+ "rule_action_default_value.yaml",
+ "rule_action_default_value_overridden.yaml",
+ "rule_action_default_value_render_fail.yaml",
]
}
MOCK_TRIGGER_INSTANCE = TriggerInstanceDB()
-MOCK_TRIGGER_INSTANCE.id = 'triggerinstance-test'
-MOCK_TRIGGER_INSTANCE.payload = {'t1_p': 't1_p_v'}
+MOCK_TRIGGER_INSTANCE.id = "triggerinstance-test"
+MOCK_TRIGGER_INSTANCE.payload = {"t1_p": "t1_p_v"}
MOCK_TRIGGER_INSTANCE.occurrence_time = date_utils.get_datetime_utc_now()
MOCK_TRIGGER_INSTANCE_2 = TriggerInstanceDB()
-MOCK_TRIGGER_INSTANCE_2.id = 'triggerinstance-test2'
-MOCK_TRIGGER_INSTANCE_2.payload = {'t1_p': None}
+MOCK_TRIGGER_INSTANCE_2.id = "triggerinstance-test2"
+MOCK_TRIGGER_INSTANCE_2.payload = {"t1_p": None}
MOCK_TRIGGER_INSTANCE_2.occurrence_time = date_utils.get_datetime_utc_now()
MOCK_TRIGGER_INSTANCE_3 = TriggerInstanceDB()
-MOCK_TRIGGER_INSTANCE_3.id = 'triggerinstance-test3'
-MOCK_TRIGGER_INSTANCE_3.payload = {'t1_p': None, 't2_p': 'value2'}
+MOCK_TRIGGER_INSTANCE_3.id = "triggerinstance-test3"
+MOCK_TRIGGER_INSTANCE_3.payload = {"t1_p": None, "t2_p": "value2"}
MOCK_TRIGGER_INSTANCE_3.occurrence_time = date_utils.get_datetime_utc_now()
-MOCK_TRIGGER_INSTANCE_PAYLOAD = {'k1': 'v1', 'k2': 'v2', 'k3': 3, 'k4': True,
- 'k5': {'foo': 'bar'}, 'k6': [1, 3]}
+MOCK_TRIGGER_INSTANCE_PAYLOAD = {
+ "k1": "v1",
+ "k2": "v2",
+ "k3": 3,
+ "k4": True,
+ "k5": {"foo": "bar"},
+ "k6": [1, 3],
+}
MOCK_TRIGGER_INSTANCE_4 = TriggerInstanceDB()
-MOCK_TRIGGER_INSTANCE_4.id = 'triggerinstance-test4'
+MOCK_TRIGGER_INSTANCE_4.id = "triggerinstance-test4"
MOCK_TRIGGER_INSTANCE_4.payload = MOCK_TRIGGER_INSTANCE_PAYLOAD
MOCK_TRIGGER_INSTANCE_4.occurrence_time = date_utils.get_datetime_utc_now()
MOCK_LIVEACTION = LiveActionDB()
-MOCK_LIVEACTION.id = 'liveaction-test-1.id'
-MOCK_LIVEACTION.status = 'requested'
+MOCK_LIVEACTION.id = "liveaction-test-1.id"
+MOCK_LIVEACTION.status = "requested"
MOCK_EXECUTION = ActionExecutionDB()
-MOCK_EXECUTION.id = 'exec-test-1.id'
-MOCK_EXECUTION.status = 'requested'
+MOCK_EXECUTION.id = "exec-test-1.id"
+MOCK_EXECUTION.status = "requested"
FAILURE_REASON = "fail!"
@@ -111,11 +117,16 @@ def setUpClass(cls):
# Create TriggerTypes before creation of Rule to avoid failure. Rule requires the
# Trigger and therefore TriggerType to be created prior to rule creation.
cls.models = FixturesLoader().save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_1)
- cls.models.update(FixturesLoader().save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_2))
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_1
+ )
+ cls.models.update(
+ FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_2
+ )
+ )
MOCK_TRIGGER_INSTANCE.trigger = reference.get_ref_from_model(
- cls.models['triggers']['trigger1.yaml'])
+ cls.models["triggers"]["trigger1.yaml"]
+ )
def setUp(self):
super(BaseRuleEnforcerTestCase, self).setUp()
@@ -124,335 +135,445 @@ def setUp(self):
class RuleEnforcerTestCase(BaseRuleEnforcerTestCase):
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
def test_ruleenforcement_occurs(self):
- enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml'])
+ enforcer = RuleEnforcer(
+ MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"]
+ )
execution_db = enforcer.enforce()
self.assertIsNotNone(execution_db)
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
def test_ruleenforcement_casts(self):
- enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml'])
+ enforcer = RuleEnforcer(
+ MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"]
+ )
execution_db = enforcer.enforce()
self.assertIsNotNone(execution_db)
self.assertTrue(action_service.request.called)
- self.assertIsInstance(action_service.request.call_args[0][0].parameters['objtype'], dict)
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
+ self.assertIsInstance(
+ action_service.request.call_args[0][0].parameters["objtype"], dict
+ )
+
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
def test_ruleenforcement_create_on_success(self):
- enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml'])
+ enforcer = RuleEnforcer(
+ MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"]
+ )
execution_db = enforcer.enforce()
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
- self.models['rules']['rule2.yaml'].ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
+ self.models["rules"]["rule2.yaml"].ref,
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
+
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
def test_rule_enforcement_create_rule_none_param_casting(self):
mock_trigger_instance = MOCK_TRIGGER_INSTANCE_2
# 1. Non None value, should be serialized as regular string
- mock_trigger_instance.payload = {'t1_p': 'somevalue'}
+ mock_trigger_instance.payload = {"t1_p": "somevalue"}
def mock_cast_string(x):
- assert x == 'somevalue'
+ assert x == "somevalue"
return casts._cast_string(x)
- casts.CASTS['string'] = mock_cast_string
- enforcer = RuleEnforcer(mock_trigger_instance,
- self.models['rules']['rule_use_none_filter.yaml'])
+ casts.CASTS["string"] = mock_cast_string
+
+ enforcer = RuleEnforcer(
+ mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"]
+ )
execution_db = enforcer.enforce()
# Verify value has been serialized correctly
call_args = action_service.request.call_args[0]
live_action_db = call_args[0]
- self.assertEqual(live_action_db.parameters['actionstr'], 'somevalue')
+ self.assertEqual(live_action_db.parameters["actionstr"], "somevalue")
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
- self.models['rules']['rule_use_none_filter.yaml'].ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
+ self.models["rules"]["rule_use_none_filter.yaml"].ref,
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
# 2. Verify that None type from trigger instance is correctly serialized to
# None when using "use_none" Jinja filter when invoking an action
- mock_trigger_instance.payload = {'t1_p': None}
+ mock_trigger_instance.payload = {"t1_p": None}
def mock_cast_string(x):
assert x == data.NONE_MAGIC_VALUE
return casts._cast_string(x)
- casts.CASTS['string'] = mock_cast_string
- enforcer = RuleEnforcer(mock_trigger_instance,
- self.models['rules']['rule_use_none_filter.yaml'])
+ casts.CASTS["string"] = mock_cast_string
+
+ enforcer = RuleEnforcer(
+ mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"]
+ )
execution_db = enforcer.enforce()
# Verify None has been correctly serialized to None
call_args = action_service.request.call_args[0]
live_action_db = call_args[0]
- self.assertEqual(live_action_db.parameters['actionstr'], None)
+ self.assertEqual(live_action_db.parameters["actionstr"], None)
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
- self.models['rules']['rule_use_none_filter.yaml'].ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
+ self.models["rules"]["rule_use_none_filter.yaml"].ref,
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
- casts.CASTS['string'] = casts._cast_string
+ casts.CASTS["string"] = casts._cast_string
# 3. Parameter value is a compound string one of which values is None, but "use_none"
# filter is not used
mock_trigger_instance = MOCK_TRIGGER_INSTANCE_3
- mock_trigger_instance.payload = {'t1_p': None, 't2_p': 'value2'}
+ mock_trigger_instance.payload = {"t1_p": None, "t2_p": "value2"}
- enforcer = RuleEnforcer(mock_trigger_instance,
- self.models['rules']['rule_none_no_use_none_filter.yaml'])
+ enforcer = RuleEnforcer(
+ mock_trigger_instance,
+ self.models["rules"]["rule_none_no_use_none_filter.yaml"],
+ )
execution_db = enforcer.enforce()
# Verify None has been correctly serialized to None
call_args = action_service.request.call_args[0]
live_action_db = call_args[0]
- self.assertEqual(live_action_db.parameters['actionstr'], 'None-value2')
+ self.assertEqual(live_action_db.parameters["actionstr"], "None-value2")
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
- self.models['rules']['rule_none_no_use_none_filter.yaml'].ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
-
- casts.CASTS['string'] = casts._cast_string
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- side_effect=ValueError(FAILURE_REASON)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref,
+ self.models["rules"]["rule_none_no_use_none_filter.yaml"].ref,
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
+
+ casts.CASTS["string"] = casts._cast_string
+
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(side_effect=ValueError(FAILURE_REASON)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
def test_ruleenforcement_create_on_fail(self):
- enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml'])
+ enforcer = RuleEnforcer(
+ MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"]
+ )
execution_db = enforcer.enforce()
self.assertIsNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason,
- FAILURE_REASON)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_FAILED)
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
- @mock.patch('st2common.util.param.get_config',
- mock.Mock(return_value={'arrtype_value': ['one 1', 'two 2', 'three 3']}))
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].failure_reason, FAILURE_REASON
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_FAILED,
+ )
+
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
+ @mock.patch(
+ "st2common.util.param.get_config",
+ mock.Mock(return_value={"arrtype_value": ["one 1", "two 2", "three 3"]}),
+ )
def test_action_default_jinja_parameter_value_is_rendered(self):
# Verify that a default action parameter which is a Jinja variable is correctly rendered
- rule = self.models['rules']['rule_action_default_value.yaml']
+ rule = self.models["rules"]["rule_action_default_value.yaml"]
enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule)
execution_db = enforcer.enforce()
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
call_parameters = action_service.request.call_args[0][0].parameters
- self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'})
- self.assertEqual(call_parameters['strtype'], 't1_p_v')
- self.assertEqual(call_parameters['arrtype'], ['one 1', 'two 2', 'three 3'])
+ self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"})
+ self.assertEqual(call_parameters["strtype"], "t1_p_v")
+ self.assertEqual(call_parameters["arrtype"], ["one 1", "two 2", "three 3"])
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
def test_action_default_jinja_parameter_value_overridden_in_rule(self):
# Verify that it works correctly if default parameter value is overridden in rule
- rule = self.models['rules']['rule_action_default_value_overridden.yaml']
+ rule = self.models["rules"]["rule_action_default_value_overridden.yaml"]
enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule)
execution_db = enforcer.enforce()
self.assertIsNotNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_SUCCEEDED)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_SUCCEEDED,
+ )
call_parameters = action_service.request.call_args[0][0].parameters
- self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'})
- self.assertEqual(call_parameters['strtype'], 't1_p_v')
- self.assertEqual(call_parameters['arrtype'], ['override 1', 'override 2'])
-
- @mock.patch.object(action_service, 'request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(action_service, 'create_request', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(action_service, 'update_status', mock.MagicMock(
- return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)))
- @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock())
+ self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"})
+ self.assertEqual(call_parameters["strtype"], "t1_p_v")
+ self.assertEqual(call_parameters["arrtype"], ["override 1", "override 2"])
+
+ @mock.patch.object(
+ action_service,
+ "request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(
+ action_service,
+ "create_request",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(
+ action_service,
+ "update_status",
+ mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)),
+ )
+ @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock())
def test_action_default_jinja_parameter_value_render_fail(self):
# Action parameter render failure should result in a failed execution
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule)
execution_db = enforcer.enforce()
self.assertIsNone(execution_db)
self.assertTrue(RuleEnforcement.add_or_update.called)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status,
- RULE_ENFORCEMENT_STATUS_FAILED)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref
+ )
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].status,
+ RULE_ENFORCEMENT_STATUS_FAILED,
+ )
self.assertFalse(action_service.request.called)
self.assertTrue(action_service.create_request.called)
- self.assertEqual(action_service.create_request.call_args[0][0].action,
- 'wolfpack.a2_default_value')
+ self.assertEqual(
+ action_service.create_request.call_args[0][0].action,
+ "wolfpack.a2_default_value",
+ )
self.assertTrue(action_service.update_status.called)
- self.assertEqual(action_service.update_status.call_args[1]['new_status'],
- action_constants.LIVEACTION_STATUS_FAILED)
+ self.assertEqual(
+ action_service.update_status.call_args[1]["new_status"],
+ action_constants.LIVEACTION_STATUS_FAILED,
+ )
- expected_msg = ('Failed to render parameter "arrtype": \'dict object\' has no '
- 'attribute \'arrtype_value\'')
+ expected_msg = (
+ "Failed to render parameter \"arrtype\": 'dict object' has no "
+ "attribute 'arrtype_value'"
+ )
- result = action_service.update_status.call_args[1]['result']
- self.assertEqual(result['error'], expected_msg)
+ result = action_service.update_status.call_args[1]["result"]
+ self.assertEqual(result["error"], expected_msg)
- self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason,
- expected_msg)
+ self.assertEqual(
+ RuleEnforcement.add_or_update.call_args[0][0].failure_reason, expected_msg
+ )
class RuleEnforcerDataTransformationTestCase(BaseRuleEnforcerTestCase):
-
def test_payload_data_transform(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
- params = {'ip1': '{{trigger.k1}}-static',
- 'ip2': '{{trigger.k2}} static'}
+ params = {"ip1": "{{trigger.k1}}-static", "ip2": "{{trigger.k2}} static"}
- expected_params = {'ip1': 'v1-static', 'ip2': 'v2 static'}
+ expected_params = {"ip1": "v1-static", "ip2": "v2 static"}
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
def test_payload_transforms_int_type(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
- params = {'int': 666}
- expected_params = {'int': 666}
+ params = {"int": 666}
+ expected_params = {"int": 666}
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
def test_payload_transforms_bool_type(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
runner_type_db = mock.Mock()
runner_type_db.runner_parameters = {}
action_db = mock.Mock()
action_db.parameters = {}
- params = {'bool': True}
- expected_params = {'bool': True}
+ params = {"bool": True}
+ expected_params = {"bool": True}
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
def test_payload_transforms_complex_type(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
runner_type_db = mock.Mock()
runner_type_db.runner_parameters = {}
action_db = mock.Mock()
action_db.parameters = {}
- params = {'complex_dict': {'bool': True, 'int': 666, 'str': '{{trigger.k1}}-string'}}
- expected_params = {'complex_dict': {'bool': True, 'int': 666, 'str': 'v1-string'}}
+ params = {
+ "complex_dict": {"bool": True, "int": 666, "str": "{{trigger.k1}}-string"}
+ }
+ expected_params = {
+ "complex_dict": {"bool": True, "int": 666, "str": "v1-string"}
+ }
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
- params = {'simple_list': [1, 2, 3]}
- expected_params = {'simple_list': [1, 2, 3]}
+ params = {"simple_list": [1, 2, 3]}
+ expected_params = {"simple_list": [1, 2, 3]}
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
def test_hypenated_payload_transform(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
- payload = {'headers': {'hypenated-header': 'dont-care'}, 'k2': 'v2'}
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
+ payload = {"headers": {"hypenated-header": "dont-care"}, "k2": "v2"}
MOCK_TRIGGER_INSTANCE_4.payload = payload
- params = {'ip1': '{{trigger.headers[\'hypenated-header\']}}-static',
- 'ip2': '{{trigger.k2}} static'}
- expected_params = {'ip1': 'dont-care-static', 'ip2': 'v2 static'}
-
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ params = {
+ "ip1": "{{trigger.headers['hypenated-header']}}-static",
+ "ip2": "{{trigger.k2}} static",
+ }
+ expected_params = {"ip1": "dont-care-static", "ip2": "v2 static"}
+
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
def test_system_transform(self):
- rule = self.models['rules']['rule_action_default_value_render_fail.yaml']
+ rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"]
runner_type_db = mock.Mock()
runner_type_db.runner_parameters = {}
action_db = mock.Mock()
action_db.parameters = {}
- k5 = KeyValuePair.add_or_update(KeyValuePairDB(name='k5', value='v5'))
- k6 = KeyValuePair.add_or_update(KeyValuePairDB(name='k6', value='v6'))
- k7 = KeyValuePair.add_or_update(KeyValuePairDB(name='k7', value='v7'))
- k8 = KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value='v8',
- scope=FULL_SYSTEM_SCOPE))
+ k5 = KeyValuePair.add_or_update(KeyValuePairDB(name="k5", value="v5"))
+ k6 = KeyValuePair.add_or_update(KeyValuePairDB(name="k6", value="v6"))
+ k7 = KeyValuePair.add_or_update(KeyValuePairDB(name="k7", value="v7"))
+ k8 = KeyValuePair.add_or_update(
+ KeyValuePairDB(name="k8", value="v8", scope=FULL_SYSTEM_SCOPE)
+ )
- params = {'ip5': '{{trigger.k2}}-static',
- 'ip6': '{{st2kv.system.k6}}-static',
- 'ip7': '{{st2kv.system.k7}}-static'}
- expected_params = {'ip5': 'v2-static',
- 'ip6': 'v6-static',
- 'ip7': 'v7-static'}
+ params = {
+ "ip5": "{{trigger.k2}}-static",
+ "ip6": "{{st2kv.system.k6}}-static",
+ "ip7": "{{st2kv.system.k7}}-static",
+ }
+ expected_params = {"ip5": "v2-static", "ip6": "v6-static", "ip7": "v7-static"}
try:
- self.assertResolvedParamsMatchExpected(rule=rule,
- trigger_instance=MOCK_TRIGGER_INSTANCE_4,
- params=params,
- expected_params=expected_params)
+ self.assertResolvedParamsMatchExpected(
+ rule=rule,
+ trigger_instance=MOCK_TRIGGER_INSTANCE_4,
+ params=params,
+ expected_params=expected_params,
+ )
finally:
KeyValuePair.delete(k5)
KeyValuePair.delete(k6)
KeyValuePair.delete(k7)
KeyValuePair.delete(k8)
- def assertResolvedParamsMatchExpected(self, rule, trigger_instance, params, expected_params):
+ def assertResolvedParamsMatchExpected(
+ self, rule, trigger_instance, params, expected_params
+ ):
runner_type_db = mock.Mock()
runner_type_db.runner_parameters = {}
action_db = mock.Mock()
action_db.parameters = {}
enforcer = RuleEnforcer(trigger_instance, rule)
- context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db)
+ context, additional_contexts = enforcer.get_action_execution_context(
+ action_db=action_db
+ )
- resolved_params = enforcer.get_resolved_parameters(action_db=action_db,
+ resolved_params = enforcer.get_resolved_parameters(
+ action_db=action_db,
runnertype_db=runner_type_db,
params=params,
context=context,
- additional_contexts=additional_contexts)
+ additional_contexts=additional_contexts,
+ )
self.assertEqual(resolved_params, expected_params)
diff --git a/st2reactor/tests/unit/test_filter.py b/st2reactor/tests/unit/test_filter.py
index 4df7ef2360..d1e42eaece 100644
--- a/st2reactor/tests/unit/test_filter.py
+++ b/st2reactor/tests/unit/test_filter.py
@@ -27,57 +27,71 @@
from st2tests import DbTestCase
-MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', type='system.test')
+MOCK_TRIGGER = TriggerDB(
+ pack="dummy_pack_1", name="trigger-test.name", type="system.test"
+)
MOCK_TRIGGER_INSTANCE = TriggerInstanceDB(
trigger=MOCK_TRIGGER.get_reference().ref,
occurrence_time=date_utils.get_datetime_utc_now(),
payload={
- 'p1': 'v1',
- 'p2': 'preYYYpost',
- 'bool': True,
- 'int': 1,
- 'float': 0.8,
- 'list': ['v1', True, 1],
- 'recursive_list': [
+ "p1": "v1",
+ "p2": "preYYYpost",
+ "bool": True,
+ "int": 1,
+ "float": 0.8,
+ "list": ["v1", True, 1],
+ "recursive_list": [
{
- 'field_name': "Status",
- 'to_value': "Approved",
- }, {
- 'field_name': "Signed off by",
- 'to_value': "Stanley",
- }
+ "field_name": "Status",
+ "to_value": "Approved",
+ },
+ {
+ "field_name": "Signed off by",
+ "to_value": "Stanley",
+ },
],
- }
+ },
)
-MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack='wolfpack', name='action-test-1.name')
+MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack="wolfpack", name="action-test-1.name")
-MOCK_RULE_1 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some1',
- trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER),
- criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction"))
+MOCK_RULE_1 = RuleDB(
+ id=bson.ObjectId(),
+ pack="wolfpack",
+ name="some1",
+ trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER),
+ criteria={},
+ action=ActionExecutionSpecDB(ref="somepack.someaction"),
+)
-MOCK_RULE_2 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some2',
- trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER),
- criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction"))
+MOCK_RULE_2 = RuleDB(
+ id=bson.ObjectId(),
+ pack="wolfpack",
+ name="some2",
+ trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER),
+ criteria={},
+ action=ActionExecutionSpecDB(ref="somepack.someaction"),
+)
-@mock.patch.object(reference, 'get_model_by_resource_ref',
- mock.MagicMock(return_value=MOCK_TRIGGER))
+@mock.patch.object(
+ reference, "get_model_by_resource_ref", mock.MagicMock(return_value=MOCK_TRIGGER)
+)
class FilterTest(DbTestCase):
def test_empty_criteria(self):
rule = MOCK_RULE_1
rule.criteria = {}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have failed.')
+ self.assertTrue(f.filter(), "equals check should have failed.")
def test_empty_payload(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}}
+ rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}}
trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE)
trigger_instance.payload = None
f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'equals check should have failed.')
+ self.assertFalse(f.filter(), "equals check should have failed.")
def test_empty_criteria_and_empty_payload(self):
rule = MOCK_RULE_1
@@ -85,234 +99,247 @@ def test_empty_criteria_and_empty_payload(self):
trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE)
trigger_instance.payload = None
f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have failed.')
+ self.assertTrue(f.filter(), "equals check should have failed.")
def test_search_operator_pass_any_criteria(self):
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.recursive_list': {
- 'type': 'search',
- 'condition': 'any',
- 'pattern': {
- 'item.field_name': {
- 'type': 'equals',
- 'pattern': 'Status',
+ "trigger.recursive_list": {
+ "type": "search",
+ "condition": "any",
+ "pattern": {
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
},
- 'item.to_value': {
- 'type': 'equals',
- 'pattern': 'Approved'
- }
- }
+ "item.to_value": {"type": "equals", "pattern": "Approved"},
+ },
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'Failed evaluation')
+ self.assertTrue(f.filter(), "Failed evaluation")
def test_search_operator_fail_any_criteria(self):
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.recursive_list': {
- 'type': 'search',
- 'condition': 'any',
- 'pattern': {
- 'item.field_name': {
- 'type': 'equals',
- 'pattern': 'Status',
+ "trigger.recursive_list": {
+ "type": "search",
+ "condition": "any",
+ "pattern": {
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
},
- 'item.to_value': {
- 'type': 'equals',
- 'pattern': 'Denied',
- }
- }
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Denied",
+ },
+ },
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'Passed evaluation')
+ self.assertFalse(f.filter(), "Passed evaluation")
def test_search_operator_pass_all_criteria(self):
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.recursive_list': {
- 'type': 'search',
- 'condition': 'all',
- 'pattern': {
- 'item.field_name': {
- 'type': 'startswith',
- 'pattern': 'S',
+ "trigger.recursive_list": {
+ "type": "search",
+ "condition": "all",
+ "pattern": {
+ "item.field_name": {
+ "type": "startswith",
+ "pattern": "S",
}
- }
+ },
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'Failed evaluation')
+ self.assertTrue(f.filter(), "Failed evaluation")
def test_search_operator_fail_all_criteria(self):
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.recursive_list': {
- 'type': 'search',
- 'condition': 'all',
- 'pattern': {
- 'item.field_name': {
- 'type': 'equals',
- 'pattern': 'Status',
+ "trigger.recursive_list": {
+ "type": "search",
+ "condition": "all",
+ "pattern": {
+ "item.field_name": {
+ "type": "equals",
+ "pattern": "Status",
},
- 'item.to_value': {
- 'type': 'equals',
- 'pattern': 'Denied',
- }
- }
+ "item.to_value": {
+ "type": "equals",
+ "pattern": "Denied",
+ },
+ },
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'Passed evaluation')
+ self.assertFalse(f.filter(), "Passed evaluation")
def test_matchregex_operator_pass_criteria(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v1$'}}
+ rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v1$"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'Failed to pass evaluation.')
+ self.assertTrue(f.filter(), "Failed to pass evaluation.")
def test_matchregex_operator_fail_criteria(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v$'}}
+ rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v$"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'regex check should have failed.')
+ self.assertFalse(f.filter(), "regex check should have failed.")
def test_equals_operator_pass_criteria(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}}
+ rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p1}}'}}
+ rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p1}}"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.p1': {
- 'type': 'equals',
- 'pattern': "{{'%s' % trigger.p1 if trigger.int}}"
+ "trigger.p1": {
+ "type": "equals",
+ "pattern": "{{'%s' % trigger.p1 if trigger.int}}",
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
# Test our filter works if proper JSON is returned from user pattern
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.list': {
- 'type': 'equals',
- 'pattern': """
+ "trigger.list": {
+ "type": "equals",
+ "pattern": """
[
{% for item in trigger.list %}
{{item}}{% if not loop.last %},{% endif %}
{% endfor %}
]
- """
+ """,
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
def test_equals_operator_fail_criteria(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v'}}
+ rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'equals check should have failed.')
+ self.assertFalse(f.filter(), "equals check should have failed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p2}}'}}
+ rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p2}}"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'equals check should have failed.')
+ self.assertFalse(f.filter(), "equals check should have failed.")
rule = MOCK_RULE_1
rule.criteria = {
- 'trigger.list': {
- 'type': 'equals',
- 'pattern': """
+ "trigger.list": {
+ "type": "equals",
+ "pattern": """
[
{% for item in trigger.list %}
{{item}}
{% endfor %}
]
- """
+ """,
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'equals check should have failed.')
+ self.assertFalse(f.filter(), "equals check should have failed.")
def test_equals_bool_value(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': True}}
+ rule.criteria = {"trigger.bool": {"type": "equals", "pattern": True}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{trigger.bool}}'}}
+ rule.criteria = {
+ "trigger.bool": {"type": "equals", "pattern": "{{trigger.bool}}"}
+ }
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{ trigger.bool }}'}}
+ rule.criteria = {
+ "trigger.bool": {"type": "equals", "pattern": "{{ trigger.bool }}"}
+ }
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
def test_equals_int_value(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': 1}}
+ rule.criteria = {"trigger.int": {"type": "equals", "pattern": 1}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': '{{trigger.int}}'}}
+ rule.criteria = {
+ "trigger.int": {"type": "equals", "pattern": "{{trigger.int}}"}
+ }
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
def test_equals_float_value(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': 0.8}}
+ rule.criteria = {"trigger.float": {"type": "equals", "pattern": 0.8}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
rule = MOCK_RULE_1
- rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': '{{trigger.float}}'}}
+ rule.criteria = {
+ "trigger.float": {"type": "equals", "pattern": "{{trigger.float}}"}
+ }
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'equals check should have passed.')
+ self.assertTrue(f.filter(), "equals check should have passed.")
def test_exists(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.float': {'type': 'exists'}}
+ rule.criteria = {"trigger.float": {"type": "exists"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), '"float" key exists in trigger. Should return true.')
- rule.criteria = {'trigger.floattt': {'type': 'exists'}}
+ self.assertTrue(
+ f.filter(), '"float" key exists in trigger. Should return true.'
+ )
+ rule.criteria = {"trigger.floattt": {"type": "exists"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.')
+ self.assertFalse(
+ f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.'
+ )
def test_nexists(self):
rule = MOCK_RULE_1
- rule.criteria = {'trigger.float': {'type': 'nexists'}}
+ rule.criteria = {"trigger.float": {"type": "nexists"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), '"float" key exists in trigger. Should return false.')
- rule.criteria = {'trigger.floattt': {'type': 'nexists'}}
+ self.assertFalse(
+ f.filter(), '"float" key exists in trigger. Should return false.'
+ )
+ rule.criteria = {"trigger.floattt": {"type": "nexists"}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.')
+ self.assertTrue(
+ f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.'
+ )
def test_gt_lt_falsy_pattern(self):
# Make sure that the falsy value (number 0) is handled correctly
rule = MOCK_RULE_1
- rule.criteria = {'trigger.int': {'type': 'gt', 'pattern': 0}}
+ rule.criteria = {"trigger.int": {"type": "gt", "pattern": 0}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertTrue(f.filter(), 'trigger value is gt than 0 but didn\'t match')
+ self.assertTrue(f.filter(), "trigger value is gt than 0 but didn't match")
- rule.criteria = {'trigger.int': {'type': 'lt', 'pattern': 0}}
+ rule.criteria = {"trigger.int": {"type": "lt", "pattern": 0}}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
- self.assertFalse(f.filter(), 'trigger value is gt than 0 but didn\'t fail')
+ self.assertFalse(f.filter(), "trigger value is gt than 0 but didn't fail")
- @mock.patch('st2common.util.templating.KeyValueLookup')
+ @mock.patch("st2common.util.templating.KeyValueLookup")
def test_criteria_pattern_references_a_datastore_item(self, mock_KeyValueLookup):
class MockResultLookup(object):
pass
@@ -323,22 +350,24 @@ class MockSystemLookup(object):
rule = MOCK_RULE_2
# Using a variable in pattern, referencing an inexistent datastore value
- rule.criteria = {'trigger.p1': {
- 'type': 'equals',
- 'pattern': '{{ st2kv.system.inexistent_value }}'}
+ rule.criteria = {
+ "trigger.p1": {
+ "type": "equals",
+ "pattern": "{{ st2kv.system.inexistent_value }}",
+ }
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
self.assertFalse(f.filter())
# Using a variable in pattern, referencing an existing value which doesn't match
mock_result = MockSystemLookup()
- mock_result.test_value_1 = 'non matching'
+ mock_result.test_value_1 = "non matching"
mock_KeyValueLookup.return_value = mock_result
rule.criteria = {
- 'trigger.p1': {
- 'type': 'equals',
- 'pattern': '{{ st2kv.system.test_value_1 }}'
+ "trigger.p1": {
+ "type": "equals",
+ "pattern": "{{ st2kv.system.test_value_1 }}",
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
@@ -346,13 +375,13 @@ class MockSystemLookup(object):
# Using a variable in pattern, referencing an existing value which does match
mock_result = MockSystemLookup()
- mock_result.test_value_2 = 'v1'
+ mock_result.test_value_2 = "v1"
mock_KeyValueLookup.return_value = mock_result
rule.criteria = {
- 'trigger.p1': {
- 'type': 'equals',
- 'pattern': '{{ st2kv.system.test_value_2 }}'
+ "trigger.p1": {
+ "type": "equals",
+ "pattern": "{{ st2kv.system.test_value_2 }}",
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
@@ -360,13 +389,13 @@ class MockSystemLookup(object):
# Using a variable in pattern, referencing an existing value which matches partially
mock_result = MockSystemLookup()
- mock_result.test_value_3 = 'YYY'
+ mock_result.test_value_3 = "YYY"
mock_KeyValueLookup.return_value = mock_result
rule.criteria = {
- 'trigger.p2': {
- 'type': 'equals',
- 'pattern': '{{ st2kv.system.test_value_3 }}'
+ "trigger.p2": {
+ "type": "equals",
+ "pattern": "{{ st2kv.system.test_value_3 }}",
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
@@ -374,13 +403,13 @@ class MockSystemLookup(object):
# Using a variable in pattern, referencing an existing value which matches partially
mock_result = MockSystemLookup()
- mock_result.test_value_3 = 'YYY'
+ mock_result.test_value_3 = "YYY"
mock_KeyValueLookup.return_value = mock_result
rule.criteria = {
- 'trigger.p2': {
- 'type': 'equals',
- 'pattern': 'pre{{ st2kv.system.test_value_3 }}post'
+ "trigger.p2": {
+ "type": "equals",
+ "pattern": "pre{{ st2kv.system.test_value_3 }}post",
}
}
f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule)
diff --git a/st2reactor/tests/unit/test_garbage_collector.py b/st2reactor/tests/unit/test_garbage_collector.py
index 31442e8eb3..93de6b25d0 100644
--- a/st2reactor/tests/unit/test_garbage_collector.py
+++ b/st2reactor/tests/unit/test_garbage_collector.py
@@ -21,43 +21,48 @@
from oslo_config import cfg
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2reactor.garbage_collector import base as garbage_collector
class GarbageCollectorServiceTest(unittest.TestCase):
-
def tearDown(self):
# Reset gc_max_idle_sec with a value of 1 to reenable for other tests.
- cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine')
+ cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine")
super(GarbageCollectorServiceTest, self).tearDown()
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_action_executions',
- mock.MagicMock(return_value=None))
+ "_purge_action_executions",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_action_executions_output',
- mock.MagicMock(return_value=None))
+ "_purge_action_executions_output",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_trigger_instances',
- mock.MagicMock(return_value=None))
+ "_purge_trigger_instances",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_timeout_inquiries',
- mock.MagicMock(return_value=None))
+ "_timeout_inquiries",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_orphaned_workflow_executions',
- mock.MagicMock(return_value=None))
+ "_purge_orphaned_workflow_executions",
+ mock.MagicMock(return_value=None),
+ )
def test_orphaned_workflow_executions_gc_enabled(self):
# Mock the default value of gc_max_idle_sec with a value >= 1 to enable. The config
# gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see
# whether to run the routine.
- cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine')
+ cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine")
# Run the garbage collection.
gc = garbage_collector.GarbageCollectorService(sleep_delay=0)
@@ -70,29 +75,34 @@ def test_orphaned_workflow_executions_gc_enabled(self):
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_action_executions',
- mock.MagicMock(return_value=None))
+ "_purge_action_executions",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_action_executions_output',
- mock.MagicMock(return_value=None))
+ "_purge_action_executions_output",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_trigger_instances',
- mock.MagicMock(return_value=None))
+ "_purge_trigger_instances",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_timeout_inquiries',
- mock.MagicMock(return_value=None))
+ "_timeout_inquiries",
+ mock.MagicMock(return_value=None),
+ )
@mock.patch.object(
garbage_collector.GarbageCollectorService,
- '_purge_orphaned_workflow_executions',
- mock.MagicMock(return_value=None))
+ "_purge_orphaned_workflow_executions",
+ mock.MagicMock(return_value=None),
+ )
def test_orphaned_workflow_executions_gc_disabled(self):
# Mock the default value of gc_max_idle_sec with a value of 0 to disable. The config
# gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see
# whether to run the routine.
- cfg.CONF.set_override('gc_max_idle_sec', 0, group='workflow_engine')
+ cfg.CONF.set_override("gc_max_idle_sec", 0, group="workflow_engine")
# Run the garbage collection.
gc = garbage_collector.GarbageCollectorService(sleep_delay=0)
diff --git a/st2reactor/tests/unit/test_hash_partitioner.py b/st2reactor/tests/unit/test_hash_partitioner.py
index 4412c07b97..12e522a10c 100644
--- a/st2reactor/tests/unit/test_hash_partitioner.py
+++ b/st2reactor/tests/unit/test_hash_partitioner.py
@@ -22,10 +22,8 @@
from st2tests import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-PACK = 'generic'
-FIXTURES_1 = {
- 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml']
-}
+PACK = "generic"
+FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]}
class HashPartitionerTest(DbTestCase):
@@ -38,39 +36,42 @@ def setUpClass(cls):
# Create TriggerTypes before creation of Rule to avoid failure. Rule requires the
# Trigger and therefore TriggerType to be created prior to rule creation.
cls.models = FixturesLoader().save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_1)
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_1
+ )
config.parse_args()
def test_full_range_hash_partitioner(self):
- partitioner = HashPartitioner('node1', 'MIN..MAX')
+ partitioner = HashPartitioner("node1", "MIN..MAX")
sensors = partitioner.get_sensors()
- self.assertEqual(len(sensors), 3, 'Expected all sensors')
+ self.assertEqual(len(sensors), 3, "Expected all sensors")
def test_multi_range_hash_partitioner(self):
range_third = int(Range.RANGE_MAX_VALUE / 3)
range_two_third = range_third * 2
- hash_ranges = \
- 'MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX'.format(
- range_third=range_third, range_two_third=range_two_third)
- partitioner = HashPartitioner('node1', hash_ranges)
+ hash_ranges = "MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX".format(
+ range_third=range_third, range_two_third=range_two_third
+ )
+ partitioner = HashPartitioner("node1", hash_ranges)
sensors = partitioner.get_sensors()
- self.assertEqual(len(sensors), 3, 'Expected all sensors')
+ self.assertEqual(len(sensors), 3, "Expected all sensors")
def test_split_range_hash_partitioner(self):
range_mid = int(Range.RANGE_MAX_VALUE / 2)
- partitioner = HashPartitioner('node1', 'MIN..%s' % range_mid)
+ partitioner = HashPartitioner("node1", "MIN..%s" % range_mid)
sensors1 = partitioner.get_sensors()
- partitioner = HashPartitioner('node2', '%s..MAX' % range_mid)
+ partitioner = HashPartitioner("node2", "%s..MAX" % range_mid)
sensors2 = partitioner.get_sensors()
- self.assertEqual(len(sensors1) + len(sensors2), 3, 'Expected all sensors')
+ self.assertEqual(len(sensors1) + len(sensors2), 3, "Expected all sensors")
def test_hash_effectiveness(self):
range_third = int(Range.RANGE_MAX_VALUE / 3)
- partitioner1 = HashPartitioner('node1', 'MIN..%s' % range_third)
- partitioner2 = HashPartitioner('node2', '%s..%s' % (range_third, range_third + range_third))
- partitioner3 = HashPartitioner('node2', '%s..MAX' % (range_third + range_third))
+ partitioner1 = HashPartitioner("node1", "MIN..%s" % range_third)
+ partitioner2 = HashPartitioner(
+ "node2", "%s..%s" % (range_third, range_third + range_third)
+ )
+ partitioner3 = HashPartitioner("node2", "%s..MAX" % (range_third + range_third))
refs_count = 1000
@@ -89,15 +90,21 @@ def test_hash_effectiveness(self):
if partitioner3._is_in_hash_range(ref):
p3_count += 1
- self.assertEqual(p1_count + p2_count + p3_count, refs_count,
- 'Sum should equal all sensors.')
+ self.assertEqual(
+ p1_count + p2_count + p3_count, refs_count, "Sum should equal all sensors."
+ )
# Test effectiveness by checking if the sd is within 20% of mean
mean = refs_count / 3
- variance = float((p1_count - mean)**2 + (p1_count - mean)**2 + (p3_count - mean)**2) / 3
+ variance = (
+ float(
+ (p1_count - mean) ** 2 + (p1_count - mean) ** 2 + (p3_count - mean) ** 2
+ )
+ / 3
+ )
sd = math.sqrt(variance)
- self.assertTrue(sd / mean <= 0.2, 'Some values deviate too much from the mean.')
+ self.assertTrue(sd / mean <= 0.2, "Some values deviate too much from the mean.")
def _generate_refs(self, count=10):
random_word_count = int(math.sqrt(count)) + 1
@@ -105,7 +112,7 @@ def _generate_refs(self, count=10):
x_index = 0
y_index = 0
while count > 0:
- yield '%s.%s' % (words[x_index], words[y_index])
+ yield "%s.%s" % (words[x_index], words[y_index])
if y_index < len(words) - 1:
y_index += 1
else:
diff --git a/st2reactor/tests/unit/test_partitioners.py b/st2reactor/tests/unit/test_partitioners.py
index 00e7681cc9..8c4213ec5b 100644
--- a/st2reactor/tests/unit/test_partitioners.py
+++ b/st2reactor/tests/unit/test_partitioners.py
@@ -16,8 +16,11 @@
from __future__ import absolute_import
from oslo_config import cfg
-from st2common.constants.sensors import KVSTORE_PARTITION_LOADER, FILE_PARTITION_LOADER, \
- HASH_PARTITION_LOADER
+from st2common.constants.sensors import (
+ KVSTORE_PARTITION_LOADER,
+ FILE_PARTITION_LOADER,
+ HASH_PARTITION_LOADER,
+)
from st2common.models.db.keyvalue import KeyValuePairDB
from st2common.persistence.keyvalue import KeyValuePair
from st2reactor.container.partitioner_lookup import get_sensors_partitioner
@@ -26,10 +29,8 @@
from st2tests import DbTestCase
from st2tests.fixturesloader import FixturesLoader
-PACK = 'generic'
-FIXTURES_1 = {
- 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml']
-}
+PACK = "generic"
+FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]}
class PartitionerTest(DbTestCase):
@@ -42,76 +43,91 @@ def setUpClass(cls):
# Create TriggerTypes before creation of Rule to avoid failure. Rule requires the
# Trigger and therefore TriggerType to be created prior to rule creation.
cls.models = FixturesLoader().save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_1)
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_1
+ )
config.parse_args()
def test_default_partitioner(self):
provider = get_sensors_partitioner()
sensors = provider.get_sensors()
- self.assertEqual(len(sensors), len(FIXTURES_1['sensors']),
- 'Failed to provider all sensors')
+ self.assertEqual(
+ len(sensors), len(FIXTURES_1["sensors"]), "Failed to provider all sensors"
+ )
- sensor1 = self.models['sensors']['sensor1.yaml']
+ sensor1 = self.models["sensors"]["sensor1.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor1))
def test_kvstore_partitioner(self):
- cfg.CONF.set_override(name='partition_provider',
- override={'name': KVSTORE_PARTITION_LOADER},
- group='sensorcontainer')
- kvp = KeyValuePairDB(**{'name': 'sensornode1.sensor_partition',
- 'value': 'generic.Sensor1, generic.Sensor2'})
+ cfg.CONF.set_override(
+ name="partition_provider",
+ override={"name": KVSTORE_PARTITION_LOADER},
+ group="sensorcontainer",
+ )
+ kvp = KeyValuePairDB(
+ **{
+ "name": "sensornode1.sensor_partition",
+ "value": "generic.Sensor1, generic.Sensor2",
+ }
+ )
KeyValuePair.add_or_update(kvp, publish=False, dispatch_trigger=False)
provider = get_sensors_partitioner()
sensors = provider.get_sensors()
- self.assertEqual(len(sensors), len(kvp.value.split(',')))
+ self.assertEqual(len(sensors), len(kvp.value.split(",")))
- sensor1 = self.models['sensors']['sensor1.yaml']
+ sensor1 = self.models["sensors"]["sensor1.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor1))
- sensor3 = self.models['sensors']['sensor3.yaml']
+ sensor3 = self.models["sensors"]["sensor3.yaml"]
self.assertFalse(provider.is_sensor_owner(sensor3))
def test_file_partitioner(self):
partition_file = FixturesLoader().get_fixture_file_path_abs(
- fixtures_pack=PACK, fixtures_type='sensors', fixture_name='partition_file.yaml')
- cfg.CONF.set_override(name='partition_provider',
- override={'name': FILE_PARTITION_LOADER,
- 'partition_file': partition_file},
- group='sensorcontainer')
+ fixtures_pack=PACK,
+ fixtures_type="sensors",
+ fixture_name="partition_file.yaml",
+ )
+ cfg.CONF.set_override(
+ name="partition_provider",
+ override={"name": FILE_PARTITION_LOADER, "partition_file": partition_file},
+ group="sensorcontainer",
+ )
provider = get_sensors_partitioner()
sensors = provider.get_sensors()
self.assertEqual(len(sensors), 2)
- sensor1 = self.models['sensors']['sensor1.yaml']
+ sensor1 = self.models["sensors"]["sensor1.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor1))
- sensor3 = self.models['sensors']['sensor3.yaml']
+ sensor3 = self.models["sensors"]["sensor3.yaml"]
self.assertFalse(provider.is_sensor_owner(sensor3))
def test_hash_partitioner(self):
# no specific partitioner testing here for that see test_hash_partitioner.py
# This test is to make sure the wiring and some basics work
- cfg.CONF.set_override(name='partition_provider',
- override={'name': HASH_PARTITION_LOADER,
- 'hash_ranges': '%s..%s' % (Range.RANGE_MIN_ENUM,
- Range.RANGE_MAX_ENUM)},
- group='sensorcontainer')
+ cfg.CONF.set_override(
+ name="partition_provider",
+ override={
+ "name": HASH_PARTITION_LOADER,
+ "hash_ranges": "%s..%s" % (Range.RANGE_MIN_ENUM, Range.RANGE_MAX_ENUM),
+ },
+ group="sensorcontainer",
+ )
provider = get_sensors_partitioner()
sensors = provider.get_sensors()
self.assertEqual(len(sensors), 3)
- sensor1 = self.models['sensors']['sensor1.yaml']
+ sensor1 = self.models["sensors"]["sensor1.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor1))
- sensor2 = self.models['sensors']['sensor2.yaml']
+ sensor2 = self.models["sensors"]["sensor2.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor2))
- sensor3 = self.models['sensors']['sensor3.yaml']
+ sensor3 = self.models["sensors"]["sensor3.yaml"]
self.assertTrue(provider.is_sensor_owner(sensor3))
diff --git a/st2reactor/tests/unit/test_process_container.py b/st2reactor/tests/unit/test_process_container.py
index 10ad700b8a..d1bcfdfe64 100644
--- a/st2reactor/tests/unit/test_process_container.py
+++ b/st2reactor/tests/unit/test_process_container.py
@@ -17,7 +17,7 @@
import os
import time
-from mock import (MagicMock, Mock, patch)
+from mock import MagicMock, Mock, patch
import unittest2
from st2reactor.container.process_container import ProcessSensorContainer
@@ -26,14 +26,18 @@
from st2common.persistence.pack import Pack
import st2tests.config as tests_config
+
tests_config.parse_args()
-MOCK_PACK_DB = PackDB(ref='wolfpack', name='wolf pack', description='',
- path='/opt/stackstorm/packs/wolfpack/')
+MOCK_PACK_DB = PackDB(
+ ref="wolfpack",
+ name="wolf pack",
+ description="",
+ path="/opt/stackstorm/packs/wolfpack/",
+)
class ProcessContainerTests(unittest2.TestCase):
-
def test_no_sensors_dont_quit(self):
process_container = ProcessSensorContainer(None, poll_interval=0.1)
process_container_thread = concurrency.spawn(process_container.run)
@@ -43,113 +47,133 @@ def test_no_sensors_dont_quit(self):
process_container.shutdown()
process_container_thread.kill()
- @patch.object(ProcessSensorContainer, '_get_sensor_id',
- MagicMock(return_value='wolfpack.StupidSensor'))
- @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn',
- MagicMock(return_value=None))
- @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB))
- @patch.object(os.path, 'isdir', MagicMock(return_value=True))
- @patch('subprocess.Popen')
- @patch('st2reactor.container.process_container.create_token')
- def test_common_lib_path_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen):
+ @patch.object(
+ ProcessSensorContainer,
+ "_get_sensor_id",
+ MagicMock(return_value="wolfpack.StupidSensor"),
+ )
+ @patch.object(
+ ProcessSensorContainer,
+ "_dispatch_trigger_for_sensor_spawn",
+ MagicMock(return_value=None),
+ )
+ @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB))
+ @patch.object(os.path, "isdir", MagicMock(return_value=True))
+ @patch("subprocess.Popen")
+ @patch("st2reactor.container.process_container.create_token")
+ def test_common_lib_path_in_pythonpath_env_var(
+ self, mock_create_token, mock_subproc_popen
+ ):
process_mock = Mock()
- attrs = {'communicate.return_value': ('output', 'error')}
+ attrs = {"communicate.return_value": ("output", "error")}
process_mock.configure_mock(**attrs)
mock_subproc_popen.return_value = process_mock
mock_create_token = Mock()
- mock_create_token.return_value = 'WHOLETTHEDOGSOUT'
+ mock_create_token.return_value = "WHOLETTHEDOGSOUT"
mock_dispatcher = Mock()
- process_container = ProcessSensorContainer(None, poll_interval=0.1,
- dispatcher=mock_dispatcher)
+ process_container = ProcessSensorContainer(
+ None, poll_interval=0.1, dispatcher=mock_dispatcher
+ )
sensor = {
- 'class_name': 'wolfpack.StupidSensor',
- 'ref': 'wolfpack.StupidSensor',
- 'id': '567890',
- 'trigger_types': ['some_trigga'],
- 'pack': 'wolfpack',
- 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py',
- 'poll_interval': 5
+ "class_name": "wolfpack.StupidSensor",
+ "ref": "wolfpack.StupidSensor",
+ "id": "567890",
+ "trigger_types": ["some_trigga"],
+ "pack": "wolfpack",
+ "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py",
+ "poll_interval": 5,
}
process_container._enable_common_pack_libs = True
- process_container._sensors = {'pack.StupidSensor': sensor}
+ process_container._sensors = {"pack.StupidSensor": sensor}
process_container._spawn_sensor_process(sensor)
_, call_kwargs = mock_subproc_popen.call_args
- actual_env = call_kwargs['env']
- self.assertIn('PYTHONPATH', actual_env)
- pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib'
- self.assertIn(pack_common_lib_path, actual_env['PYTHONPATH'])
-
- @patch.object(ProcessSensorContainer, '_get_sensor_id',
- MagicMock(return_value='wolfpack.StupidSensor'))
- @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn',
- MagicMock(return_value=None))
- @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB))
- @patch.object(os.path, 'isdir', MagicMock(return_value=True))
- @patch('subprocess.Popen')
- @patch('st2reactor.container.process_container.create_token')
- def test_common_lib_path_not_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen):
+ actual_env = call_kwargs["env"]
+ self.assertIn("PYTHONPATH", actual_env)
+ pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib"
+ self.assertIn(pack_common_lib_path, actual_env["PYTHONPATH"])
+
+ @patch.object(
+ ProcessSensorContainer,
+ "_get_sensor_id",
+ MagicMock(return_value="wolfpack.StupidSensor"),
+ )
+ @patch.object(
+ ProcessSensorContainer,
+ "_dispatch_trigger_for_sensor_spawn",
+ MagicMock(return_value=None),
+ )
+ @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB))
+ @patch.object(os.path, "isdir", MagicMock(return_value=True))
+ @patch("subprocess.Popen")
+ @patch("st2reactor.container.process_container.create_token")
+ def test_common_lib_path_not_in_pythonpath_env_var(
+ self, mock_create_token, mock_subproc_popen
+ ):
process_mock = Mock()
- attrs = {'communicate.return_value': ('output', 'error')}
+ attrs = {"communicate.return_value": ("output", "error")}
process_mock.configure_mock(**attrs)
mock_subproc_popen.return_value = process_mock
mock_create_token = Mock()
- mock_create_token.return_value = 'WHOLETTHEDOGSOUT'
+ mock_create_token.return_value = "WHOLETTHEDOGSOUT"
mock_dispatcher = Mock()
- process_container = ProcessSensorContainer(None, poll_interval=0.1,
- dispatcher=mock_dispatcher)
+ process_container = ProcessSensorContainer(
+ None, poll_interval=0.1, dispatcher=mock_dispatcher
+ )
sensor = {
- 'class_name': 'wolfpack.StupidSensor',
- 'ref': 'wolfpack.StupidSensor',
- 'id': '567890',
- 'trigger_types': ['some_trigga'],
- 'pack': 'wolfpack',
- 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py',
- 'poll_interval': 5
+ "class_name": "wolfpack.StupidSensor",
+ "ref": "wolfpack.StupidSensor",
+ "id": "567890",
+ "trigger_types": ["some_trigga"],
+ "pack": "wolfpack",
+ "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py",
+ "poll_interval": 5,
}
process_container._enable_common_pack_libs = False
- process_container._sensors = {'pack.StupidSensor': sensor}
+ process_container._sensors = {"pack.StupidSensor": sensor}
process_container._spawn_sensor_process(sensor)
_, call_kwargs = mock_subproc_popen.call_args
- actual_env = call_kwargs['env']
- self.assertIn('PYTHONPATH', actual_env)
- pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib'
- self.assertNotIn(pack_common_lib_path, actual_env['PYTHONPATH'])
+ actual_env = call_kwargs["env"]
+ self.assertIn("PYTHONPATH", actual_env)
+ pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib"
+ self.assertNotIn(pack_common_lib_path, actual_env["PYTHONPATH"])
- @patch.object(time, 'time', MagicMock(return_value=1439441533))
+ @patch.object(time, "time", MagicMock(return_value=1439441533))
def test_dispatch_triggers_on_spawn_exit(self):
mock_dispatcher = Mock()
- process_container = ProcessSensorContainer(None, poll_interval=0.1,
- dispatcher=mock_dispatcher)
- sensor = {
- 'class_name': 'pack.StupidSensor'
- }
+ process_container = ProcessSensorContainer(
+ None, poll_interval=0.1, dispatcher=mock_dispatcher
+ )
+ sensor = {"class_name": "pack.StupidSensor"}
process = Mock()
- process_attrs = {'pid': 1234}
+ process_attrs = {"pid": 1234}
process.configure_mock(**process_attrs)
- cmd = 'sensor_wrapper.py --class-name pack.StupidSensor'
+ cmd = "sensor_wrapper.py --class-name pack.StupidSensor"
process_container._dispatch_trigger_for_sensor_spawn(sensor, process, cmd)
mock_dispatcher.dispatch.assert_called_with(
- 'core.st2.sensor.process_spawn',
+ "core.st2.sensor.process_spawn",
payload={
- 'timestamp': 1439441533,
- 'cmd': 'sensor_wrapper.py --class-name pack.StupidSensor',
- 'pid': 1234,
- 'id': 'pack.StupidSensor'})
+ "timestamp": 1439441533,
+ "cmd": "sensor_wrapper.py --class-name pack.StupidSensor",
+ "pid": 1234,
+ "id": "pack.StupidSensor",
+ },
+ )
process_container._dispatch_trigger_for_sensor_exit(sensor, 1)
mock_dispatcher.dispatch.assert_called_with(
- 'core.st2.sensor.process_exit',
+ "core.st2.sensor.process_exit",
payload={
- 'id': 'pack.StupidSensor',
- 'timestamp': 1439441533,
- 'exit_code': 1
- })
+ "id": "pack.StupidSensor",
+ "timestamp": 1439441533,
+ "exit_code": 1,
+ },
+ )
diff --git a/st2reactor/tests/unit/test_rule_engine.py b/st2reactor/tests/unit/test_rule_engine.py
index 39b1627268..2f70a2a9d7 100644
--- a/st2reactor/tests/unit/test_rule_engine.py
+++ b/st2reactor/tests/unit/test_rule_engine.py
@@ -18,9 +18,9 @@
from mongoengine import NotUniqueError
from st2common.models.api.rule import RuleAPI
-from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB)
+from st2common.models.db.trigger import TriggerDB, TriggerTypeDB
from st2common.persistence.rule import Rule
-from st2common.persistence.trigger import (TriggerType, Trigger)
+from st2common.persistence.trigger import TriggerType, Trigger
from st2common.util import date as date_utils
import st2reactor.container.utils as container_utils
from st2reactor.rules.enforcer import RuleEnforcer
@@ -29,30 +29,29 @@
class RuleEngineTest(DbTestCase):
-
@classmethod
def setUpClass(cls):
super(RuleEngineTest, cls).setUpClass()
RuleEngineTest._setup_test_models()
- @mock.patch.object(RuleEnforcer, 'enforce', mock.MagicMock(return_value=True))
+ @mock.patch.object(RuleEnforcer, "enforce", mock.MagicMock(return_value=True))
def test_handle_trigger_instances(self):
trigger_instance_1 = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger1',
- {'k1': 't1_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger1",
+ {"k1": "t1_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
trigger_instance_2 = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger1',
- {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger1",
+ {"k1": "t1_p_v", "k2": "v2", "k3": "v3"},
+ date_utils.get_datetime_utc_now(),
)
trigger_instance_3 = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger2',
- {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger2",
+ {"k1": "t1_p_v", "k2": "v2", "k3": "v3"},
+ date_utils.get_datetime_utc_now(),
)
instances = [trigger_instance_1, trigger_instance_2, trigger_instance_3]
rules_engine = RulesEngine()
@@ -60,32 +59,36 @@ def test_handle_trigger_instances(self):
rules_engine.handle_trigger_instance(instance)
def test_create_trigger_instance_for_trigger_with_params(self):
- trigger = {'type': 'dummy_pack_1.st2.test.trigger4', 'parameters': {'url': 'sample'}}
- payload = {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'}
+ trigger = {
+ "type": "dummy_pack_1.st2.test.trigger4",
+ "parameters": {"url": "sample"},
+ }
+ payload = {"k1": "t1_p_v", "k2": "v2", "k3": "v3"}
occurrence_time = date_utils.get_datetime_utc_now()
- trigger_instance = container_utils.create_trigger_instance(trigger=trigger,
- payload=payload,
- occurrence_time=occurrence_time)
+ trigger_instance = container_utils.create_trigger_instance(
+ trigger=trigger, payload=payload, occurrence_time=occurrence_time
+ )
self.assertTrue(trigger_instance)
- self.assertEqual(trigger_instance.trigger, trigger['type'])
+ self.assertEqual(trigger_instance.trigger, trigger["type"])
self.assertEqual(trigger_instance.payload, payload)
def test_get_matching_rules_filters_disabled_rules(self):
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger1',
- {'k1': 't1_p_v', 'k2': 'v2'}, date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger1",
+ {"k1": "t1_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
rules_engine = RulesEngine()
matching_rules = rules_engine.get_matching_rules_for_trigger(trigger_instance)
- expected_rules = ['st2.test.rule2']
+ expected_rules = ["st2.test.rule2"]
for rule in matching_rules:
self.assertIn(rule.name, expected_rules)
def test_handle_trigger_instance_no_rules(self):
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger3',
- {'k1': 't1_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger3",
+ {"k1": "t1_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
rules_engine = RulesEngine()
rules_engine.handle_trigger_instance(trigger_instance) # should not throw.
@@ -96,14 +99,26 @@ def _setup_test_models(cls):
RuleEngineTest._setup_sample_rules()
@classmethod
- def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2',
- 'st2.test.trigger3', 'st2.test.trigger4']):
+ def _setup_sample_triggers(
+ self,
+ names=[
+ "st2.test.trigger1",
+ "st2.test.trigger2",
+ "st2.test.trigger3",
+ "st2.test.trigger4",
+ ],
+ ):
trigger_dbs = []
for name in names:
trigtype = None
try:
- trigtype = TriggerTypeDB(pack='dummy_pack_1', name=name, description='',
- payload_schema={}, parameters_schema={})
+ trigtype = TriggerTypeDB(
+ pack="dummy_pack_1",
+ name=name,
+ description="",
+ payload_schema={},
+ parameters_schema={},
+ )
try:
trigtype = TriggerType.get_by_name(name)
except:
@@ -111,11 +126,15 @@ def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2'
except NotUniqueError:
pass
- created = TriggerDB(pack='dummy_pack_1', name=name, description='',
- type=trigtype.get_reference().ref)
+ created = TriggerDB(
+ pack="dummy_pack_1",
+ name=name,
+ description="",
+ type=trigtype.get_reference().ref,
+ )
- if name in ['st2.test.trigger4']:
- created.parameters = {'url': 'sample'}
+ if name in ["st2.test.trigger4"]:
+ created.parameters = {"url": "sample"}
else:
created.parameters = {}
@@ -130,55 +149,40 @@ def _setup_sample_rules(self):
# Rules for st2.test.trigger1
RULE_1 = {
- 'enabled': True,
- 'name': 'st2.test.rule1',
- 'pack': 'sixpack',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
- },
- 'criteria': {
- 'k1': { # Missing prefix 'trigger'. This rule won't match.
- 'pattern': 't1_p_v',
- 'type': 'equals'
+ "enabled": True,
+ "name": "st2.test.rule1",
+ "pack": "sixpack",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {
+ "k1": { # Missing prefix 'trigger'. This rule won't match.
+ "pattern": "t1_p_v",
+ "type": "equals",
}
},
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
rule_api = RuleAPI(**RULE_1)
rule_db = RuleAPI.to_model(rule_api)
rule_db = Rule.add_or_update(rule_db)
rules.append(rule_db)
- RULE_2 = { # Rule should match.
- 'enabled': True,
- 'name': 'st2.test.rule2',
- 'pack': 'sixpack',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
- },
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't1_p_v',
- 'type': 'equals'
- }
- },
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
+ RULE_2 = { # Rule should match.
+ "enabled": True,
+ "name": "st2.test.rule2",
+ "pack": "sixpack",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
rule_api = RuleAPI(**RULE_2)
rule_db = RuleAPI.to_model(rule_api)
@@ -186,27 +190,17 @@ def _setup_sample_rules(self):
rules.append(rule_db)
RULE_3 = {
- 'enabled': False, # Disabled rule shouldn't match.
- 'name': 'st2.test.rule3',
- 'pack': 'sixpack',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
- },
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't1_p_v',
- 'type': 'equals'
- }
- },
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
+ "enabled": False, # Disabled rule shouldn't match.
+ "name": "st2.test.rule3",
+ "pack": "sixpack",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
rule_api = RuleAPI(**RULE_3)
rule_db = RuleAPI.to_model(rule_api)
@@ -215,27 +209,17 @@ def _setup_sample_rules(self):
# Rules for st2.test.trigger2
RULE_4 = {
- 'enabled': True,
- 'name': 'st2.test.rule4',
- 'pack': 'sixpack',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger2'
- },
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't1_p_v',
- 'type': 'equals'
- }
- },
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
+ "enabled": True,
+ "name": "st2.test.rule4",
+ "pack": "sixpack",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger2"},
+ "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
rule_api = RuleAPI(**RULE_4)
rule_db = RuleAPI.to_model(rule_api)
diff --git a/st2reactor/tests/unit/test_rule_matcher.py b/st2reactor/tests/unit/test_rule_matcher.py
index a5680fa094..46cc084662 100644
--- a/st2reactor/tests/unit/test_rule_matcher.py
+++ b/st2reactor/tests/unit/test_rule_matcher.py
@@ -19,9 +19,9 @@
import mock
from st2common.models.api.rule import RuleAPI
-from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB)
+from st2common.models.db.trigger import TriggerDB, TriggerTypeDB
from st2common.persistence.rule import Rule
-from st2common.persistence.trigger import (TriggerType, Trigger)
+from st2common.persistence.trigger import TriggerType, Trigger
from st2common.services.triggers import get_trigger_db_by_ref
from st2common.util import date as date_utils
import st2reactor.container.utils as container_utils
@@ -33,106 +33,68 @@
from st2tests.base import CleanDbTestCase
from st2tests.fixturesloader import FixturesLoader
-__all__ = [
- 'RuleMatcherTestCase',
- 'BackstopRuleMatcherTestCase'
-]
+__all__ = ["RuleMatcherTestCase", "BackstopRuleMatcherTestCase"]
# Mock rules
RULE_1 = {
- 'enabled': True,
- 'name': 'st2.test.rule1',
- 'pack': 'yoyohoneysingh',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
- },
- 'criteria': {
- 'k1': { # Missing prefix 'trigger'. This rule won't match.
- 'pattern': 't1_p_v',
- 'type': 'equals'
+ "enabled": True,
+ "name": "st2.test.rule1",
+ "pack": "yoyohoneysingh",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {
+ "k1": { # Missing prefix 'trigger'. This rule won't match.
+ "pattern": "t1_p_v",
+ "type": "equals",
}
},
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
-RULE_2 = { # Rule should match.
- 'enabled': True,
- 'name': 'st2.test.rule2',
- 'pack': 'yoyohoneysingh',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
- },
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't1_p_v',
- 'type': 'equals'
- }
+RULE_2 = { # Rule should match.
+ "enabled": True,
+ "name": "st2.test.rule2",
+ "pack": "yoyohoneysingh",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
- },
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
RULE_3 = {
- 'enabled': False, # Disabled rule shouldn't match.
- 'name': 'st2.test.rule3',
- 'pack': 'yoyohoneysingh',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger1'
+ "enabled": False, # Disabled rule shouldn't match.
+ "name": "st2.test.rule3",
+ "pack": "yoyohoneysingh",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger1"},
+ "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't1_p_v',
- 'type': 'equals'
- }
- },
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
- },
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
-RULE_4 = { # Rule should match.
- 'enabled': True,
- 'name': 'st2.test.rule4',
- 'pack': 'yoyohoneysingh',
- 'trigger': {
- 'type': 'dummy_pack_1.st2.test.trigger4'
+RULE_4 = { # Rule should match.
+ "enabled": True,
+ "name": "st2.test.rule4",
+ "pack": "yoyohoneysingh",
+ "trigger": {"type": "dummy_pack_1.st2.test.trigger4"},
+ "criteria": {"trigger.k1": {"pattern": "t2_p_v", "type": "equals"}},
+ "action": {
+ "ref": "sixpack.st2.test.action",
+ "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"},
},
- 'criteria': {
- 'trigger.k1': {
- 'pattern': 't2_p_v',
- 'type': 'equals'
- }
- },
- 'action': {
- 'ref': 'sixpack.st2.test.action',
- 'parameters': {
- 'ip2': '{{rule.k1}}',
- 'ip1': '{{trigger.t1_p}}'
- }
- },
- 'id': '23',
- 'description': ''
+ "id": "23",
+ "description": "",
}
@@ -140,15 +102,15 @@ class RuleMatcherTestCase(CleanDbTestCase):
rules = []
def test_get_matching_rules(self):
- self._setup_sample_trigger('st2.test.trigger1')
+ self._setup_sample_trigger("st2.test.trigger1")
rule_db_1 = self._setup_sample_rule(RULE_1)
rule_db_2 = self._setup_sample_rule(RULE_2)
rule_db_3 = self._setup_sample_rule(RULE_3)
rules = [rule_db_1, rule_db_2, rule_db_3]
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger1',
- {'k1': 't1_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger1",
+ {"k1": "t1_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
trigger = get_trigger_db_by_ref(trigger_instance.trigger)
@@ -159,17 +121,22 @@ def test_get_matching_rules(self):
def test_trigger_instance_payload_with_special_values(self):
# Test a rule where TriggerInstance payload contains a dot (".") and $
- self._setup_sample_trigger('st2.test.trigger1')
- self._setup_sample_trigger('st2.test.trigger2')
+ self._setup_sample_trigger("st2.test.trigger1")
+ self._setup_sample_trigger("st2.test.trigger2")
rule_db_1 = self._setup_sample_rule(RULE_1)
rule_db_2 = self._setup_sample_rule(RULE_2)
rule_db_3 = self._setup_sample_rule(RULE_3)
rules = [rule_db_1, rule_db_2, rule_db_3]
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger2',
- {'k1': 't1_p_v', 'k2.k2': 'v2', 'k3.more.nested.deep': 'some.value',
- 'k4.even.more.nested$': 'foo', 'yep$aaa': 'b'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger2",
+ {
+ "k1": "t1_p_v",
+ "k2.k2": "v2",
+ "k3.more.nested.deep": "some.value",
+ "k4.even.more.nested$": "foo",
+ "yep$aaa": "b",
+ },
+ date_utils.get_datetime_utc_now(),
)
trigger = get_trigger_db_by_ref(trigger_instance.trigger)
@@ -178,20 +145,22 @@ def test_trigger_instance_payload_with_special_values(self):
self.assertIsNotNone(matching_rules)
self.assertEqual(len(matching_rules), 1)
- @mock.patch('st2reactor.rules.matcher.RuleFilter._render_criteria_pattern',
- mock.Mock(side_effect=Exception('exception in _render_criteria_pattern')))
+ @mock.patch(
+ "st2reactor.rules.matcher.RuleFilter._render_criteria_pattern",
+ mock.Mock(side_effect=Exception("exception in _render_criteria_pattern")),
+ )
def test_rule_enforcement_is_created_on_exception_1(self):
# 1. Exception in _render_criteria_pattern
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(rule_enforcement_dbs, [])
- self._setup_sample_trigger('st2.test.trigger4')
+ self._setup_sample_trigger("st2.test.trigger4")
rule_4_db = self._setup_sample_rule(RULE_4)
rules = [rule_4_db]
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger4',
- {'k1': 't2_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger4",
+ {"k1": "t2_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
trigger = get_trigger_db_by_ref(trigger_instance.trigger)
@@ -203,29 +172,35 @@ def test_rule_enforcement_is_created_on_exception_1(self):
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(len(rule_enforcement_dbs), 1)
- expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
- 'instance "%s": Failed to render pattern value "t2_p_v" for key '
- '"trigger.k1": exception in _render_criteria_pattern' %
- (str(trigger_instance.id)))
+ expected_failure = (
+ 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
+ 'instance "%s": Failed to render pattern value "t2_p_v" for key '
+ '"trigger.k1": exception in _render_criteria_pattern'
+ % (str(trigger_instance.id))
+ )
self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure)
- self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id))
- self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id))
+ self.assertEqual(
+ rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)
+ )
+ self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id))
self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED)
- @mock.patch('st2reactor.rules.filter.PayloadLookup.get_value',
- mock.Mock(side_effect=Exception('exception in get_value')))
+ @mock.patch(
+ "st2reactor.rules.filter.PayloadLookup.get_value",
+ mock.Mock(side_effect=Exception("exception in get_value")),
+ )
def test_rule_enforcement_is_created_on_exception_2(self):
# 1. Exception in payload_lookup.get_value
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(rule_enforcement_dbs, [])
- self._setup_sample_trigger('st2.test.trigger4')
+ self._setup_sample_trigger("st2.test.trigger4")
rule_4_db = self._setup_sample_rule(RULE_4)
rules = [rule_4_db]
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger4',
- {'k1': 't2_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger4",
+ {"k1": "t2_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
trigger = get_trigger_db_by_ref(trigger_instance.trigger)
@@ -237,28 +212,34 @@ def test_rule_enforcement_is_created_on_exception_2(self):
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(len(rule_enforcement_dbs), 1)
- expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
- 'instance "%s": Failed transforming criteria key trigger.k1: '
- 'exception in get_value' % (str(trigger_instance.id)))
+ expected_failure = (
+ 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
+ 'instance "%s": Failed transforming criteria key trigger.k1: '
+ "exception in get_value" % (str(trigger_instance.id))
+ )
self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure)
- self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id))
- self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id))
+ self.assertEqual(
+ rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)
+ )
+ self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id))
self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED)
- @mock.patch('st2common.operators.get_operator',
- mock.Mock(return_value=mock.Mock(side_effect=Exception('exception in equals'))))
+ @mock.patch(
+ "st2common.operators.get_operator",
+ mock.Mock(return_value=mock.Mock(side_effect=Exception("exception in equals"))),
+ )
def test_rule_enforcement_is_created_on_exception_3(self):
# 1. Exception in payload_lookup.get_value
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(rule_enforcement_dbs, [])
- self._setup_sample_trigger('st2.test.trigger4')
+ self._setup_sample_trigger("st2.test.trigger4")
rule_4_db = self._setup_sample_rule(RULE_4)
rules = [rule_4_db]
trigger_instance = container_utils.create_trigger_instance(
- 'dummy_pack_1.st2.test.trigger4',
- {'k1': 't2_p_v', 'k2': 'v2'},
- date_utils.get_datetime_utc_now()
+ "dummy_pack_1.st2.test.trigger4",
+ {"k1": "t2_p_v", "k2": "v2"},
+ date_utils.get_datetime_utc_now(),
)
trigger = get_trigger_db_by_ref(trigger_instance.trigger)
@@ -270,22 +251,31 @@ def test_rule_enforcement_is_created_on_exception_3(self):
rule_enforcement_dbs = list(RuleEnforcement.get_all())
self.assertEqual(len(rule_enforcement_dbs), 1)
- expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
- 'instance "%s": There might be a problem with the criteria in rule '
- 'yoyohoneysingh.st2.test.rule4: exception in equals' %
- (str(trigger_instance.id)))
+ expected_failure = (
+ 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger '
+ 'instance "%s": There might be a problem with the criteria in rule '
+ "yoyohoneysingh.st2.test.rule4: exception in equals"
+ % (str(trigger_instance.id))
+ )
self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure)
- self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id))
- self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id))
+ self.assertEqual(
+ rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)
+ )
+ self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id))
self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED)
def _setup_sample_trigger(self, name):
- trigtype = TriggerTypeDB(name=name, pack='dummy_pack_1', payload_schema={},
- parameters_schema={})
+ trigtype = TriggerTypeDB(
+ name=name, pack="dummy_pack_1", payload_schema={}, parameters_schema={}
+ )
TriggerType.add_or_update(trigtype)
- created = TriggerDB(name=name, pack='dummy_pack_1', type=trigtype.get_reference().ref,
- parameters={})
+ created = TriggerDB(
+ name=name,
+ pack="dummy_pack_1",
+ type=trigtype.get_reference().ref,
+ parameters={},
+ )
Trigger.add_or_update(created)
def _setup_sample_rule(self, rule):
@@ -295,14 +285,12 @@ def _setup_sample_rule(self, rule):
return rule_db
-PACK = 'backstop'
+PACK = "backstop"
FIXTURES_TRIGGERS = {
- 'triggertypes': ['triggertype1.yaml'],
- 'triggers': ['trigger1.yaml']
-}
-FIXTURES_RULES = {
- 'rules': ['backstop.yaml', 'success.yaml', 'fail.yaml']
+ "triggertypes": ["triggertype1.yaml"],
+ "triggers": ["trigger1.yaml"],
}
+FIXTURES_RULES = {"rules": ["backstop.yaml", "success.yaml", "fail.yaml"]}
class BackstopRuleMatcherTestCase(DbTestCase):
@@ -315,33 +303,41 @@ def setUpClass(cls):
# Create TriggerTypes before creation of Rule to avoid failure. Rule requires the
# Trigger and therefore TriggerType to be created prior to rule creation.
cls.models = fixturesloader.save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS)
- cls.models.update(fixturesloader.save_fixtures_to_db(
- fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES))
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS
+ )
+ cls.models.update(
+ fixturesloader.save_fixtures_to_db(
+ fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES
+ )
+ )
def test_backstop_ignore(self):
trigger_instance = container_utils.create_trigger_instance(
- self.models['triggers']['trigger1.yaml'].ref,
- {'k1': 'v1'},
- date_utils.get_datetime_utc_now()
+ self.models["triggers"]["trigger1.yaml"].ref,
+ {"k1": "v1"},
+ date_utils.get_datetime_utc_now(),
)
- trigger = self.models['triggers']['trigger1.yaml']
- rules = [rule for rule in six.itervalues(self.models['rules'])]
+ trigger = self.models["triggers"]["trigger1.yaml"]
+ rules = [rule for rule in six.itervalues(self.models["rules"])]
rules_matcher = RulesMatcher(trigger_instance, trigger, rules)
matching_rules = rules_matcher.get_matching_rules()
self.assertEqual(len(matching_rules), 1)
- self.assertEqual(matching_rules[0].id, self.models['rules']['success.yaml'].id)
+ self.assertEqual(matching_rules[0].id, self.models["rules"]["success.yaml"].id)
def test_backstop_apply(self):
trigger_instance = container_utils.create_trigger_instance(
- self.models['triggers']['trigger1.yaml'].ref,
- {'k1': 'v1'},
- date_utils.get_datetime_utc_now()
+ self.models["triggers"]["trigger1.yaml"].ref,
+ {"k1": "v1"},
+ date_utils.get_datetime_utc_now(),
)
- trigger = self.models['triggers']['trigger1.yaml']
- success_rule = self.models['rules']['success.yaml']
- rules = [rule for rule in six.itervalues(self.models['rules']) if rule != success_rule]
+ trigger = self.models["triggers"]["trigger1.yaml"]
+ success_rule = self.models["rules"]["success.yaml"]
+ rules = [
+ rule
+ for rule in six.itervalues(self.models["rules"])
+ if rule != success_rule
+ ]
rules_matcher = RulesMatcher(trigger_instance, trigger, rules)
matching_rules = rules_matcher.get_matching_rules()
self.assertEqual(len(matching_rules), 1)
- self.assertEqual(matching_rules[0].id, self.models['rules']['backstop.yaml'].id)
+ self.assertEqual(matching_rules[0].id, self.models["rules"]["backstop.yaml"].id)
diff --git a/st2reactor/tests/unit/test_sensor_and_rule_registration.py b/st2reactor/tests/unit/test_sensor_and_rule_registration.py
index 3f54e97c73..50075690e9 100644
--- a/st2reactor/tests/unit/test_sensor_and_rule_registration.py
+++ b/st2reactor/tests/unit/test_sensor_and_rule_registration.py
@@ -27,22 +27,20 @@
from st2common.bootstrap.sensorsregistrar import SensorsRegistrar
from st2common.bootstrap.rulesregistrar import RulesRegistrar
-__all__ = [
- 'SensorRegistrationTestCase',
- 'RuleRegistrationTestCase'
-]
+__all__ = ["SensorRegistrationTestCase", "RuleRegistrationTestCase"]
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
-PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures/packs'))
+PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures/packs"))
# NOTE: We need to perform this patching because test fixtures are located outside of the packs
# base paths directory. This will never happen outside the context of test fixtures.
-@mock.patch('st2common.content.utils.get_pack_base_path',
- mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_sensor')))
+@mock.patch(
+ "st2common.content.utils.get_pack_base_path",
+ mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_sensor")),
+)
class SensorRegistrationTestCase(DbTestCase):
-
- @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+ @mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
def test_register_sensors(self):
# Verify DB is empty at the beginning
self.assertEqual(len(SensorType.get_all()), 0)
@@ -61,29 +59,33 @@ def test_register_sensors(self):
self.assertEqual(len(trigger_type_dbs), 2)
self.assertEqual(len(trigger_dbs), 2)
- self.assertEqual(sensor_dbs[0].name, 'TestSensor')
+ self.assertEqual(sensor_dbs[0].name, "TestSensor")
self.assertEqual(sensor_dbs[0].poll_interval, 10)
self.assertTrue(sensor_dbs[0].enabled)
- self.assertEqual(sensor_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml')
+ self.assertEqual(sensor_dbs[0].metadata_file, "sensors/test_sensor_1.yaml")
- self.assertEqual(sensor_dbs[1].name, 'TestSensorDisabled')
+ self.assertEqual(sensor_dbs[1].name, "TestSensorDisabled")
self.assertEqual(sensor_dbs[1].poll_interval, 10)
self.assertFalse(sensor_dbs[1].enabled)
- self.assertEqual(sensor_dbs[1].metadata_file, 'sensors/test_sensor_2.yaml')
+ self.assertEqual(sensor_dbs[1].metadata_file, "sensors/test_sensor_2.yaml")
- self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1')
- self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor')
+ self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1")
+ self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor")
self.assertEqual(len(trigger_type_dbs[0].tags), 0)
- self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2')
- self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor')
+ self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2")
+ self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor")
self.assertEqual(len(trigger_type_dbs[1].tags), 2)
- self.assertEqual(trigger_type_dbs[1].tags[0].name, 'tag1name')
- self.assertEqual(trigger_type_dbs[1].tags[0].value, 'tag1 value')
+ self.assertEqual(trigger_type_dbs[1].tags[0].name, "tag1name")
+ self.assertEqual(trigger_type_dbs[1].tags[0].value, "tag1 value")
# Triggered which are registered via sensors have metadata_file pointing to the sensor
# definition file
- self.assertEqual(trigger_type_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml')
- self.assertEqual(trigger_type_dbs[1].metadata_file, 'sensors/test_sensor_1.yaml')
+ self.assertEqual(
+ trigger_type_dbs[0].metadata_file, "sensors/test_sensor_1.yaml"
+ )
+ self.assertEqual(
+ trigger_type_dbs[1].metadata_file, "sensors/test_sensor_1.yaml"
+ )
# Verify second call to registration doesn't create a duplicate objects
registrar.register_from_packs(base_dirs=[PACKS_DIR])
@@ -96,13 +98,13 @@ def test_register_sensors(self):
self.assertEqual(len(trigger_type_dbs), 2)
self.assertEqual(len(trigger_dbs), 2)
- self.assertEqual(sensor_dbs[0].name, 'TestSensor')
+ self.assertEqual(sensor_dbs[0].name, "TestSensor")
self.assertEqual(sensor_dbs[0].poll_interval, 10)
- self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1')
- self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor')
- self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2')
- self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor')
+ self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1")
+ self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor")
+ self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2")
+ self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor")
# Verify sensor and trigger data is updated on registration
original_load = registrar._meta_loader.load
@@ -110,9 +112,10 @@ def test_register_sensors(self):
def mock_load(*args, **kwargs):
# Update poll_interval and trigger_type_2 description
data = original_load(*args, **kwargs)
- data['poll_interval'] = 50
- data['trigger_types'][1]['description'] = 'test 2'
+ data["poll_interval"] = 50
+ data["trigger_types"][1]["description"] = "test 2"
return data
+
registrar._meta_loader.load = mock_load
registrar.register_from_packs(base_dirs=[PACKS_DIR])
@@ -125,20 +128,22 @@ def mock_load(*args, **kwargs):
self.assertEqual(len(trigger_type_dbs), 2)
self.assertEqual(len(trigger_dbs), 2)
- self.assertEqual(sensor_dbs[0].name, 'TestSensor')
+ self.assertEqual(sensor_dbs[0].name, "TestSensor")
self.assertEqual(sensor_dbs[0].poll_interval, 50)
- self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1')
- self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor')
- self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2')
- self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor')
- self.assertEqual(trigger_type_dbs[1].description, 'test 2')
+ self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1")
+ self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor")
+ self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2")
+ self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor")
+ self.assertEqual(trigger_type_dbs[1].description, "test 2")
# NOTE: We need to perform this patching because test fixtures are located outside of the packs
# base paths directory. This will never happen outside the context of test fixtures.
-@mock.patch('st2common.content.utils.get_pack_base_path',
- mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_rules')))
+@mock.patch(
+ "st2common.content.utils.get_pack_base_path",
+ mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_rules")),
+)
class RuleRegistrationTestCase(DbTestCase):
def test_register_rules(self):
# Verify DB is empty at the beginning
@@ -154,8 +159,8 @@ def test_register_rules(self):
self.assertEqual(len(rule_dbs), 2)
self.assertEqual(len(trigger_dbs), 1)
- self.assertEqual(rule_dbs[0].name, 'sample.with_the_same_timer')
- self.assertEqual(rule_dbs[1].name, 'sample.with_timer')
+ self.assertEqual(rule_dbs[0].name, "sample.with_the_same_timer")
+ self.assertEqual(rule_dbs[1].name, "sample.with_timer")
self.assertIsNotNone(trigger_dbs[0].name)
# Verify second register call updates existing models
diff --git a/st2reactor/tests/unit/test_sensor_service.py b/st2reactor/tests/unit/test_sensor_service.py
index 2064c25ee3..9d1e245e10 100644
--- a/st2reactor/tests/unit/test_sensor_service.py
+++ b/st2reactor/tests/unit/test_sensor_service.py
@@ -23,22 +23,20 @@
from st2common.constants.keyvalue import SYSTEM_SCOPE
from st2common.constants.keyvalue import USER_SCOPE
-__all__ = [
- 'SensorServiceTestCase'
-]
+__all__ = ["SensorServiceTestCase"]
# This trigger has schema that uses all property types
TEST_SCHEMA = {
- 'type': 'object',
- 'additionalProperties': False,
- 'properties': {
- 'age': {'type': 'integer'},
- 'name': {'type': 'string', 'required': True},
- 'address': {'type': 'string', 'default': '-'},
- 'career': {'type': 'array'},
- 'married': {'type': 'boolean'},
- 'awards': {'type': 'object'},
- 'income': {'anyOf': [{'type': 'integer'}, {'type': 'string'}]},
+ "type": "object",
+ "additionalProperties": False,
+ "properties": {
+ "age": {"type": "integer"},
+ "name": {"type": "string", "required": True},
+ "address": {"type": "string", "default": "-"},
+ "career": {"type": "array"},
+ "married": {"type": "boolean"},
+ "awards": {"type": "object"},
+ "income": {"anyOf": [{"type": "integer"}, {"type": "string"}]},
},
}
@@ -60,8 +58,9 @@ def side_effect(trigger, payload, trace_context):
self.sensor_service = SensorService(mock.MagicMock())
self.sensor_service._trigger_dispatcher_service._dispatcher = mock.Mock()
- self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = \
+ self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = (
mock.MagicMock(side_effect=side_effect)
+ )
self._dispatched_count = 0
# Previously, cfg.CONF.system.validate_trigger_payload was set to False explicitly
@@ -73,55 +72,65 @@ def tearDown(self):
# Replace original configured value for payload validation
cfg.CONF.system.validate_trigger_payload = self.validate_trigger_payload
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_success_valid_payload_validation_enabled(self):
cfg.CONF.system.validate_trigger_payload = True
# define a valid payload
payload = {
- 'name': 'John Doe',
- 'age': 25,
- 'career': ['foo, Inc.', 'bar, Inc.'],
- 'married': True,
- 'awards': {'2016': ['hoge prize', 'fuga prize']},
- 'income': 50000
+ "name": "John Doe",
+ "age": 25,
+ "career": ["foo, Inc.", "bar, Inc."],
+ "married": True,
+ "awards": {"2016": ["hoge prize", "fuga prize"]},
+ "income": 50000,
}
# dispatching a trigger
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
# This assumed that the target tirgger dispatched
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
- @mock.patch('st2common.services.triggers.get_trigger_db_by_ref',
- mock.MagicMock(return_value=TriggerDBMock(type='trigger-type-ref')))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_db_by_ref",
+ mock.MagicMock(return_value=TriggerDBMock(type="trigger-type-ref")),
+ )
def test_dispatch_success_with_validation_enabled_trigger_reference(self):
# Test a scenario where a Trigger ref and not TriggerType ref is provided
cfg.CONF.system.validate_trigger_payload = True
# define a valid payload
payload = {
- 'name': 'John Doe',
- 'age': 25,
- 'career': ['foo, Inc.', 'bar, Inc.'],
- 'married': True,
- 'awards': {'2016': ['hoge prize', 'fuga prize']},
- 'income': 50000
+ "name": "John Doe",
+ "age": 25,
+ "career": ["foo, Inc.", "bar, Inc."],
+ "married": True,
+ "awards": {"2016": ["hoge prize", "fuga prize"]},
+ "income": 50000,
}
self.assertEqual(self._dispatched_count, 0)
# dispatching a trigger
- self.sensor_service.dispatch('pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b', payload)
+ self.sensor_service.dispatch(
+ "pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b", payload
+ )
# This assumed that the target tirgger dispatched
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_success_with_validation_disabled_and_invalid_payload(self):
"""
Tests that an invalid payload still results in dispatch success with default config
@@ -143,29 +152,31 @@ def test_dispatch_success_with_validation_disabled_and_invalid_payload(self):
# define a invalid payload (the type of 'age' is incorrect)
payload = {
- 'name': 'John Doe',
- 'age': '25',
+ "name": "John Doe",
+ "age": "25",
}
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
# The default config is to disable validation. So, we want to make sure
# the dispatch actually went through.
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_failure_caused_by_incorrect_type(self):
# define a invalid payload (the type of 'age' is incorrect)
payload = {
- 'name': 'John Doe',
- 'age': '25',
+ "name": "John Doe",
+ "age": "25",
}
# set config to stop dispatching when the payload comply with target trigger_type
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
# This assumed that the target trigger isn't dispatched
self.assertEqual(self._dispatched_count, 0)
@@ -173,120 +184,130 @@ def test_dispatch_failure_caused_by_incorrect_type(self):
# reset config to permit force dispatching
cfg.CONF.system.validate_trigger_payload = False
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_failure_caused_by_lack_of_required_parameter(self):
# define a invalid payload (lack of required property)
payload = {
- 'age': 25,
+ "age": 25,
}
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 0)
# reset config to permit force dispatching
cfg.CONF.system.validate_trigger_payload = False
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_failure_caused_by_extra_parameter(self):
# define a invalid payload ('hobby' is extra)
payload = {
- 'name': 'John Doe',
- 'hobby': 'programming',
+ "name": "John Doe",
+ "hobby": "programming",
}
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 0)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_success_with_multiple_type_value(self):
payload = {
- 'name': 'John Doe',
- 'income': 1234,
+ "name": "John Doe",
+ "income": 1234,
}
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
# reset payload which can have different type
- payload['income'] = 'secret'
+ payload["income"] = "secret"
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 2)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)),
+ )
def test_dispatch_success_with_null(self):
payload = {
- 'name': 'John Doe',
- 'age': None,
+ "name": "John Doe",
+ "age": None,
}
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('trigger-name', payload)
+ self.sensor_service.dispatch("trigger-name", payload)
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=TriggerTypeDBMock()))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=TriggerTypeDBMock()),
+ )
def test_dispatch_success_without_payload_schema(self):
# the case trigger has no property
- self.sensor_service.dispatch('trigger-name', {})
+ self.sensor_service.dispatch("trigger-name", {})
self.assertEqual(self._dispatched_count, 1)
- @mock.patch('st2common.services.triggers.get_trigger_type_db',
- mock.MagicMock(return_value=None))
+ @mock.patch(
+ "st2common.services.triggers.get_trigger_type_db",
+ mock.MagicMock(return_value=None),
+ )
def test_dispatch_trigger_type_not_in_db_should_not_dispatch(self):
cfg.CONF.system.validate_trigger_payload = True
- self.sensor_service.dispatch('not-in-database-ref', {})
+ self.sensor_service.dispatch("not-in-database-ref", {})
self.assertEqual(self._dispatched_count, 0)
def test_datastore_methods(self):
self.sensor_service._datastore_service = mock.Mock()
# Verify methods take encrypt, decrypt and scope arguments
- self.sensor_service.get_value(name='foo1', scope=SYSTEM_SCOPE, decrypt=True)
+ self.sensor_service.get_value(name="foo1", scope=SYSTEM_SCOPE, decrypt=True)
call_kwargs = self.sensor_service.datastore_service.get_value.call_args[1]
expected_kwargs = {
- 'name': 'foo1',
- 'local': True,
- 'scope': SYSTEM_SCOPE,
- 'decrypt': True
+ "name": "foo1",
+ "local": True,
+ "scope": SYSTEM_SCOPE,
+ "decrypt": True,
}
self.assertEqual(call_kwargs, expected_kwargs)
- self.sensor_service.set_value(name='foo2', value='bar', scope=USER_SCOPE, encrypt=True)
+ self.sensor_service.set_value(
+ name="foo2", value="bar", scope=USER_SCOPE, encrypt=True
+ )
call_kwargs = self.sensor_service.datastore_service.set_value.call_args[1]
expected_kwargs = {
- 'name': 'foo2',
- 'value': 'bar',
- 'ttl': None,
- 'local': True,
- 'scope': USER_SCOPE,
- 'encrypt': True
+ "name": "foo2",
+ "value": "bar",
+ "ttl": None,
+ "local": True,
+ "scope": USER_SCOPE,
+ "encrypt": True,
}
self.assertEqual(call_kwargs, expected_kwargs)
- self.sensor_service.delete_value(name='foo3', scope=USER_SCOPE)
+ self.sensor_service.delete_value(name="foo3", scope=USER_SCOPE)
call_kwargs = self.sensor_service.datastore_service.delete_value.call_args[1]
- expected_kwargs = {
- 'name': 'foo3',
- 'local': True,
- 'scope': USER_SCOPE
- }
+ expected_kwargs = {"name": "foo3", "local": True, "scope": USER_SCOPE}
self.assertEqual(call_kwargs, expected_kwargs)
diff --git a/st2reactor/tests/unit/test_sensor_wrapper.py b/st2reactor/tests/unit/test_sensor_wrapper.py
index 735e0e545b..b2d637812d 100644
--- a/st2reactor/tests/unit/test_sensor_wrapper.py
+++ b/st2reactor/tests/unit/test_sensor_wrapper.py
@@ -16,6 +16,7 @@
from __future__ import absolute_import
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -33,11 +34,9 @@
from st2reactor.sensor.base import Sensor, PollingSensor
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
-RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources'))
+RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources"))
-__all__ = [
- 'SensorWrapperTestCase'
-]
+__all__ = ["SensorWrapperTestCase"]
class SensorWrapperTestCase(unittest2.TestCase):
@@ -47,27 +46,33 @@ def setUpClass(cls):
tests_config.parse_args()
def test_sensor_instance_has_sensor_service(self):
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
-
- wrapper = SensorWrapper(pack='core', file_path=file_path,
- class_name='TestSensor',
- trigger_types=trigger_types,
- parent_args=parent_args)
- self.assertIsNotNone(getattr(wrapper._sensor_instance, 'sensor_service', None))
- self.assertIsNotNone(getattr(wrapper._sensor_instance, 'config', None))
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
+
+ wrapper = SensorWrapper(
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
+ self.assertIsNotNone(getattr(wrapper._sensor_instance, "sensor_service", None))
+ self.assertIsNotNone(getattr(wrapper._sensor_instance, "config", None))
def test_trigger_cud_event_handlers(self):
- trigger_id = '57861fcb0640fd1524e577c0'
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
-
- wrapper = SensorWrapper(pack='core', file_path=file_path,
- class_name='TestSensor',
- trigger_types=trigger_types,
- parent_args=parent_args)
+ trigger_id = "57861fcb0640fd1524e577c0"
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
+
+ wrapper = SensorWrapper(
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
self.assertEqual(wrapper._trigger_names, {})
@@ -78,7 +83,9 @@ def test_trigger_cud_event_handlers(self):
# Call create handler with a trigger which refers to this sensor
self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 0)
- trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0])
+ trigger = TriggerDB(
+ id=trigger_id, name="test", pack="dummy", type=trigger_types[0]
+ )
wrapper._handle_create_trigger(trigger=trigger)
self.assertEqual(wrapper._trigger_names, {trigger_id: trigger})
self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 1)
@@ -86,7 +93,9 @@ def test_trigger_cud_event_handlers(self):
# Validate that update handler updates the trigger_names
self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 0)
- trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0])
+ trigger = TriggerDB(
+ id=trigger_id, name="test", pack="dummy", type=trigger_types[0]
+ )
wrapper._handle_update_trigger(trigger=trigger)
self.assertEqual(wrapper._trigger_names, {trigger_id: trigger})
self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 1)
@@ -94,70 +103,97 @@ def test_trigger_cud_event_handlers(self):
# Validate that delete handler deletes the trigger from trigger_names
self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 0)
- trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0])
+ trigger = TriggerDB(
+ id=trigger_id, name="test", pack="dummy", type=trigger_types[0]
+ )
wrapper._handle_delete_trigger(trigger=trigger)
self.assertEqual(wrapper._trigger_names, {})
self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 1)
def test_sensor_creation_passive(self):
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
-
- wrapper = SensorWrapper(pack='core', file_path=file_path,
- class_name='TestSensor',
- trigger_types=trigger_types,
- parent_args=parent_args)
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
+
+ wrapper = SensorWrapper(
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
self.assertIsInstance(wrapper._sensor_instance, Sensor)
self.assertIsNotNone(wrapper._sensor_instance)
def test_sensor_creation_active(self):
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
poll_interval = 10
- wrapper = SensorWrapper(pack='core', file_path=file_path,
- class_name='TestPollingSensor',
- trigger_types=trigger_types,
- parent_args=parent_args,
- poll_interval=poll_interval)
+ wrapper = SensorWrapper(
+ pack="core",
+ file_path=file_path,
+ class_name="TestPollingSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ poll_interval=poll_interval,
+ )
self.assertIsNotNone(wrapper._sensor_instance)
self.assertIsInstance(wrapper._sensor_instance, PollingSensor)
self.assertEqual(wrapper._sensor_instance._poll_interval, poll_interval)
def test_sensor_init_fails_file_doesnt_exist(self):
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor_doesnt_exist.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
-
- expected_msg = 'Failed to load sensor class from file.*? No such file or directory'
- self.assertRaisesRegexp(IOError, expected_msg, SensorWrapper,
- pack='core', file_path=file_path,
- class_name='TestSensor',
- trigger_types=trigger_types,
- parent_args=parent_args)
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor_doesnt_exist.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
+
+ expected_msg = (
+ "Failed to load sensor class from file.*? No such file or directory"
+ )
+ self.assertRaisesRegexp(
+ IOError,
+ expected_msg,
+ SensorWrapper,
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
def test_sensor_init_fails_sensor_code_contains_typo(self):
- file_path = os.path.join(RESOURCES_DIR, 'test_sensor_with_typo.py')
- trigger_types = ['trigger1', 'trigger2']
- parent_args = ['--config-file', TESTS_CONFIG_PATH]
-
- expected_msg = 'Failed to load sensor class from file.*? \'typobar\' is not defined'
- self.assertRaisesRegexp(NameError, expected_msg, SensorWrapper,
- pack='core', file_path=file_path,
- class_name='TestSensor',
- trigger_types=trigger_types,
- parent_args=parent_args)
+ file_path = os.path.join(RESOURCES_DIR, "test_sensor_with_typo.py")
+ trigger_types = ["trigger1", "trigger2"]
+ parent_args = ["--config-file", TESTS_CONFIG_PATH]
+
+ expected_msg = (
+ "Failed to load sensor class from file.*? 'typobar' is not defined"
+ )
+ self.assertRaisesRegexp(
+ NameError,
+ expected_msg,
+ SensorWrapper,
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
# Verify error message also contains traceback
try:
- SensorWrapper(pack='core', file_path=file_path, class_name='TestSensor',
- trigger_types=trigger_types, parent_args=parent_args)
+ SensorWrapper(
+ pack="core",
+ file_path=file_path,
+ class_name="TestSensor",
+ trigger_types=trigger_types,
+ parent_args=parent_args,
+ )
except NameError as e:
- self.assertIn('Traceback (most recent call last)', six.text_type(e))
- self.assertIn('line 20, in ', six.text_type(e))
+ self.assertIn("Traceback (most recent call last)", six.text_type(e))
+ self.assertIn("line 20, in ", six.text_type(e))
else:
- self.fail('NameError not thrown')
+ self.fail("NameError not thrown")
def test_sensor_wrapper_poll_method_still_works(self):
# Verify that sensor wrapper correctly applied select.poll() eventlet workaround so code
@@ -167,5 +203,5 @@ def test_sensor_wrapper_poll_method_still_works(self):
import select
self.assertTrue(eventlet.patcher.is_monkey_patched(select))
- self.assertTrue(select != eventlet.patcher.original('select'))
+ self.assertTrue(select != eventlet.patcher.original("select"))
self.assertTrue(select.poll())
diff --git a/st2reactor/tests/unit/test_tester.py b/st2reactor/tests/unit/test_tester.py
index f1f1b01886..60cd6919b8 100644
--- a/st2reactor/tests/unit/test_tester.py
+++ b/st2reactor/tests/unit/test_tester.py
@@ -25,65 +25,77 @@
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
-FIXTURES_PACK = 'generic'
+FIXTURES_PACK = "generic"
TEST_MODELS_TRIGGERS = {
- 'triggertypes': ['triggertype1.yaml', 'triggertype2.yaml'],
- 'triggers': ['trigger1.yaml', 'trigger2.yaml'],
- 'triggerinstances': ['trigger_instance_1.yaml', 'trigger_instance_2.yaml']
+ "triggertypes": ["triggertype1.yaml", "triggertype2.yaml"],
+ "triggers": ["trigger1.yaml", "trigger2.yaml"],
+ "triggerinstances": ["trigger_instance_1.yaml", "trigger_instance_2.yaml"],
}
-TEST_MODELS_RULES = {
- 'rules': ['rule1.yaml']
-}
+TEST_MODELS_RULES = {"rules": ["rule1.yaml"]}
-TEST_MODELS_ACTIONS = {
- 'actions': ['action1.yaml']
-}
+TEST_MODELS_ACTIONS = {"actions": ["action1.yaml"]}
-@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock())
+@mock.patch.object(PoolPublisher, "publish", mock.MagicMock())
class RuleTesterTestCase(CleanDbTestCase):
def test_matching_trigger_from_file(self):
- FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_ACTIONS)
- rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml')
- trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_1.yaml')
- tester = RuleTester(rule_file_path=rule_file_path,
- trigger_instance_file_path=trigger_instance_file_path)
+ FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS
+ )
+ rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml")
+ trigger_instance_file_path = os.path.join(
+ BASE_PATH, "../fixtures/trigger_instance_1.yaml"
+ )
+ tester = RuleTester(
+ rule_file_path=rule_file_path,
+ trigger_instance_file_path=trigger_instance_file_path,
+ )
matching = tester.evaluate()
self.assertTrue(matching)
def test_non_matching_trigger_from_file(self):
- rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml')
- trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_2.yaml')
- tester = RuleTester(rule_file_path=rule_file_path,
- trigger_instance_file_path=trigger_instance_file_path)
+ rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml")
+ trigger_instance_file_path = os.path.join(
+ BASE_PATH, "../fixtures/trigger_instance_2.yaml"
+ )
+ tester = RuleTester(
+ rule_file_path=rule_file_path,
+ trigger_instance_file_path=trigger_instance_file_path,
+ )
matching = tester.evaluate()
self.assertFalse(matching)
def test_matching_trigger_from_db(self):
- FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_ACTIONS)
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_TRIGGERS)
- trigger_instance_db = models['triggerinstances']['trigger_instance_2.yaml']
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_RULES)
- rule_db = models['rules']['rule1.yaml']
- tester = RuleTester(rule_ref=rule_db.ref,
- trigger_instance_id=str(trigger_instance_db.id))
+ FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS
+ )
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS
+ )
+ trigger_instance_db = models["triggerinstances"]["trigger_instance_2.yaml"]
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES
+ )
+ rule_db = models["rules"]["rule1.yaml"]
+ tester = RuleTester(
+ rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id)
+ )
matching = tester.evaluate()
self.assertTrue(matching)
def test_non_matching_trigger_from_db(self):
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_TRIGGERS)
- trigger_instance_db = models['triggerinstances']['trigger_instance_1.yaml']
- models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK,
- fixtures_dict=TEST_MODELS_RULES)
- rule_db = models['rules']['rule1.yaml']
- tester = RuleTester(rule_ref=rule_db.ref,
- trigger_instance_id=str(trigger_instance_db.id))
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS
+ )
+ trigger_instance_db = models["triggerinstances"]["trigger_instance_1.yaml"]
+ models = FixturesLoader().save_fixtures_to_db(
+ fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES
+ )
+ rule_db = models["rules"]["rule1.yaml"]
+ tester = RuleTester(
+ rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id)
+ )
matching = tester.evaluate()
self.assertFalse(matching)
diff --git a/st2reactor/tests/unit/test_timer.py b/st2reactor/tests/unit/test_timer.py
index 861d74349e..f4311d18d8 100644
--- a/st2reactor/tests/unit/test_timer.py
+++ b/st2reactor/tests/unit/test_timer.py
@@ -60,9 +60,14 @@ def test_existing_rules_are_loaded_on_start(self):
# Add a dummy timer Trigger object
type_ = list(TIMER_TRIGGER_TYPES.keys())[0]
- parameters = {'unit': 'seconds', 'delta': 1000}
- trigger_db = TriggerDB(id=bson.ObjectId(), name='test_trigger_1', pack='dummy',
- type=type_, parameters=parameters)
+ parameters = {"unit": "seconds", "delta": 1000}
+ trigger_db = TriggerDB(
+ id=bson.ObjectId(),
+ name="test_trigger_1",
+ pack="dummy",
+ type=type_,
+ parameters=parameters,
+ )
trigger_db = Trigger.add_or_update(trigger_db)
# Verify object has been added
@@ -74,7 +79,7 @@ def test_existing_rules_are_loaded_on_start(self):
# Verify handlers are called
timer._handle_create_trigger.assert_called_with(trigger_db)
- @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch')
+ @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch")
def test_timer_trace_tag_creation(self, dispatch_mock):
timer = St2Timer()
timer._scheduler = mock.Mock()
@@ -82,11 +87,14 @@ def test_timer_trace_tag_creation(self, dispatch_mock):
# Add a dummy timer Trigger object
type_ = list(TIMER_TRIGGER_TYPES.keys())[0]
- parameters = {'unit': 'seconds', 'delta': 1}
- trigger_db = TriggerDB(name='test_trigger_1', pack='dummy', type=type_,
- parameters=parameters)
+ parameters = {"unit": "seconds", "delta": 1}
+ trigger_db = TriggerDB(
+ name="test_trigger_1", pack="dummy", type=type_, parameters=parameters
+ )
timer.add_trigger(trigger_db)
timer._emit_trigger_instance(trigger=trigger_db.to_serializable_dict())
- self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag,
- '%s-%s' % (TIMER_TRIGGER_TYPES[type_]['name'], trigger_db.name))
+ self.assertEqual(
+ dispatch_mock.call_args[1]["trace_context"].trace_tag,
+ "%s-%s" % (TIMER_TRIGGER_TYPES[type_]["name"], trigger_db.name),
+ )
diff --git a/st2stream/dist_utils.py b/st2stream/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2stream/dist_utils.py
+++ b/st2stream/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2stream/setup.py b/st2stream/setup.py
index af6b302f5d..f34692affc 100644
--- a/st2stream/setup.py
+++ b/st2stream/setup.py
@@ -22,9 +22,9 @@
from dist_utils import apply_vagrant_workaround
from st2stream import __version__
-ST2_COMPONENT = 'st2stream'
+ST2_COMPONENT = "st2stream"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -32,18 +32,18 @@
setup(
name=ST2_COMPONENT,
version=__version__,
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests']),
- scripts=[
- 'bin/st2stream'
- ]
+ packages=find_packages(exclude=["setuptools", "tests"]),
+ scripts=["bin/st2stream"],
)
diff --git a/st2stream/st2stream/__init__.py b/st2stream/st2stream/__init__.py
index ae0bd695f5..b3e513c2bc 100644
--- a/st2stream/st2stream/__init__.py
+++ b/st2stream/st2stream/__init__.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2stream/st2stream/app.py b/st2stream/st2stream/app.py
index 73d32eb4cf..0932dfcc27 100644
--- a/st2stream/st2stream/app.py
+++ b/st2stream/st2stream/app.py
@@ -43,9 +43,9 @@
def setup_app(config={}):
- LOG.info('Creating st2stream: %s as OpenAPI app.', VERSION_STRING)
+ LOG.info("Creating st2stream: %s as OpenAPI app.", VERSION_STRING)
- is_gunicorn = config.get('is_gunicorn', False)
+ is_gunicorn = config.get("is_gunicorn", False)
if is_gunicorn:
# Note: We need to perform monkey patching in the worker. If we do it in
# the master process (gunicorn_config.py), it breaks tons of things
@@ -54,30 +54,33 @@ def setup_app(config={}):
st2stream_config.register_opts()
capabilities = {
- 'name': 'stream',
- 'listen_host': cfg.CONF.stream.host,
- 'listen_port': cfg.CONF.stream.port,
- 'type': 'active'
+ "name": "stream",
+ "listen_host": cfg.CONF.stream.host,
+ "listen_port": cfg.CONF.stream.port,
+ "type": "active",
}
# This should be called in gunicorn case because we only want
# workers to connect to db, rabbbitmq etc. In standalone HTTP
# server case, this setup would have already occurred.
- common_setup(service='stream', config=st2stream_config, setup_db=True,
- register_mq_exchanges=True,
- register_signal_handlers=True,
- register_internal_trigger_types=False,
- run_migrations=False,
- service_registry=True,
- capabilities=capabilities,
- config_args=config.get('config_args', None))
+ common_setup(
+ service="stream",
+ config=st2stream_config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ service_registry=True,
+ capabilities=capabilities,
+ config_args=config.get("config_args", None),
+ )
- router = Router(debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable,
- is_gunicorn=is_gunicorn)
+ router = Router(
+ debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn
+ )
- spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2')
- transforms = {
- '^/stream/v1/': ['/', '/v1/']
- }
+ spec = spec_loader.load_spec("st2common", "openapi.yaml.j2")
+ transforms = {"^/stream/v1/": ["/", "/v1/"]}
router.add_spec(spec, transforms=transforms)
app = router.as_wsgi
@@ -87,8 +90,8 @@ def setup_app(config={}):
app = ErrorHandlingMiddleware(app)
app = CorsMiddleware(app)
app = LoggingMiddleware(app, router)
- app = ResponseInstrumentationMiddleware(app, router, service_name='stream')
+ app = ResponseInstrumentationMiddleware(app, router, service_name="stream")
app = RequestIDMiddleware(app)
- app = RequestInstrumentationMiddleware(app, router, service_name='stream')
+ app = RequestInstrumentationMiddleware(app, router, service_name="stream")
return app
diff --git a/st2stream/st2stream/cmd/__init__.py b/st2stream/st2stream/cmd/__init__.py
index 4d6cd0332d..85b1f07d71 100644
--- a/st2stream/st2stream/cmd/__init__.py
+++ b/st2stream/st2stream/cmd/__init__.py
@@ -15,4 +15,4 @@
from st2stream.cmd import api
-__all__ = ['api']
+__all__ = ["api"]
diff --git a/st2stream/st2stream/cmd/api.py b/st2stream/st2stream/cmd/api.py
index cc1eec7d17..b4ce963ea5 100644
--- a/st2stream/st2stream/cmd/api.py
+++ b/st2stream/st2stream/cmd/api.py
@@ -14,6 +14,7 @@
# limitations under the License.
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
import os
@@ -30,20 +31,20 @@
from st2common.util.wsgi import shutdown_server_kill_pending_requests
from st2stream.signal_handlers import register_stream_signal_handlers
from st2stream import config
+
config.register_opts()
from st2stream import app
-__all__ = [
- 'main'
-]
+__all__ = ["main"]
eventlet.monkey_patch(
os=True,
select=True,
socket=True,
- thread=False if '--use-debugger' in sys.argv else True,
- time=True)
+ thread=False if "--use-debugger" in sys.argv else True,
+ time=True,
+)
LOG = logging.getLogger(__name__)
@@ -53,29 +54,43 @@
def _setup():
capabilities = {
- 'name': 'stream',
- 'listen_host': cfg.CONF.stream.host,
- 'listen_port': cfg.CONF.stream.port,
- 'type': 'active'
+ "name": "stream",
+ "listen_host": cfg.CONF.stream.host,
+ "listen_port": cfg.CONF.stream.port,
+ "type": "active",
}
- common_setup(service='stream', config=config, setup_db=True, register_mq_exchanges=True,
- register_signal_handlers=True, register_internal_trigger_types=False,
- run_migrations=False, service_registry=True, capabilities=capabilities)
+ common_setup(
+ service="stream",
+ config=config,
+ setup_db=True,
+ register_mq_exchanges=True,
+ register_signal_handlers=True,
+ register_internal_trigger_types=False,
+ run_migrations=False,
+ service_registry=True,
+ capabilities=capabilities,
+ )
def _run_server():
host = cfg.CONF.stream.host
port = cfg.CONF.stream.port
- LOG.info('(PID=%s) ST2 Stream API is serving on http://%s:%s.', os.getpid(), host, port)
+ LOG.info(
+ "(PID=%s) ST2 Stream API is serving on http://%s:%s.", os.getpid(), host, port
+ )
max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS
worker_pool = eventlet.GreenPool(max_pool_size)
sock = eventlet.listen((host, port))
def queue_shutdown(signal_number, stack_frame):
- eventlet.spawn_n(shutdown_server_kill_pending_requests, sock=sock,
- worker_pool=worker_pool, wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME)
+ eventlet.spawn_n(
+ shutdown_server_kill_pending_requests,
+ sock=sock,
+ worker_pool=worker_pool,
+ wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME,
+ )
# We register a custom SIGINT handler which allows us to kill long running active requests.
# Note: Eventually we will support draining (waiting for short-running requests), but we
@@ -97,12 +112,12 @@ def main():
except SystemExit as exit_code:
sys.exit(exit_code)
except KeyboardInterrupt:
- listener = get_listener_if_set(name='stream')
+ listener = get_listener_if_set(name="stream")
if listener:
listener.shutdown()
except Exception:
- LOG.exception('(PID=%s) ST2 Stream API quit due to exception.', os.getpid())
+ LOG.exception("(PID=%s) ST2 Stream API quit due to exception.", os.getpid())
return 1
finally:
_teardown()
diff --git a/st2stream/st2stream/config.py b/st2stream/st2stream/config.py
index fe068dc0b2..bc117b556a 100644
--- a/st2stream/st2stream/config.py
+++ b/st2stream/st2stream/config.py
@@ -32,8 +32,11 @@
def parse_args(args=None):
- cfg.CONF(args=args, version=VERSION_STRING,
- default_config_files=[DEFAULT_CONFIG_FILE_PATH])
+ cfg.CONF(
+ args=args,
+ version=VERSION_STRING,
+ default_config_files=[DEFAULT_CONFIG_FILE_PATH],
+ )
def register_opts():
@@ -54,17 +57,15 @@ def _register_app_opts():
# config since they are also used outside st2stream
api_opts = [
cfg.StrOpt(
- 'host', default='127.0.0.1',
- help='StackStorm stream API server host'),
- cfg.IntOpt(
- 'port', default=9102,
- help='StackStorm API stream, server port'),
- cfg.BoolOpt(
- 'debug', default=False,
- help='Specify to enable debug mode.'),
+ "host", default="127.0.0.1", help="StackStorm stream API server host"
+ ),
+ cfg.IntOpt("port", default=9102, help="StackStorm API stream, server port"),
+ cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."),
cfg.StrOpt(
- 'logging', default='/etc/st2/logging.stream.conf',
- help='location of the logging.conf file')
+ "logging",
+ default="/etc/st2/logging.stream.conf",
+ help="location of the logging.conf file",
+ ),
]
- CONF.register_opts(api_opts, group='stream')
+ CONF.register_opts(api_opts, group="stream")
diff --git a/st2stream/st2stream/controllers/v1/executions.py b/st2stream/st2stream/controllers/v1/executions.py
index 379491e978..70023b8745 100644
--- a/st2stream/st2stream/controllers/v1/executions.py
+++ b/st2stream/st2stream/controllers/v1/executions.py
@@ -30,47 +30,46 @@
from st2common.rbac.types import PermissionType
from st2common.stream.listener import get_listener
-__all__ = [
- 'ActionExecutionOutputStreamController'
-]
+__all__ = ["ActionExecutionOutputStreamController"]
LOG = logging.getLogger(__name__)
# Event which is returned when no more data will be produced on this stream endpoint before closing
# the connection.
-NO_MORE_DATA_EVENT = 'event: EOF\ndata: \'\'\n\n'
+NO_MORE_DATA_EVENT = "event: EOF\ndata: ''\n\n"
class ActionExecutionOutputStreamController(ResourceController):
model = ActionExecutionAPI
access = ActionExecution
- supported_filters = {
- 'output_type': 'output_type'
- }
+ supported_filters = {"output_type": "output_type"}
CLOSE_STREAM_LIVEACTION_STATES = action_constants.LIVEACTION_COMPLETED_STATES + [
action_constants.LIVEACTION_STATUS_PAUSING,
- action_constants.LIVEACTION_STATUS_RESUMING
+ action_constants.LIVEACTION_STATUS_RESUMING,
]
- def get_one(self, id, output_type='all', requester_user=None):
+ def get_one(self, id, output_type="all", requester_user=None):
# Special case for id == "last"
- if id == 'last':
- execution_db = ActionExecution.query().order_by('-id').limit(1).first()
+ if id == "last":
+ execution_db = ActionExecution.query().order_by("-id").limit(1).first()
if not execution_db:
- raise ValueError('No executions found in the database')
+ raise ValueError("No executions found in the database")
id = str(execution_db.id)
- execution_db = self._get_one_by_id(id=id, requester_user=requester_user,
- permission_type=PermissionType.EXECUTION_VIEW)
+ execution_db = self._get_one_by_id(
+ id=id,
+ requester_user=requester_user,
+ permission_type=PermissionType.EXECUTION_VIEW,
+ )
execution_id = str(execution_db.id)
query_filters = {}
- if output_type and output_type != 'all':
- query_filters['output_type'] = output_type
+ if output_type and output_type != "all":
+ query_filters["output_type"] = output_type
def format_output_object(output_db_or_api):
if isinstance(output_db_or_api, ActionExecutionOutputDB):
@@ -78,25 +77,27 @@ def format_output_object(output_db_or_api):
elif isinstance(output_db_or_api, ActionExecutionOutputAPI):
data = output_db_or_api
else:
- raise ValueError('Unsupported format: %s' % (type(output_db_or_api)))
+ raise ValueError("Unsupported format: %s" % (type(output_db_or_api)))
- event = 'st2.execution.output__create'
- result = 'event: %s\ndata: %s\n\n' % (event, json_encode(data, indent=None))
+ event = "st2.execution.output__create"
+ result = "event: %s\ndata: %s\n\n" % (event, json_encode(data, indent=None))
return result
def existing_output_iter():
# Consume and return all of the existing lines
- output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters)
+ output_dbs = ActionExecutionOutput.query(
+ execution_id=execution_id, **query_filters
+ )
# Note: We return all at once instead of yield line by line to avoid multiple socket
# writes and to achieve better performance
output = [format_output_object(output_db) for output_db in output_dbs]
- output = ''.join(output)
- yield six.binary_type(output.encode('utf-8'))
+ output = "".join(output)
+ yield six.binary_type(output.encode("utf-8"))
def new_output_iter():
def noop_gen():
- yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8'))
+ yield six.binary_type(NO_MORE_DATA_EVENT.encode("utf-8"))
# Bail out if execution has already completed / been paused
if execution_db.status in self.CLOSE_STREAM_LIVEACTION_STATES:
@@ -104,7 +105,9 @@ def noop_gen():
# Wait for and return any new line which may come in
execution_ids = [execution_id]
- listener = get_listener(name='execution_output') # pylint: disable=no-member
+ listener = get_listener(
+ name="execution_output"
+ ) # pylint: disable=no-member
gen = listener.generator(execution_ids=execution_ids)
def format(gen):
@@ -117,28 +120,37 @@ def format(gen):
# Note: gunicorn wsgi handler expect bytes, not unicode
# pylint: disable=no-member
if isinstance(model_api, ActionExecutionOutputAPI):
- if output_type and output_type != 'all' and \
- model_api.output_type != output_type:
+ if (
+ output_type
+ and output_type != "all"
+ and model_api.output_type != output_type
+ ):
continue
- output = format_output_object(model_api).encode('utf-8')
+ output = format_output_object(model_api).encode("utf-8")
yield six.binary_type(output)
elif isinstance(model_api, ActionExecutionAPI):
if model_api.status in self.CLOSE_STREAM_LIVEACTION_STATES:
- yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8'))
+ yield six.binary_type(
+ NO_MORE_DATA_EVENT.encode("utf-8")
+ )
break
else:
- LOG.debug('Unrecognized message type: %s' % (model_api))
+ LOG.debug("Unrecognized message type: %s" % (model_api))
gen = format(gen)
return gen
def make_response():
app_iter = itertools.chain(existing_output_iter(), new_output_iter())
- res = Response(headerlist=[("X-Accel-Buffering", "no"),
- ('Cache-Control', 'no-cache'),
- ("Content-Type", "text/event-stream; charset=UTF-8")],
- app_iter=app_iter)
+ res = Response(
+ headerlist=[
+ ("X-Accel-Buffering", "no"),
+ ("Cache-Control", "no-cache"),
+ ("Content-Type", "text/event-stream; charset=UTF-8"),
+ ],
+ app_iter=app_iter,
+ )
return res
res = make_response()
diff --git a/st2stream/st2stream/controllers/v1/root.py b/st2stream/st2stream/controllers/v1/root.py
index c9873127a6..2b9178f785 100644
--- a/st2stream/st2stream/controllers/v1/root.py
+++ b/st2stream/st2stream/controllers/v1/root.py
@@ -15,9 +15,7 @@
from st2stream.controllers.v1.stream import StreamController
-__all__ = [
- 'RootController'
-]
+__all__ = ["RootController"]
class RootController(object):
diff --git a/st2stream/st2stream/controllers/v1/stream.py b/st2stream/st2stream/controllers/v1/stream.py
index 19c7d71b1d..f6995c3300 100644
--- a/st2stream/st2stream/controllers/v1/stream.py
+++ b/st2stream/st2stream/controllers/v1/stream.py
@@ -21,58 +21,70 @@
from st2common.util.jsonify import json_encode
from st2common.stream.listener import get_listener
-__all__ = [
- 'StreamController'
-]
+__all__ = ["StreamController"]
LOG = logging.getLogger(__name__)
DEFAULT_EVENTS_WHITELIST = [
- 'st2.announcement__*',
-
- 'st2.execution__create',
- 'st2.execution__update',
- 'st2.execution__delete',
-
- 'st2.liveaction__create',
- 'st2.liveaction__update',
- 'st2.liveaction__delete',
+ "st2.announcement__*",
+ "st2.execution__create",
+ "st2.execution__update",
+ "st2.execution__delete",
+ "st2.liveaction__create",
+ "st2.liveaction__update",
+ "st2.liveaction__delete",
]
def format(gen):
- message = '''event: %s\ndata: %s\n\n'''
+ message = """event: %s\ndata: %s\n\n"""
for pack in gen:
if not pack:
# Note: gunicorn wsgi handler expect bytes, not unicode
- yield six.binary_type(b'\n')
+ yield six.binary_type(b"\n")
else:
(event, body) = pack
# Note: gunicorn wsgi handler expect bytes, not unicode
- yield six.binary_type((message % (event, json_encode(body,
- indent=None))).encode('utf-8'))
+ yield six.binary_type(
+ (message % (event, json_encode(body, indent=None))).encode("utf-8")
+ )
class StreamController(object):
- def get_all(self, end_execution_id=None, end_event=None,
- events=None, action_refs=None, execution_ids=None, requester_user=None):
+ def get_all(
+ self,
+ end_execution_id=None,
+ end_event=None,
+ events=None,
+ action_refs=None,
+ execution_ids=None,
+ requester_user=None,
+ ):
events = events if events else DEFAULT_EVENTS_WHITELIST
action_refs = action_refs if action_refs else None
execution_ids = execution_ids if execution_ids else None
def make_response():
- listener = get_listener(name='stream')
- app_iter = format(listener.generator(events=events,
- action_refs=action_refs,
- end_event=end_event,
- end_statuses=action_constants.LIVEACTION_COMPLETED_STATES,
- end_execution_id=end_execution_id,
- execution_ids=execution_ids))
- res = Response(headerlist=[("X-Accel-Buffering", "no"),
- ('Cache-Control', 'no-cache'),
- ("Content-Type", "text/event-stream; charset=UTF-8")],
- app_iter=app_iter)
+ listener = get_listener(name="stream")
+ app_iter = format(
+ listener.generator(
+ events=events,
+ action_refs=action_refs,
+ end_event=end_event,
+ end_statuses=action_constants.LIVEACTION_COMPLETED_STATES,
+ end_execution_id=end_execution_id,
+ execution_ids=execution_ids,
+ )
+ )
+ res = Response(
+ headerlist=[
+ ("X-Accel-Buffering", "no"),
+ ("Cache-Control", "no-cache"),
+ ("Content-Type", "text/event-stream; charset=UTF-8"),
+ ],
+ app_iter=app_iter,
+ )
return res
stream = make_response()
diff --git a/st2stream/st2stream/signal_handlers.py b/st2stream/st2stream/signal_handlers.py
index 56bc06450a..b292d8b67b 100644
--- a/st2stream/st2stream/signal_handlers.py
+++ b/st2stream/st2stream/signal_handlers.py
@@ -15,9 +15,7 @@
import signal
-__all__ = [
- 'register_stream_signal_handlers'
-]
+__all__ = ["register_stream_signal_handlers"]
def register_stream_signal_handlers(handler_func):
diff --git a/st2stream/st2stream/wsgi.py b/st2stream/st2stream/wsgi.py
index c177572ba1..14d847e2a1 100644
--- a/st2stream/st2stream/wsgi.py
+++ b/st2stream/st2stream/wsgi.py
@@ -18,8 +18,11 @@
from st2stream import app
config = {
- 'is_gunicorn': True,
- 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')]
+ "is_gunicorn": True,
+ "config_args": [
+ "--config-file",
+ os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"),
+ ],
}
application = app.setup_app(config)
diff --git a/st2stream/tests/unit/controllers/v1/base.py b/st2stream/tests/unit/controllers/v1/base.py
index 24a59a5cd0..4f6e2ca336 100644
--- a/st2stream/tests/unit/controllers/v1/base.py
+++ b/st2stream/tests/unit/controllers/v1/base.py
@@ -16,9 +16,7 @@
from st2stream import app
from st2tests.api import BaseFunctionalTest
-__all__ = [
- 'FunctionalTest'
-]
+__all__ = ["FunctionalTest"]
class FunctionalTest(BaseFunctionalTest):
diff --git a/st2stream/tests/unit/controllers/v1/test_stream.py b/st2stream/tests/unit/controllers/v1/test_stream.py
index 7ff7e62f3d..c67f3e2782 100644
--- a/st2stream/tests/unit/controllers/v1/test_stream.py
+++ b/st2stream/tests/unit/controllers/v1/test_stream.py
@@ -34,88 +34,72 @@
RUNNER_TYPE_1 = {
- 'description': '',
- 'enabled': True,
- 'name': 'local-shell-cmd',
- 'runner_module': 'local_runner',
- 'runner_parameters': {}
+ "description": "",
+ "enabled": True,
+ "name": "local-shell-cmd",
+ "runner_module": "local_runner",
+ "runner_parameters": {},
}
ACTION_1 = {
- 'name': 'st2.dummy.action1',
- 'description': 'test description',
- 'enabled': True,
- 'entry_point': '/tmp/test/action1.sh',
- 'pack': 'sixpack',
- 'runner_type': 'local-shell-cmd',
- 'parameters': {
- 'a': {
- 'type': 'string',
- 'default': 'abc'
- },
- 'b': {
- 'type': 'number',
- 'default': 123
- },
- 'c': {
- 'type': 'number',
- 'default': 123,
- 'immutable': True
- },
- 'd': {
- 'type': 'string',
- 'secret': True
- }
- }
+ "name": "st2.dummy.action1",
+ "description": "test description",
+ "enabled": True,
+ "entry_point": "/tmp/test/action1.sh",
+ "pack": "sixpack",
+ "runner_type": "local-shell-cmd",
+ "parameters": {
+ "a": {"type": "string", "default": "abc"},
+ "b": {"type": "number", "default": 123},
+ "c": {"type": "number", "default": 123, "immutable": True},
+ "d": {"type": "string", "secret": True},
+ },
}
LIVE_ACTION_1 = {
- 'action': 'sixpack.st2.dummy.action1',
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'uname -a',
- 'd': SUPER_SECRET_PARAMETER
- }
+ "action": "sixpack.st2.dummy.action1",
+ "parameters": {
+ "hosts": "localhost",
+ "cmd": "uname -a",
+ "d": SUPER_SECRET_PARAMETER,
+ },
}
EXECUTION_1 = {
- 'id': '598dbf0c0640fd54bffc688b',
- 'action': {
- 'ref': 'sixpack.st2.dummy.action1'
+ "id": "598dbf0c0640fd54bffc688b",
+ "action": {"ref": "sixpack.st2.dummy.action1"},
+ "parameters": {
+ "hosts": "localhost",
+ "cmd": "uname -a",
+ "d": SUPER_SECRET_PARAMETER,
},
- 'parameters': {
- 'hosts': 'localhost',
- 'cmd': 'uname -a',
- 'd': SUPER_SECRET_PARAMETER
- }
}
STDOUT_1 = {
- 'execution_id': '598dbf0c0640fd54bffc688b',
- 'action_ref': 'dummy.action1',
- 'output_type': 'stdout'
+ "execution_id": "598dbf0c0640fd54bffc688b",
+ "action_ref": "dummy.action1",
+ "output_type": "stdout",
}
STDERR_1 = {
- 'execution_id': '598dbf0c0640fd54bffc688b',
- 'action_ref': 'dummy.action1',
- 'output_type': 'stderr'
+ "execution_id": "598dbf0c0640fd54bffc688b",
+ "action_ref": "dummy.action1",
+ "output_type": "stderr",
}
class META(object):
delivery_info = {}
- def __init__(self, exchange='some', routing_key='thing'):
- self.delivery_info['exchange'] = exchange
- self.delivery_info['routing_key'] = routing_key
+ def __init__(self, exchange="some", routing_key="thing"):
+ self.delivery_info["exchange"] = exchange
+ self.delivery_info["routing_key"] = routing_key
def ack(self):
pass
class TestStreamController(FunctionalTest):
-
@classmethod
def setUpClass(cls):
super(TestStreamController, cls).setUpClass()
@@ -126,33 +110,35 @@ def setUpClass(cls):
instance = ActionAPI(**ACTION_1)
Action.add_or_update(ActionAPI.to_model(instance))
- @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock())
- @mock.patch('st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST', None)
+ @mock.patch.object(st2common.stream.listener, "listen", mock.Mock())
+ @mock.patch("st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST", None)
def test_get_all(self):
resp = stream.StreamController().get_all()
- self.assertEqual(resp._status, '200 OK')
- self.assertIn(('Content-Type', 'text/event-stream; charset=UTF-8'), resp._headerlist)
+ self.assertEqual(resp._status, "200 OK")
+ self.assertIn(
+ ("Content-Type", "text/event-stream; charset=UTF-8"), resp._headerlist
+ )
- listener = st2common.stream.listener.get_listener(name='stream')
+ listener = st2common.stream.listener.get_listener(name="stream")
process = listener.processor(LiveActionAPI)
message = None
for message in resp._app_iter:
- message = message.decode('utf-8')
- if message != '\n':
+ message = message.decode("utf-8")
+ if message != "\n":
break
process(LiveActionDB(**LIVE_ACTION_1), META())
- self.assertIn('event: some__thing', message)
+ self.assertIn("event: some__thing", message)
self.assertIn('data: {"', message)
self.assertNotIn(SUPER_SECRET_PARAMETER, message)
- @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock())
+ @mock.patch.object(st2common.stream.listener, "listen", mock.Mock())
def test_get_all_with_filters(self):
- cfg.CONF.set_override(name='heartbeat', group='stream', override=0.1)
+ cfg.CONF.set_override(name="heartbeat", group="stream", override=0.1)
- listener = st2common.stream.listener.get_listener(name='stream')
+ listener = st2common.stream.listener.get_listener(name="stream")
process_execution = listener.processor(ActionExecutionAPI)
process_liveaction = listener.processor(LiveActionAPI)
process_output = listener.processor(ActionExecutionOutputAPI)
@@ -164,50 +150,50 @@ def test_get_all_with_filters(self):
output_api_stderr = ActionExecutionOutputDB(**STDERR_1)
def dispatch_and_handle_mock_data(resp):
- received_messages_data = ''
+ received_messages_data = ""
for index, message in enumerate(resp._app_iter):
if message.strip():
- received_messages_data += message.decode('utf-8')
+ received_messages_data += message.decode("utf-8")
# Dispatch some mock events
if index == 0:
- meta = META('st2.execution', 'create')
+ meta = META("st2.execution", "create")
process_execution(execution_api, meta)
elif index == 1:
- meta = META('st2.execution', 'update')
+ meta = META("st2.execution", "update")
process_execution(execution_api, meta)
elif index == 2:
- meta = META('st2.execution', 'delete')
+ meta = META("st2.execution", "delete")
process_execution(execution_api, meta)
elif index == 3:
- meta = META('st2.liveaction', 'create')
+ meta = META("st2.liveaction", "create")
process_liveaction(liveaction_api, meta)
elif index == 4:
- meta = META('st2.liveaction', 'create')
+ meta = META("st2.liveaction", "create")
process_liveaction(liveaction_api, meta)
elif index == 5:
- meta = META('st2.liveaction', 'delete')
+ meta = META("st2.liveaction", "delete")
process_liveaction(liveaction_api, meta)
elif index == 6:
- meta = META('st2.liveaction', 'delete')
+ meta = META("st2.liveaction", "delete")
process_liveaction(liveaction_api, meta)
elif index == 7:
- meta = META('st2.announcement', 'chatops')
+ meta = META("st2.announcement", "chatops")
process_no_api_model({}, meta)
elif index == 8:
- meta = META('st2.execution.output', 'create')
+ meta = META("st2.execution.output", "create")
process_output(output_api_stdout, meta)
elif index == 9:
- meta = META('st2.execution.output', 'create')
+ meta = META("st2.execution.output", "create")
process_output(output_api_stderr, meta)
elif index == 10:
- meta = META('st2.announcement', 'errbot')
+ meta = META("st2.announcement", "errbot")
process_no_api_model({}, meta)
else:
break
- received_messages = received_messages_data.split('\n\n')
+ received_messages = received_messages_data.split("\n\n")
received_messages = [message for message in received_messages if message]
return received_messages
@@ -217,10 +203,10 @@ def dispatch_and_handle_mock_data(resp):
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 9)
- self.assertIn('st2.execution__create', received_messages[0])
- self.assertIn('st2.liveaction__delete', received_messages[5])
- self.assertIn('st2.announcement__chatops', received_messages[7])
- self.assertIn('st2.announcement__errbot', received_messages[8])
+ self.assertIn("st2.execution__create", received_messages[0])
+ self.assertIn("st2.liveaction__delete", received_messages[5])
+ self.assertIn("st2.announcement__chatops", received_messages[7])
+ self.assertIn("st2.announcement__errbot", received_messages[8])
# 1. ?events= filter
# No filter provided - all messages should be received
@@ -229,79 +215,79 @@ def dispatch_and_handle_mock_data(resp):
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 11)
- self.assertIn('st2.execution__create', received_messages[0])
- self.assertIn('st2.announcement__chatops', received_messages[7])
- self.assertIn('st2.execution.output__create', received_messages[8])
- self.assertIn('st2.execution.output__create', received_messages[9])
- self.assertIn('st2.announcement__errbot', received_messages[10])
+ self.assertIn("st2.execution__create", received_messages[0])
+ self.assertIn("st2.announcement__chatops", received_messages[7])
+ self.assertIn("st2.execution.output__create", received_messages[8])
+ self.assertIn("st2.execution.output__create", received_messages[9])
+ self.assertIn("st2.announcement__errbot", received_messages[10])
# Filter provided, only three messages should be received
- events = ['st2.execution__create', 'st2.liveaction__delete']
+ events = ["st2.execution__create", "st2.liveaction__delete"]
resp = stream.StreamController().get_all(events=events)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 3)
- self.assertIn('st2.execution__create', received_messages[0])
- self.assertIn('st2.liveaction__delete', received_messages[1])
- self.assertIn('st2.liveaction__delete', received_messages[2])
+ self.assertIn("st2.execution__create", received_messages[0])
+ self.assertIn("st2.liveaction__delete", received_messages[1])
+ self.assertIn("st2.liveaction__delete", received_messages[2])
# Filter provided, only three messages should be received
- events = ['st2.liveaction__create', 'st2.liveaction__delete']
+ events = ["st2.liveaction__create", "st2.liveaction__delete"]
resp = stream.StreamController().get_all(events=events)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 4)
- self.assertIn('st2.liveaction__create', received_messages[0])
- self.assertIn('st2.liveaction__create', received_messages[1])
- self.assertIn('st2.liveaction__delete', received_messages[2])
- self.assertIn('st2.liveaction__delete', received_messages[3])
+ self.assertIn("st2.liveaction__create", received_messages[0])
+ self.assertIn("st2.liveaction__create", received_messages[1])
+ self.assertIn("st2.liveaction__delete", received_messages[2])
+ self.assertIn("st2.liveaction__delete", received_messages[3])
# Glob filter
- events = ['st2.announcement__*']
+ events = ["st2.announcement__*"]
resp = stream.StreamController().get_all(events=events)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 2)
- self.assertIn('st2.announcement__chatops', received_messages[0])
- self.assertIn('st2.announcement__errbot', received_messages[1])
+ self.assertIn("st2.announcement__chatops", received_messages[0])
+ self.assertIn("st2.announcement__errbot", received_messages[1])
# Filter provided
- events = ['st2.execution.output__create']
+ events = ["st2.execution.output__create"]
resp = stream.StreamController().get_all(events=events)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 2)
- self.assertIn('st2.execution.output__create', received_messages[0])
- self.assertIn('st2.execution.output__create', received_messages[1])
+ self.assertIn("st2.execution.output__create", received_messages[0])
+ self.assertIn("st2.execution.output__create", received_messages[1])
# Filter provided, invalid , no message should be received
- events = ['invalid1', 'invalid2']
+ events = ["invalid1", "invalid2"]
resp = stream.StreamController().get_all(events=events)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 0)
# 2. ?action_refs= filter
- action_refs = ['invalid1', 'invalid2']
+ action_refs = ["invalid1", "invalid2"]
resp = stream.StreamController().get_all(action_refs=action_refs)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 0)
- action_refs = ['dummy.action1']
+ action_refs = ["dummy.action1"]
resp = stream.StreamController().get_all(action_refs=action_refs)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 2)
# 3. ?execution_ids= filter
- execution_ids = ['invalid1', 'invalid2']
+ execution_ids = ["invalid1", "invalid2"]
resp = stream.StreamController().get_all(execution_ids=execution_ids)
received_messages = dispatch_and_handle_mock_data(resp)
self.assertEqual(len(received_messages), 0)
- execution_ids = [EXECUTION_1['id']]
+ execution_ids = [EXECUTION_1["id"]]
resp = stream.StreamController().get_all(execution_ids=execution_ids)
received_messages = dispatch_and_handle_mock_data(resp)
diff --git a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py
index deb76b4e97..d14dd029e8 100644
--- a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py
+++ b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py
@@ -30,50 +30,54 @@
from .base import FunctionalTest
-__all__ = [
- 'ActionExecutionOutputStreamControllerTestCase'
-]
+__all__ = ["ActionExecutionOutputStreamControllerTestCase"]
class ActionExecutionOutputStreamControllerTestCase(FunctionalTest):
def test_get_one_id_last_no_executions_in_the_database(self):
ActionExecution.query().delete()
- resp = self.app.get('/v1/executions/last/output', expect_errors=True)
+ resp = self.app.get("/v1/executions/last/output", expect_errors=True)
self.assertEqual(resp.status_int, http_client.BAD_REQUEST)
- self.assertEqual(resp.json['faultstring'], 'No executions found in the database')
+ self.assertEqual(
+ resp.json["faultstring"], "No executions found in the database"
+ )
def test_get_output_running_execution(self):
# Retrieve lister instance to avoid race with listener connection not being established
# early enough for tests to pass.
# NOTE: This only affects tests where listeners are not pre-initialized.
- listener = get_listener(name='execution_output')
+ listener = get_listener(name="execution_output")
eventlet.sleep(1.0)
# Test the execution output API endpoint for execution which is running (blocking)
status = action_constants.LIVEACTION_STATUS_RUNNING
timestamp = date_utils.get_datetime_utc_now()
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
action_execution_db = ActionExecution.add_or_update(action_execution_db)
- output_params = dict(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout before start\n')
+ output_params = dict(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout before start\n",
+ )
# Insert mock output object
output_db = ActionExecutionOutputDB(**output_params)
ActionExecutionOutput.add_or_update(output_db, publish=False)
def insert_mock_data():
- output_params['data'] = 'stdout mid 1\n'
+ output_params["data"] = "stdout mid 1\n"
output_db = ActionExecutionOutputDB(**output_params)
ActionExecutionOutput.add_or_update(output_db)
@@ -81,7 +85,7 @@ def insert_mock_data():
# spawn an eventlet which eventually finishes the action.
def publish_action_finished(action_execution_db):
# Insert mock output object
- output_params['data'] = 'stdout pre finish 1\n'
+ output_params["data"] = "stdout pre finish 1\n"
output_db = ActionExecutionOutputDB(**output_params)
ActionExecutionOutput.add_or_update(output_db)
@@ -96,28 +100,32 @@ def publish_action_finished(action_execution_db):
# Retrieve data while execution is running - endpoint return new data once it's available
# and block until the execution finishes
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
events = self._parse_response(resp.text)
self.assertEqual(len(events), 4)
- self.assertEqual(events[0][1]['data'], 'stdout before start\n')
- self.assertEqual(events[1][1]['data'], 'stdout mid 1\n')
- self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n')
- self.assertEqual(events[3][0], 'EOF')
+ self.assertEqual(events[0][1]["data"], "stdout before start\n")
+ self.assertEqual(events[1][1]["data"], "stdout mid 1\n")
+ self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n")
+ self.assertEqual(events[3][0], "EOF")
# Once the execution is in completed state, existing output should be returned immediately
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
events = self._parse_response(resp.text)
self.assertEqual(len(events), 4)
- self.assertEqual(events[0][1]['data'], 'stdout before start\n')
- self.assertEqual(events[1][1]['data'], 'stdout mid 1\n')
- self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n')
- self.assertEqual(events[3][0], 'EOF')
+ self.assertEqual(events[0][1]["data"], "stdout before start\n")
+ self.assertEqual(events[1][1]["data"], "stdout mid 1\n")
+ self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n")
+ self.assertEqual(events[3][0], "EOF")
listener.shutdown()
@@ -127,49 +135,57 @@ def test_get_output_finished_execution(self):
# Insert mock execution and output objects
status = action_constants.LIVEACTION_STATUS_SUCCEEDED
timestamp = date_utils.get_datetime_utc_now()
- action_execution_db = ActionExecutionDB(start_timestamp=timestamp,
- end_timestamp=timestamp,
- status=status,
- action={'ref': 'core.local'},
- runner={'name': 'local-shell-cmd'},
- liveaction={'ref': 'foo'})
+ action_execution_db = ActionExecutionDB(
+ start_timestamp=timestamp,
+ end_timestamp=timestamp,
+ status=status,
+ action={"ref": "core.local"},
+ runner={"name": "local-shell-cmd"},
+ liveaction={"ref": "foo"},
+ )
action_execution_db = ActionExecution.add_or_update(action_execution_db)
for i in range(1, 6):
- stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stdout',
- data='stdout %s\n' % (i))
+ stdout_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stdout",
+ data="stdout %s\n" % (i),
+ )
ActionExecutionOutput.add_or_update(stdout_db)
for i in range(10, 15):
- stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id),
- action_ref='core.local',
- runner_ref='dummy',
- timestamp=timestamp,
- output_type='stderr',
- data='stderr %s\n' % (i))
+ stderr_db = ActionExecutionOutputDB(
+ execution_id=str(action_execution_db.id),
+ action_ref="core.local",
+ runner_ref="dummy",
+ timestamp=timestamp,
+ output_type="stderr",
+ data="stderr %s\n" % (i),
+ )
ActionExecutionOutput.add_or_update(stderr_db)
- resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)),
- expect_errors=False)
+ resp = self.app.get(
+ "/v1/executions/%s/output" % (str(action_execution_db.id)),
+ expect_errors=False,
+ )
self.assertEqual(resp.status_int, 200)
events = self._parse_response(resp.text)
self.assertEqual(len(events), 11)
- self.assertEqual(events[0][1]['data'], 'stdout 1\n')
- self.assertEqual(events[9][1]['data'], 'stderr 14\n')
- self.assertEqual(events[10][0], 'EOF')
+ self.assertEqual(events[0][1]["data"], "stdout 1\n")
+ self.assertEqual(events[9][1]["data"], "stderr 14\n")
+ self.assertEqual(events[10][0], "EOF")
# Verify "last" short-hand id works
- resp = self.app.get('/v1/executions/last/output', expect_errors=False)
+ resp = self.app.get("/v1/executions/last/output", expect_errors=False)
self.assertEqual(resp.status_int, 200)
events = self._parse_response(resp.text)
self.assertEqual(len(events), 11)
- self.assertEqual(events[10][0], 'EOF')
+ self.assertEqual(events[10][0], "EOF")
def _parse_response(self, response):
"""
@@ -177,12 +193,12 @@ def _parse_response(self, response):
"""
events = []
- lines = response.strip().split('\n')
+ lines = response.strip().split("\n")
for index, line in enumerate(lines):
- if 'data:' in line:
+ if "data:" in line:
e_line = lines[index - 1]
- event_name = e_line[e_line.find('event: ') + len('event:'):].strip()
- event_data = line[line.find('data: ') + len('data :'):].strip()
+ event_name = e_line[e_line.find("event: ") + len("event:") :].strip()
+ event_data = line[line.find("data: ") + len("data :") :].strip()
event_data = json.loads(event_data) if len(event_data) > 2 else {}
events.append((event_name, event_data))
diff --git a/st2tests/dist_utils.py b/st2tests/dist_utils.py
index a6f62c8cc2..2f2043cf29 100644
--- a/st2tests/dist_utils.py
+++ b/st2tests/dist_utils.py
@@ -43,17 +43,17 @@
if PY3:
text_type = str
else:
- text_type = unicode # noqa # pylint: disable=E0602
+ text_type = unicode # noqa # pylint: disable=E0602
-GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python'
+GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python"
__all__ = [
- 'check_pip_is_installed',
- 'check_pip_version',
- 'fetch_requirements',
- 'apply_vagrant_workaround',
- 'get_version_string',
- 'parse_version_string'
+ "check_pip_is_installed",
+ "check_pip_version",
+ "fetch_requirements",
+ "apply_vagrant_workaround",
+ "get_version_string",
+ "parse_version_string",
]
@@ -64,15 +64,15 @@ def check_pip_is_installed():
try:
import pip # NOQA
except ImportError as e:
- print('Failed to import pip: %s' % (text_type(e)))
- print('')
- print('Download pip:\n%s' % (GET_PIP))
+ print("Failed to import pip: %s" % (text_type(e)))
+ print("")
+ print("Download pip:\n%s" % (GET_PIP))
sys.exit(1)
return True
-def check_pip_version(min_version='6.0.0'):
+def check_pip_version(min_version="6.0.0"):
"""
Ensure that a minimum supported version of pip is installed.
"""
@@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'):
import pip
if StrictVersion(pip.__version__) < StrictVersion(min_version):
- print("Upgrade pip, your version '{0}' "
- "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__,
- min_version,
- GET_PIP))
+ print(
+ "Upgrade pip, your version '{0}' "
+ "is outdated. Minimum required version is '{1}':\n{2}".format(
+ pip.__version__, min_version, GET_PIP
+ )
+ )
sys.exit(1)
return True
@@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path):
reqs = []
def _get_link(line):
- vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+']
+ vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"]
for vcs_prefix in vcs_prefixes:
- if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)):
- req_name = re.findall('.*#egg=(.+)([&|@]).*$', line)
+ if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)):
+ req_name = re.findall(".*#egg=(.+)([&|@]).*$", line)
if not req_name:
- req_name = re.findall('.*#egg=(.+?)$', line)
+ req_name = re.findall(".*#egg=(.+?)$", line)
else:
req_name = req_name[0]
if not req_name:
- raise ValueError('Line "%s" is missing "#egg="' % (line))
+ raise ValueError(
+ 'Line "%s" is missing "#egg="' % (line)
+ )
- link = line.replace('-e ', '').strip()
+ link = line.replace("-e ", "").strip()
return link, req_name[0]
return None, None
- with open(requirements_file_path, 'r') as fp:
+ with open(requirements_file_path, "r") as fp:
for line in fp.readlines():
line = line.strip()
- if line.startswith('#') or not line:
+ if line.startswith("#") or not line:
continue
link, req_name = _get_link(line=line)
@@ -131,8 +135,8 @@ def _get_link(line):
else:
req_name = line
- if ';' in req_name:
- req_name = req_name.split(';')[0].strip()
+ if ";" in req_name:
+ req_name = req_name.split(";")[0].strip()
reqs.append(req_name)
@@ -146,7 +150,7 @@ def apply_vagrant_workaround():
Note: Without this workaround, setup.py sdist will fail when running inside a shared directory
(nfs / virtualbox shared folders).
"""
- if os.environ.get('USER', None) == 'vagrant':
+ if os.environ.get("USER", None) == "vagrant":
del os.link
@@ -155,14 +159,13 @@ def get_version_string(init_file):
Read __version__ string for an init file.
"""
- with open(init_file, 'r') as fp:
+ with open(init_file, "r") as fp:
content = fp.read()
- version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
- content, re.M)
+ version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M)
if version_match:
return version_match.group(1)
- raise RuntimeError('Unable to find version string in %s.' % (init_file))
+ raise RuntimeError("Unable to find version string in %s." % (init_file))
# alias for get_version_string
diff --git a/st2tests/integration/orquesta/base.py b/st2tests/integration/orquesta/base.py
index 52e2277e4c..f5f13cce04 100644
--- a/st2tests/integration/orquesta/base.py
+++ b/st2tests/integration/orquesta/base.py
@@ -30,7 +30,7 @@
LIVEACTION_LAUNCHED_STATUSES = [
action_constants.LIVEACTION_STATUS_REQUESTED,
action_constants.LIVEACTION_STATUS_SCHEDULED,
- action_constants.LIVEACTION_STATUS_RUNNING
+ action_constants.LIVEACTION_STATUS_RUNNING,
]
DEFAULT_WAIT_FIXED = 500
@@ -42,10 +42,9 @@ def retry_on_exceptions(exc):
class WorkflowControlTestCaseMixin(object):
-
def _create_temp_file(self):
_, temp_file_path = tempfile.mkstemp()
- os.chmod(temp_file_path, 0o755) # nosec
+ os.chmod(temp_file_path, 0o755) # nosec
return temp_file_path
def _delete_temp_file(self, temp_file_path):
@@ -57,18 +56,23 @@ def _delete_temp_file(self, temp_file_path):
class TestWorkflowExecution(unittest2.TestCase):
-
@classmethod
def setUpClass(cls):
- cls.st2client = st2.Client(base_url='http://127.0.0.1')
+ cls.st2client = st2.Client(base_url="http://127.0.0.1")
- def _execute_workflow(self, action, parameters=None, execute_async=True,
- expected_status=None, expected_result=None):
+ def _execute_workflow(
+ self,
+ action,
+ parameters=None,
+ execute_async=True,
+ expected_status=None,
+ expected_result=None,
+ ):
ex = models.LiveAction(action=action, parameters=(parameters or {}))
ex = self.st2client.executions.create(ex)
self.assertIsNotNone(ex.id)
- self.assertEqual(ex.action['ref'], action)
+ self.assertEqual(ex.action["ref"], action)
self.assertIn(ex.status, LIVEACTION_LAUNCHED_STATUSES)
if execute_async:
@@ -88,14 +92,16 @@ def _execute_workflow(self, action, parameters=None, execute_async=True,
@retrying.retry(
retry_on_exception=retry_on_exceptions,
- wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY)
+ wait_fixed=DEFAULT_WAIT_FIXED,
+ stop_max_delay=DEFAULT_STOP_MAX_DELAY,
+ )
def _wait_for_state(self, ex, states):
if isinstance(states, six.string_types):
states = [states]
for state in states:
if state not in action_constants.LIVEACTION_STATUSES:
- raise ValueError('Status %s is not valid.' % state)
+ raise ValueError("Status %s is not valid." % state)
try:
ex = self.st2client.executions.get_by_id(ex.id)
@@ -104,8 +110,7 @@ def _wait_for_state(self, ex, states):
if ex.status in action_constants.LIVEACTION_COMPLETED_STATES:
raise Exception(
'Execution is in completed state "%s" and '
- 'does not match expected state(s). %s' %
- (ex.status, ex.result)
+ "does not match expected state(s). %s" % (ex.status, ex.result)
)
else:
raise
@@ -117,13 +122,16 @@ def _get_children(self, ex):
@retrying.retry(
retry_on_exception=retry_on_exceptions,
- wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY)
+ wait_fixed=DEFAULT_WAIT_FIXED,
+ stop_max_delay=DEFAULT_STOP_MAX_DELAY,
+ )
def _wait_for_task(self, ex, task, status=None, num_task_exs=1):
ex = self.st2client.executions.get_by_id(ex.id)
task_exs = [
- task_ex for task_ex in self._get_children(ex)
- if task_ex.context.get('orquesta', {}).get('task_name', '') == task
+ task_ex
+ for task_ex in self._get_children(ex)
+ if task_ex.context.get("orquesta", {}).get("task_name", "") == task
]
try:
@@ -131,8 +139,9 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1):
except:
if ex.status in action_constants.LIVEACTION_COMPLETED_STATES:
raise Exception(
- 'Execution is in completed state and does not match expected number of '
- 'tasks. Expected: %s Actual: %s' % (str(num_task_exs), str(len(task_exs)))
+ "Execution is in completed state and does not match expected number of "
+ "tasks. Expected: %s Actual: %s"
+ % (str(num_task_exs), str(len(task_exs)))
)
else:
raise
@@ -143,7 +152,7 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1):
except:
if ex.status in action_constants.LIVEACTION_COMPLETED_STATES:
raise Exception(
- 'Execution is in completed state and not all tasks '
+ "Execution is in completed state and not all tasks "
'match expected status "%s".' % status
)
else:
@@ -153,17 +162,19 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1):
@retrying.retry(
retry_on_exception=retry_on_exceptions,
- wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY)
+ wait_fixed=DEFAULT_WAIT_FIXED,
+ stop_max_delay=DEFAULT_STOP_MAX_DELAY,
+ )
def _wait_for_completion(self, ex):
ex = self._wait_for_state(ex, action_constants.LIVEACTION_COMPLETED_STATES)
try:
- self.assertTrue(hasattr(ex, 'result'))
+ self.assertTrue(hasattr(ex, "result"))
except:
if ex.status in action_constants.LIVEACTION_COMPLETED_STATES:
raise Exception(
- 'Execution is in completed state and does not '
- 'contain expected result.'
+ "Execution is in completed state and does not "
+ "contain expected result."
)
else:
raise
diff --git a/st2tests/integration/orquesta/test_performance.py b/st2tests/integration/orquesta/test_performance.py
index e68ecc7f5f..899b3090f9 100644
--- a/st2tests/integration/orquesta/test_performance.py
+++ b/st2tests/integration/orquesta/test_performance.py
@@ -27,34 +27,35 @@
class WiringTest(base.TestWorkflowExecution):
-
def test_concurrent_load(self):
load_count = 3
delay_poll = load_count * 5
- wf_name = 'examples.orquesta-mock-create-vm'
- wf_input = {'vm_name': 'demo1', 'meta': {'demo1.itests.org': '10.3.41.99'}}
+ wf_name = "examples.orquesta-mock-create-vm"
+ wf_input = {"vm_name": "demo1", "meta": {"demo1.itests.org": "10.3.41.99"}}
exs = [self._execute_workflow(wf_name, wf_input) for i in range(load_count)]
eventlet.sleep(delay_poll)
for ex in exs:
e = self._wait_for_completion(ex)
- self.assertEqual(e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result))
- self.assertIn('output', e.result)
- self.assertIn('vm_id', e.result['output'])
+ self.assertEqual(
+ e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result)
+ )
+ self.assertIn("output", e.result)
+ self.assertIn("vm_id", e.result["output"])
def test_with_items_load(self):
- wf_name = 'examples.orquesta-with-items-concurrency'
+ wf_name = "examples.orquesta-with-items-concurrency"
num_items = 10
concurrency = 10
members = [str(i).zfill(5) for i in range(0, num_items)]
- wf_input = {'members': members, 'concurrency': concurrency}
+ wf_input = {"members": members, "concurrency": concurrency}
- message = '%s, resistance is futile!'
- expected_output = {'items': [message % i for i in members]}
- expected_result = {'output': expected_output}
+ message = "%s, resistance is futile!"
+ expected_output = {"items": [message % i for i in members]}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
diff --git a/st2tests/integration/orquesta/test_wiring.py b/st2tests/integration/orquesta/test_wiring.py
index f542c0d779..3e07d7b3fe 100644
--- a/st2tests/integration/orquesta/test_wiring.py
+++ b/st2tests/integration/orquesta/test_wiring.py
@@ -23,13 +23,12 @@
class WiringTest(base.TestWorkflowExecution):
-
def test_sequential(self):
- wf_name = 'examples.orquesta-sequential'
- wf_input = {'name': 'Thanos'}
+ wf_name = "examples.orquesta-sequential"
+ wf_input = {"name": "Thanos"}
- expected_output = {'greeting': 'Thanos, All your base are belong to us!'}
- expected_result = {'output': expected_output}
+ expected_output = {"greeting": "Thanos, All your base are belong to us!"}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -38,18 +37,18 @@ def test_sequential(self):
self.assertDictEqual(ex.result, expected_result)
def test_join(self):
- wf_name = 'examples.orquesta-join'
+ wf_name = "examples.orquesta-join"
expected_output = {
- 'messages': [
- 'Fee fi fo fum',
- 'I smell the blood of an English man',
- 'Be alive, or be he dead',
- 'I\'ll grind his bones to make my bread'
+ "messages": [
+ "Fee fi fo fum",
+ "I smell the blood of an English man",
+ "Be alive, or be he dead",
+ "I'll grind his bones to make my bread",
]
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
@@ -58,10 +57,10 @@ def test_join(self):
self.assertDictEqual(ex.result, expected_result)
def test_cycle(self):
- wf_name = 'examples.orquesta-rollback-retry'
+ wf_name = "examples.orquesta-rollback-retry"
expected_output = None
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
@@ -70,12 +69,12 @@ def test_cycle(self):
self.assertDictEqual(ex.result, expected_result)
def test_action_less(self):
- wf_name = 'examples.orquesta-test-action-less-tasks'
- wf_input = {'name': 'Thanos'}
+ wf_name = "examples.orquesta-test-action-less-tasks"
+ wf_input = {"name": "Thanos"}
- message = 'Thanos, All your base are belong to us!'
- expected_output = {'greeting': message.upper()}
- expected_result = {'output': expected_output}
+ message = "Thanos, All your base are belong to us!"
+ expected_output = {"greeting": message.upper()}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -84,73 +83,72 @@ def test_action_less(self):
self.assertDictEqual(ex.result, expected_result)
def test_st2_runtime_context(self):
- wf_name = 'examples.orquesta-st2-ctx'
+ wf_name = "examples.orquesta-st2-ctx"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
- expected_output = {'callback': 'http://127.0.0.1:9101/v1/executions/%s' % str(ex.id)}
- expected_result = {'output': expected_output}
+ expected_output = {
+ "callback": "http://127.0.0.1:9101/v1/executions/%s" % str(ex.id)
+ }
+ expected_result = {"output": expected_output}
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertDictEqual(ex.result, expected_result)
def test_subworkflow(self):
- wf_name = 'examples.orquesta-subworkflow'
+ wf_name = "examples.orquesta-subworkflow"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self._wait_for_task(ex, 'start', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(ex, "start", ac_const.LIVEACTION_STATUS_SUCCEEDED)
- t2_ex = self._wait_for_task(ex, 'subworkflow', ac_const.LIVEACTION_STATUS_SUCCEEDED)[0]
- self._wait_for_task(t2_ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self._wait_for_task(t2_ex, 'task2', ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self._wait_for_task(t2_ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ t2_ex = self._wait_for_task(
+ ex, "subworkflow", ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )[0]
+ self._wait_for_task(t2_ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(t2_ex, "task2", ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(t2_ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self._wait_for_task(ex, 'finish', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(ex, "finish", ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_output_on_error(self):
- wf_name = 'examples.orquesta-output-on-error'
+ wf_name = "examples.orquesta-output-on-error"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
- expected_output = {
- 'progress': 25
- }
+ expected_output = {"progress": 25}
expected_errors = [
{
- 'type': 'error',
- 'task_id': 'task2',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "type": "error",
+ "task_id": "task2",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
}
]
- expected_result = {
- 'errors': expected_errors,
- 'output': expected_output
- }
+ expected_result = {"errors": expected_errors, "output": expected_output}
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
self.assertDictEqual(ex.result, expected_result)
def test_config_context_renders(self):
config_value = "Testing"
- wf_name = 'examples.render_config_context'
+ wf_name = "examples.render_config_context"
- expected_output = {'context_value': config_value}
- expected_result = {'output': expected_output}
+ expected_output = {"context_value": config_value}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
@@ -159,21 +157,21 @@ def test_config_context_renders(self):
self.assertDictEqual(ex.result, expected_result)
def test_field_escaping(self):
- wf_name = 'examples.orquesta-test-field-escaping'
+ wf_name = "examples.orquesta-test-field-escaping"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
expected_output = {
- 'wf.hostname.with.periods': {
- 'hostname.domain.tld': 'vars.value.with.periods',
- 'hostname2.domain.tld': {
- 'stdout': 'vars.nested.value.with.periods',
+ "wf.hostname.with.periods": {
+ "hostname.domain.tld": "vars.value.with.periods",
+ "hostname2.domain.tld": {
+ "stdout": "vars.nested.value.with.periods",
},
},
- 'wf.output.with.periods': 'vars.nested.value.with.periods',
+ "wf.output.with.periods": "vars.nested.value.with.periods",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
self.assertDictEqual(ex.result, expected_result)
diff --git a/st2tests/integration/orquesta/test_wiring_cancel.py b/st2tests/integration/orquesta/test_wiring_cancel.py
index ff9d0d378f..0e4edaf918 100644
--- a/st2tests/integration/orquesta/test_wiring_cancel.py
+++ b/st2tests/integration/orquesta/test_wiring_cancel.py
@@ -22,7 +22,9 @@
from st2common.constants import action as ac_const
-class CancellationWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin):
+class CancellationWiringTest(
+ base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin
+):
temp_file_path = None
@@ -44,9 +46,9 @@ def test_cancellation(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path, 'message': 'foobar'}
- ex = self._execute_workflow('examples.orquesta-test-cancel', params)
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ params = {"tempfile": path, "message": "foobar"}
+ ex = self._execute_workflow("examples.orquesta-test-cancel", params)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the workflow before the temp file is created. The workflow will be paused
# but task1 will still be running to allow for graceful exit.
@@ -63,7 +65,7 @@ def test_cancellation(self):
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
# Task is completed successfully for graceful exit.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Get the updated execution with task result.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
@@ -74,15 +76,15 @@ def test_task_cancellation(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path, 'message': 'foobar'}
- ex = self._execute_workflow('examples.orquesta-test-cancel', params)
- task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ params = {"tempfile": path, "message": "foobar"}
+ ex = self._execute_workflow("examples.orquesta-test-cancel", params)
+ task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the task execution.
self.st2client.executions.delete(task_exs[0])
# Wait for the task and parent workflow to be canceled.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED)
# Get the updated execution with task result.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
@@ -93,10 +95,10 @@ def test_cancellation_cascade_down_to_subworkflow(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path, 'message': 'foobar'}
- action_ref = 'examples.orquesta-test-cancel-subworkflow'
+ params = {"tempfile": path, "message": "foobar"}
+ action_ref = "examples.orquesta-test-cancel-subworkflow"
ex = self._execute_workflow(action_ref, params)
- task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
subwf_ex = task_exs[0]
# Cancel the workflow before the temp file is deleted. The workflow will be canceled
@@ -123,10 +125,10 @@ def test_cancellation_cascade_up_from_subworkflow(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path, 'message': 'foobar'}
- action_ref = 'examples.orquesta-test-cancel-subworkflow'
+ params = {"tempfile": path, "message": "foobar"}
+ action_ref = "examples.orquesta-test-cancel-subworkflow"
ex = self._execute_workflow(action_ref, params)
- task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
subwf_ex = task_exs[0]
# Cancel the workflow before the temp file is deleted. The workflow will be canceled
@@ -155,12 +157,12 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path, 'file2': path}
- action_ref = 'examples.orquesta-test-cancel-subworkflows'
+ params = {"file1": path, "file2": path}
+ action_ref = "examples.orquesta-test-cancel-subworkflows"
ex = self._execute_workflow(action_ref, params)
- task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
subwf_ex_1 = task_exs[0]
- task_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ task_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
subwf_ex_2 = task_exs[0]
# Cancel the workflow before the temp file is deleted. The workflow will be canceled
@@ -168,19 +170,27 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self):
self.st2client.executions.delete(subwf_ex_1)
# Assert subworkflow is canceling.
- subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING)
+ subwf_ex_1 = self._wait_for_state(
+ subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING
+ )
# Assert main workflow and the other subworkflow is canceling.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELING)
- subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING)
+ subwf_ex_2 = self._wait_for_state(
+ subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING
+ )
# Delete the temporary file.
os.remove(path)
self.assertFalse(os.path.exists(path))
# Assert subworkflows are canceled.
- subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED)
- subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED)
+ subwf_ex_1 = self._wait_for_state(
+ subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED
+ )
+ subwf_ex_2 = self._wait_for_state(
+ subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED
+ )
# Assert main workflow is canceled.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
diff --git a/st2tests/integration/orquesta/test_wiring_data_flow.py b/st2tests/integration/orquesta/test_wiring_data_flow.py
index a9569cf693..ed5fbfa23a 100644
--- a/st2tests/integration/orquesta/test_wiring_data_flow.py
+++ b/st2tests/integration/orquesta/test_wiring_data_flow.py
@@ -27,13 +27,12 @@
class WiringTest(base.TestWorkflowExecution):
-
def test_data_flow(self):
- wf_name = 'examples.orquesta-data-flow'
- wf_input = {'a1': 'fee fi fo fum'}
+ wf_name = "examples.orquesta-data-flow"
+ wf_input = {"a1": "fee fi fo fum"}
- expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']}
- expected_result = {'output': expected_output}
+ expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -42,15 +41,15 @@ def test_data_flow(self):
self.assertDictEqual(ex.result, expected_result)
def test_data_flow_unicode(self):
- wf_name = 'examples.orquesta-data-flow'
- wf_input = {'a1': '床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉'}
+ wf_name = "examples.orquesta-data-flow"
+ wf_input = {"a1": "床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉"}
expected_output = {
- 'a5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1'],
- 'b5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1']
+ "a5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"],
+ "b5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"],
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -59,16 +58,15 @@ def test_data_flow_unicode(self):
self.assertDictEqual(ex.result, expected_result)
def test_data_flow_unicode_concat_with_ascii(self):
- wf_name = 'examples.orquesta-sequential'
- wf_input = {'name': '薩諾斯'}
+ wf_name = "examples.orquesta-sequential"
+ wf_input = {"name": "薩諾斯"}
expected_output = {
- 'greeting': '%s, All your base are belong to us!' % (
- wf_input['name'].decode('utf-8') if six.PY2 else wf_input['name']
- )
+ "greeting": "%s, All your base are belong to us!"
+ % (wf_input["name"].decode("utf-8") if six.PY2 else wf_input["name"])
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -77,15 +75,17 @@ def test_data_flow_unicode_concat_with_ascii(self):
self.assertDictEqual(ex.result, expected_result)
def test_data_flow_big_data_size(self):
- wf_name = 'examples.orquesta-data-flow'
+ wf_name = "examples.orquesta-data-flow"
data_length = 100000
- data = ''.join(random.choice(string.ascii_lowercase) for _ in range(data_length))
+ data = "".join(
+ random.choice(string.ascii_lowercase) for _ in range(data_length)
+ )
- wf_input = {'a1': data}
+ wf_input = {"a1": data}
- expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']}
- expected_result = {'output': expected_output}
+ expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
diff --git a/st2tests/integration/orquesta/test_wiring_delay.py b/st2tests/integration/orquesta/test_wiring_delay.py
index f825475479..32b923b923 100644
--- a/st2tests/integration/orquesta/test_wiring_delay.py
+++ b/st2tests/integration/orquesta/test_wiring_delay.py
@@ -23,13 +23,12 @@
class TaskDelayWiringTest(base.TestWorkflowExecution):
-
def test_task_delay(self):
- wf_name = 'examples.orquesta-delay'
- wf_input = {'name': 'Thanos', 'delay': 1}
+ wf_name = "examples.orquesta-delay"
+ wf_input = {"name": "Thanos", "delay": 1}
- expected_output = {'greeting': 'Thanos, All your base are belong to us!'}
- expected_result = {'output': expected_output}
+ expected_output = {"greeting": "Thanos, All your base are belong to us!"}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -38,12 +37,12 @@ def test_task_delay(self):
self.assertDictEqual(ex.result, expected_result)
def test_task_delay_workflow_cancellation(self):
- wf_name = 'examples.orquesta-delay'
- wf_input = {'name': 'Thanos', 'delay': 300}
+ wf_name = "examples.orquesta-delay"
+ wf_input = {"name": "Thanos", "delay": 300}
# Launch workflow and task1 should be delayed.
ex = self._execute_workflow(wf_name, wf_input)
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED)
# Cancel the workflow before the temp file is created. The workflow will be paused
# but task1 will still be running to allow for graceful exit.
@@ -53,24 +52,24 @@ def test_task_delay_workflow_cancellation(self):
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
# Task execution should be canceled.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED)
# Get the updated execution with task result.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
def test_task_delay_task_cancellation(self):
- wf_name = 'examples.orquesta-delay'
- wf_input = {'name': 'Thanos', 'delay': 300}
+ wf_name = "examples.orquesta-delay"
+ wf_input = {"name": "Thanos", "delay": 300}
# Launch workflow and task1 should be delayed.
ex = self._execute_workflow(wf_name, wf_input)
- task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED)
+ task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED)
# Cancel the task execution.
self.st2client.executions.delete(task_exs[0])
# Wait for the task and parent workflow to be canceled.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED)
# Get the updated execution with task result.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
diff --git a/st2tests/integration/orquesta/test_wiring_error_handling.py b/st2tests/integration/orquesta/test_wiring_error_handling.py
index f3c9f87fdd..130a68c7c5 100644
--- a/st2tests/integration/orquesta/test_wiring_error_handling.py
+++ b/st2tests/integration/orquesta/test_wiring_error_handling.py
@@ -22,236 +22,235 @@
class ErrorHandlingTest(base.TestWorkflowExecution):
-
def test_inspection_error(self):
expected_errors = [
{
- 'type': 'content',
- 'message': 'The action "std.noop" is not registered in the database.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action',
- 'spec_path': 'tasks.task3.action'
+ "type": "content",
+ "message": 'The action "std.noop" is not registered in the database.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action",
+ "spec_path": "tasks.task3.action",
},
{
- 'type': 'context',
- 'language': 'yaql',
- 'expression': '<% ctx().foobar %>',
- 'message': 'Variable "foobar" is referenced before assignment.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input',
- 'spec_path': 'tasks.task1.input',
+ "type": "context",
+ "language": "yaql",
+ "expression": "<% ctx().foobar %>",
+ "message": 'Variable "foobar" is referenced before assignment.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input",
+ "spec_path": "tasks.task1.input",
},
{
- 'type': 'expression',
- 'language': 'yaql',
- 'expression': '<% <% succeeded() %>',
- 'message': (
- 'Parse error: unexpected \'<\' at '
- 'position 0 of expression \'<% succeeded()\''
+ "type": "expression",
+ "language": "yaql",
+ "expression": "<% <% succeeded() %>",
+ "message": (
+ "Parse error: unexpected '<' at "
+ "position 0 of expression '<% succeeded()'"
),
- 'schema_path': (
- r'properties.tasks.patternProperties.^\w+$.'
- 'properties.next.items.properties.when'
+ "schema_path": (
+ r"properties.tasks.patternProperties.^\w+$."
+ "properties.next.items.properties.when"
),
- 'spec_path': 'tasks.task2.next[0].when'
+ "spec_path": "tasks.task2.next[0].when",
},
{
- 'type': 'syntax',
- 'message': (
- '[{\'cmd\': \'echo <% ctx().macro %>\'}] is '
- 'not valid under any of the given schemas'
+ "type": "syntax",
+ "message": (
+ "[{'cmd': 'echo <% ctx().macro %>'}] is "
+ "not valid under any of the given schemas"
),
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf',
- 'spec_path': 'tasks.task2.input'
- }
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf",
+ "spec_path": "tasks.task2.input",
+ },
]
- ex = self._execute_workflow('examples.orquesta-fail-inspection')
+ ex = self._execute_workflow("examples.orquesta-fail-inspection")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_input_error(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(8).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(8).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- ex = self._execute_workflow('examples.orquesta-fail-input-rendering')
+ ex = self._execute_workflow("examples.orquesta-fail-input-rendering")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_vars_error(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(8).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(8).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- ex = self._execute_workflow('examples.orquesta-fail-vars-rendering')
+ ex = self._execute_workflow("examples.orquesta-fail-vars-rendering")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_start_task_error(self):
self.maxDiff = None
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% ctx().name.value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% ctx().name.value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
),
- 'task_id': 'task1',
- 'route': 0
+ "task_id": "task1",
+ "route": 0,
},
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to resolve key \'greeting\' '
- 'in expression \'<% ctx().greeting %>\' from context.'
- )
- }
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to resolve key 'greeting' "
+ "in expression '<% ctx().greeting %>' from context."
+ ),
+ },
]
- ex = self._execute_workflow('examples.orquesta-fail-start-task')
+ ex = self._execute_workflow("examples.orquesta-fail-start-task")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_task_transition_error(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to resolve key \'value\' '
- 'in expression \'<% succeeded() and result().value %>\' from context.'
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to resolve key 'value' "
+ "in expression '<% succeeded() and result().value %>' from context."
),
- 'task_transition_id': 'task2__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "task2__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_output = {
- 'greeting': None
- }
+ expected_output = {"greeting": None}
- ex = self._execute_workflow('examples.orquesta-fail-task-transition')
+ ex = self._execute_workflow("examples.orquesta-fail-task-transition")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output})
+ self.assertDictEqual(
+ ex.result, {"errors": expected_errors, "output": expected_output}
+ )
def test_task_publish_error(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to resolve key \'value\' '
- 'in expression \'<% result().value %>\' from context.'
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to resolve key 'value' "
+ "in expression '<% result().value %>' from context."
),
- 'task_transition_id': 'task2__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "task2__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_output = {
- 'greeting': None
- }
+ expected_output = {"greeting": None}
- ex = self._execute_workflow('examples.orquesta-fail-task-publish')
+ ex = self._execute_workflow("examples.orquesta-fail-task-publish")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output})
+ self.assertDictEqual(
+ ex.result, {"errors": expected_errors, "output": expected_output}
+ )
def test_output_error(self):
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% abs(8).value %>\'. NoFunctionRegisteredException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% abs(8).value %>'. NoFunctionRegisteredException: "
'Unknown function "#property#value"'
- )
+ ),
}
]
- ex = self._execute_workflow('examples.orquesta-fail-output-rendering')
+ ex = self._execute_workflow("examples.orquesta-fail-output-rendering")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_task_content_errors(self):
expected_errors = [
{
- 'type': 'content',
- 'message': 'The action reference "echo" is not formatted correctly.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action',
- 'spec_path': 'tasks.task1.action'
+ "type": "content",
+ "message": 'The action reference "echo" is not formatted correctly.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action",
+ "spec_path": "tasks.task1.action",
},
{
- 'type': 'content',
- 'message': 'The action "core.echoz" is not registered in the database.',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action',
- 'spec_path': 'tasks.task2.action'
+ "type": "content",
+ "message": 'The action "core.echoz" is not registered in the database.',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action",
+ "spec_path": "tasks.task2.action",
},
{
- 'type': 'content',
- 'message': 'Action "core.echo" is missing required input "message".',
- 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input',
- 'spec_path': 'tasks.task3.input'
+ "type": "content",
+ "message": 'Action "core.echo" is missing required input "message".',
+ "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input",
+ "spec_path": "tasks.task3.input",
},
{
- 'type': 'content',
- 'message': 'Action "core.echo" has unexpected input "messages".',
- 'schema_path': (
- r'properties.tasks.patternProperties.^\w+$.properties.input.'
- r'patternProperties.^\w+$'
+ "type": "content",
+ "message": 'Action "core.echo" has unexpected input "messages".',
+ "schema_path": (
+ r"properties.tasks.patternProperties.^\w+$.properties.input."
+ r"patternProperties.^\w+$"
),
- 'spec_path': 'tasks.task3.input.messages'
- }
+ "spec_path": "tasks.task3.input.messages",
+ },
]
- ex = self._execute_workflow('examples.orquesta-fail-inspection-task-contents')
+ ex = self._execute_workflow("examples.orquesta-fail-inspection-task-contents")
ex = self._wait_for_completion(ex)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_remediate_then_fail(self):
expected_errors = [
{
- 'task_id': 'task1',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "task_id": "task1",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
},
{
- 'task_id': 'fail',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.'
- }
+ "task_id": "fail",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ },
]
- ex = self._execute_workflow('examples.orquesta-remediate-then-fail')
+ ex = self._execute_workflow("examples.orquesta-remediate-then-fail")
ex = self._wait_for_completion(ex)
# Assert that the log task is executed.
@@ -261,93 +260,95 @@ def test_remediate_then_fail(self):
# tasks is reached (With some hard limit) before failing
eventlet.sleep(2)
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED)
- self._wait_for_task(ex, 'log', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED)
+ self._wait_for_task(ex, "log", ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Assert workflow status and result.
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None})
+ self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None})
def test_fail_manually(self):
expected_errors = [
{
- 'task_id': 'task1',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "task_id": "task1",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
},
{
- 'task_id': 'fail',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.'
- }
+ "task_id": "fail",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ },
]
- expected_output = {
- 'message': '$%#&@#$!!!'
- }
+ expected_output = {"message": "$%#&@#$!!!"}
- wf_input = {'cmd': 'exit 1'}
- ex = self._execute_workflow('examples.orquesta-error-handling-fail-manually', wf_input)
+ wf_input = {"cmd": "exit 1"}
+ ex = self._execute_workflow(
+ "examples.orquesta-error-handling-fail-manually", wf_input
+ )
ex = self._wait_for_completion(ex)
# Assert task status.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED)
- self._wait_for_task(ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED)
+ self._wait_for_task(ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED)
# Assert workflow status and result.
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output})
+ self.assertDictEqual(
+ ex.result, {"errors": expected_errors, "output": expected_output}
+ )
def test_fail_continue(self):
expected_errors = [
{
- 'task_id': 'task1',
- 'type': 'error',
- 'message': 'Execution failed. See result for details.',
- 'result': {
- 'failed': True,
- 'return_code': 1,
- 'stderr': '',
- 'stdout': '',
- 'succeeded': False
- }
+ "task_id": "task1",
+ "type": "error",
+ "message": "Execution failed. See result for details.",
+ "result": {
+ "failed": True,
+ "return_code": 1,
+ "stderr": "",
+ "stdout": "",
+ "succeeded": False,
+ },
}
]
- expected_output = {
- 'message': '$%#&@#$!!!'
- }
+ expected_output = {"message": "$%#&@#$!!!"}
- wf_input = {'cmd': 'exit 1'}
- ex = self._execute_workflow('examples.orquesta-error-handling-continue', wf_input)
+ wf_input = {"cmd": "exit 1"}
+ ex = self._execute_workflow(
+ "examples.orquesta-error-handling-continue", wf_input
+ )
ex = self._wait_for_completion(ex)
# Assert task status.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED)
# Assert workflow status and result.
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
- self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output})
+ self.assertDictEqual(
+ ex.result, {"errors": expected_errors, "output": expected_output}
+ )
def test_fail_noop(self):
- expected_output = {
- 'message': '$%#&@#$!!!'
- }
+ expected_output = {"message": "$%#&@#$!!!"}
- wf_input = {'cmd': 'exit 1'}
- ex = self._execute_workflow('examples.orquesta-error-handling-noop', wf_input)
+ wf_input = {"cmd": "exit 1"}
+ ex = self._execute_workflow("examples.orquesta-error-handling-noop", wf_input)
ex = self._wait_for_completion(ex)
# Assert task status.
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED)
# Assert workflow status and result.
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertDictEqual(ex.result, {'output': expected_output})
+ self.assertDictEqual(ex.result, {"output": expected_output})
diff --git a/st2tests/integration/orquesta/test_wiring_functions.py b/st2tests/integration/orquesta/test_wiring_functions.py
index 91da108d39..538bf9ddd7 100644
--- a/st2tests/integration/orquesta/test_wiring_functions.py
+++ b/st2tests/integration/orquesta/test_wiring_functions.py
@@ -19,165 +19,174 @@
class FunctionsWiringTest(base.TestWorkflowExecution):
-
def test_data_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-data-functions'
+ wf_name = "examples.orquesta-test-yaql-data-functions"
expected_output = {
- 'data_json_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_2': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_3': '{"foo": {"bar": "foobar"}}',
- 'data_json_obj_1': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_2': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_3': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_4': {'foo': {'bar': 'foobar'}},
- 'data_yaml_str_1': 'foo:\n bar: foobar\n',
- 'data_yaml_str_2': 'foo:\n bar: foobar\n',
- 'data_query_1': ['foobar'],
- 'data_none_str': '%*****__%NONE%__*****%',
- 'data_str': 'foobar'
+ "data_json_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_2": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_3": '{"foo": {"bar": "foobar"}}',
+ "data_json_obj_1": {"foo": {"bar": "foobar"}},
+ "data_json_obj_2": {"foo": {"bar": "foobar"}},
+ "data_json_obj_3": {"foo": {"bar": "foobar"}},
+ "data_json_obj_4": {"foo": {"bar": "foobar"}},
+ "data_yaml_str_1": "foo:\n bar: foobar\n",
+ "data_yaml_str_2": "foo:\n bar: foobar\n",
+ "data_query_1": ["foobar"],
+ "data_none_str": "%*****__%NONE%__*****%",
+ "data_str": "foobar",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_data_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-data-functions'
+ wf_name = "examples.orquesta-test-jinja-data-functions"
expected_output = {
- 'data_json_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_2': '{"foo": {"bar": "foobar"}}',
- 'data_json_str_3': '{"foo": {"bar": "foobar"}}',
- 'data_json_obj_1': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_2': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_3': {'foo': {'bar': 'foobar'}},
- 'data_json_obj_4': {'foo': {'bar': 'foobar'}},
- 'data_yaml_str_1': 'foo:\n bar: foobar\n',
- 'data_yaml_str_2': 'foo:\n bar: foobar\n',
- 'data_query_1': ['foobar'],
- 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}',
- 'data_none_str': '%*****__%NONE%__*****%',
- 'data_str': 'foobar',
- 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n'
+ "data_json_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_2": '{"foo": {"bar": "foobar"}}',
+ "data_json_str_3": '{"foo": {"bar": "foobar"}}',
+ "data_json_obj_1": {"foo": {"bar": "foobar"}},
+ "data_json_obj_2": {"foo": {"bar": "foobar"}},
+ "data_json_obj_3": {"foo": {"bar": "foobar"}},
+ "data_json_obj_4": {"foo": {"bar": "foobar"}},
+ "data_yaml_str_1": "foo:\n bar: foobar\n",
+ "data_yaml_str_2": "foo:\n bar: foobar\n",
+ "data_query_1": ["foobar"],
+ "data_pipe_str_1": '{"foo": {"bar": "foobar"}}',
+ "data_none_str": "%*****__%NONE%__*****%",
+ "data_str": "foobar",
+ "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_path_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-path-functions'
+ wf_name = "examples.orquesta-test-yaql-path-functions"
- expected_output = {
- 'basename': 'file.txt',
- 'dirname': '/path/to/some'
- }
+ expected_output = {"basename": "file.txt", "dirname": "/path/to/some"}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_path_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-path-functions'
+ wf_name = "examples.orquesta-test-jinja-path-functions"
- expected_output = {
- 'basename': 'file.txt',
- 'dirname': '/path/to/some'
- }
+ expected_output = {"basename": "file.txt", "dirname": "/path/to/some"}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_regex_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-regex-functions'
+ wf_name = "examples.orquesta-test-yaql-regex-functions"
expected_output = {
- 'match': True,
- 'replace': 'wxyz',
- 'search': True,
- 'substring': '668 Infinite Dr'
+ "match": True,
+ "replace": "wxyz",
+ "search": True,
+ "substring": "668 Infinite Dr",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_regex_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-regex-functions'
+ wf_name = "examples.orquesta-test-jinja-regex-functions"
expected_output = {
- 'match': True,
- 'replace': 'wxyz',
- 'search': True,
- 'substring': '668 Infinite Dr'
+ "match": True,
+ "replace": "wxyz",
+ "search": True,
+ "substring": "668 Infinite Dr",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_time_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-time-functions'
+ wf_name = "examples.orquesta-test-yaql-time-functions"
- expected_output = {
- 'time': '3h25m45s'
- }
+ expected_output = {"time": "3h25m45s"}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_time_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-time-functions'
+ wf_name = "examples.orquesta-test-jinja-time-functions"
- expected_output = {
- 'time': '3h25m45s'
- }
+ expected_output = {"time": "3h25m45s"}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_version_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-version-functions'
+ wf_name = "examples.orquesta-test-yaql-version-functions"
expected_output = {
- 'compare_equal': 0,
- 'compare_more_than': -1,
- 'compare_less_than': 1,
- 'equal': True,
- 'more_than': False,
- 'less_than': False,
- 'match': True,
- 'bump_major': '1.0.0',
- 'bump_minor': '0.11.0',
- 'bump_patch': '0.10.1',
- 'strip_patch': '0.10'
+ "compare_equal": 0,
+ "compare_more_than": -1,
+ "compare_less_than": 1,
+ "equal": True,
+ "more_than": False,
+ "less_than": False,
+ "match": True,
+ "bump_major": "1.0.0",
+ "bump_minor": "0.11.0",
+ "bump_patch": "0.10.1",
+ "strip_patch": "0.10",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_version_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-version-functions'
+ wf_name = "examples.orquesta-test-jinja-version-functions"
expected_output = {
- 'compare_equal': 0,
- 'compare_more_than': -1,
- 'compare_less_than': 1,
- 'equal': True,
- 'more_than': False,
- 'less_than': False,
- 'match': True,
- 'bump_major': '1.0.0',
- 'bump_minor': '0.11.0',
- 'bump_patch': '0.10.1',
- 'strip_patch': '0.10'
+ "compare_equal": 0,
+ "compare_more_than": -1,
+ "compare_less_than": 1,
+ "equal": True,
+ "more_than": False,
+ "less_than": False,
+ "match": True,
+ "bump_major": "1.0.0",
+ "bump_minor": "0.11.0",
+ "bump_patch": "0.10.1",
+ "strip_patch": "0.10",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
diff --git a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py
index d02b8594c4..e4384c72cd 100644
--- a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py
+++ b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py
@@ -21,90 +21,76 @@
class DatastoreFunctionTest(base.TestWorkflowExecution):
@classmethod
- def set_kvp(cls, name, value, scope='system', secret=False):
+ def set_kvp(cls, name, value, scope="system", secret=False):
kvp = models.KeyValuePair(
- id=name,
- name=name,
- value=value,
- scope=scope,
- secret=secret
+ id=name, name=name, value=value, scope=scope, secret=secret
)
cls.st2client.keys.update(kvp)
@classmethod
- def del_kvp(cls, name, scope='system'):
- kvp = models.KeyValuePair(
- id=name,
- name=name,
- scope=scope
- )
+ def del_kvp(cls, name, scope="system"):
+ kvp = models.KeyValuePair(id=name, name=name, scope=scope)
cls.st2client.keys.delete(kvp)
def test_st2kv_system_scope(self):
- key = 'lakshmi'
- value = 'kanahansnasnasdlsajks'
+ key = "lakshmi"
+ value = "kanahansnasnasdlsajks"
self.set_kvp(key, value)
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {'key_name': 'system.%s' % key}
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": "system.%s" % key}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
self.del_kvp(key)
def test_st2kv_user_scope(self):
- key = 'winson'
- value = 'SoDiamondEng'
+ key = "winson"
+ value = "SoDiamondEng"
- self.set_kvp(key, value, 'user')
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {'key_name': key}
+ self.set_kvp(key, value, "user")
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": key}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
# self.del_kvp(key)
def test_st2kv_decrypt(self):
- key = 'kami'
- value = 'eggplant'
+ key = "kami"
+ value = "eggplant"
self.set_kvp(key, value, secret=True)
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True
- }
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value', output.result['output'])
- self.assertEqual(value, output.result['output']['value'])
+ self.assertIn("output", output.result)
+ self.assertIn("value", output.result["output"])
+ self.assertEqual(value, output.result["output"]["value"])
self.del_kvp(key)
def test_st2kv_nonexistent(self):
- key = 'matt'
+ key = "matt"
- wf_name = 'examples.orquesta-st2kv'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True
- }
+ wf_name = "examples.orquesta-st2kv"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True}
execution = self._execute_workflow(wf_name, wf_input)
@@ -112,69 +98,71 @@ def test_st2kv_nonexistent(self):
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_FAILED)
- expected_error = 'The key "%s" does not exist in the StackStorm datastore.' % key
+ expected_error = (
+ 'The key "%s" does not exist in the StackStorm datastore.' % key
+ )
- self.assertIn(expected_error, output.result['errors'][0]['message'])
+ self.assertIn(expected_error, output.result["errors"][0]["message"])
def test_st2kv_default_value(self):
- key = 'matt'
+ key = "matt"
- wf_name = 'examples.orquesta-st2kv-default'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True,
- 'default': 'stone'
- }
+ wf_name = "examples.orquesta-st2kv-default"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": "stone"}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value_from_yaql', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql'])
- self.assertIn('value_from_jinja', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja'])
+ self.assertIn("output", output.result)
+ self.assertIn("value_from_yaql", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_yaql"]
+ )
+ self.assertIn("value_from_jinja", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_jinja"]
+ )
def test_st2kv_default_value_with_empty_string(self):
- key = 'matt'
+ key = "matt"
- wf_name = 'examples.orquesta-st2kv-default'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True,
- 'default': ''
- }
+ wf_name = "examples.orquesta-st2kv-default"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": ""}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value_from_yaql', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql'])
- self.assertIn('value_from_jinja', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja'])
+ self.assertIn("output", output.result)
+ self.assertIn("value_from_yaql", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_yaql"]
+ )
+ self.assertIn("value_from_jinja", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_jinja"]
+ )
def test_st2kv_default_value_with_null(self):
- key = 'matt'
+ key = "matt"
- wf_name = 'examples.orquesta-st2kv-default'
- wf_input = {
- 'key_name': 'system.%s' % key,
- 'decrypt': True,
- 'default': None
- }
+ wf_name = "examples.orquesta-st2kv-default"
+ wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": None}
execution = self._execute_workflow(wf_name, wf_input)
output = self._wait_for_completion(execution)
self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- self.assertIn('output', output.result)
- self.assertIn('value_from_yaql', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql'])
- self.assertIn('value_from_jinja', output.result['output'])
- self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja'])
+ self.assertIn("output", output.result)
+ self.assertIn("value_from_yaql", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_yaql"]
+ )
+ self.assertIn("value_from_jinja", output.result["output"])
+ self.assertEqual(
+ wf_input["default"], output.result["output"]["value_from_jinja"]
+ )
diff --git a/st2tests/integration/orquesta/test_wiring_functions_task.py b/st2tests/integration/orquesta/test_wiring_functions_task.py
index 990b86752c..35d002c885 100644
--- a/st2tests/integration/orquesta/test_wiring_functions_task.py
+++ b/st2tests/integration/orquesta/test_wiring_functions_task.py
@@ -21,91 +21,94 @@
class FunctionsWiringTest(base.TestWorkflowExecution):
-
def test_task_functions_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-task-functions'
+ wf_name = "examples.orquesta-test-yaql-task-functions"
expected_output = {
- 'last_task4_result': 'False',
- 'task9__1__parent': 'task8__1',
- 'task9__2__parent': 'task8__2',
- 'that_task_by_name': 'task1',
- 'this_task_by_name': 'task1',
- 'this_task_no_arg': 'task1'
+ "last_task4_result": "False",
+ "task9__1__parent": "task8__1",
+ "task9__2__parent": "task8__2",
+ "that_task_by_name": "task1",
+ "this_task_by_name": "task1",
+ "this_task_no_arg": "task1",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_task_functions_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-task-functions'
+ wf_name = "examples.orquesta-test-jinja-task-functions"
expected_output = {
- 'last_task4_result': 'False',
- 'task9__1__parent': 'task8__1',
- 'task9__2__parent': 'task8__2',
- 'that_task_by_name': 'task1',
- 'this_task_by_name': 'task1',
- 'this_task_no_arg': 'task1'
+ "last_task4_result": "False",
+ "task9__1__parent": "task8__1",
+ "task9__2__parent": "task8__2",
+ "that_task_by_name": "task1",
+ "this_task_by_name": "task1",
+ "this_task_no_arg": "task1",
}
- expected_result = {'output': expected_output}
+ expected_result = {"output": expected_output}
- self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result)
+ self._execute_workflow(
+ wf_name, execute_async=False, expected_result=expected_result
+ )
def test_task_nonexistent_in_yaql(self):
- wf_name = 'examples.orquesta-test-yaql-task-nonexistent'
+ wf_name = "examples.orquesta-test-yaql-task-nonexistent"
expected_output = None
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'YaqlEvaluationException: Unable to evaluate expression '
- '\'<% task("task0") %>\'. ExpressionEvaluationException: '
+ "type": "error",
+ "message": (
+ "YaqlEvaluationException: Unable to evaluate expression "
+ "'<% task(\"task0\") %>'. ExpressionEvaluationException: "
'Unable to find task execution for "task0".'
),
- 'task_transition_id': 'continue__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "continue__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': expected_output, 'errors': expected_errors}
+ expected_result = {"output": expected_output, "errors": expected_errors}
self._execute_workflow(
wf_name,
execute_async=False,
expected_status=action_constants.LIVEACTION_STATUS_FAILED,
- expected_result=expected_result
+ expected_result=expected_result,
)
def test_task_nonexistent_in_jinja(self):
- wf_name = 'examples.orquesta-test-jinja-task-nonexistent'
+ wf_name = "examples.orquesta-test-jinja-task-nonexistent"
expected_output = None
expected_errors = [
{
- 'type': 'error',
- 'message': (
- 'JinjaEvaluationException: Unable to evaluate expression '
- '\'{{ task("task0") }}\'. ExpressionEvaluationException: '
+ "type": "error",
+ "message": (
+ "JinjaEvaluationException: Unable to evaluate expression "
+ "'{{ task(\"task0\") }}'. ExpressionEvaluationException: "
'Unable to find task execution for "task0".'
),
- 'task_transition_id': 'continue__t0',
- 'task_id': 'task1',
- 'route': 0
+ "task_transition_id": "continue__t0",
+ "task_id": "task1",
+ "route": 0,
}
]
- expected_result = {'output': expected_output, 'errors': expected_errors}
+ expected_result = {"output": expected_output, "errors": expected_errors}
self._execute_workflow(
wf_name,
execute_async=False,
expected_status=action_constants.LIVEACTION_STATUS_FAILED,
- expected_result=expected_result
+ expected_result=expected_result,
)
diff --git a/st2tests/integration/orquesta/test_wiring_inquiry.py b/st2tests/integration/orquesta/test_wiring_inquiry.py
index 71d0ed9e96..688929c041 100644
--- a/st2tests/integration/orquesta/test_wiring_inquiry.py
+++ b/st2tests/integration/orquesta/test_wiring_inquiry.py
@@ -23,75 +23,88 @@
class InquiryWiringTest(base.TestWorkflowExecution):
-
def test_basic_inquiry(self):
# Launch the workflow. The workflow will paused at the pending task.
- ex = self._execute_workflow('examples.orquesta-ask-basic')
+ ex = self._execute_workflow("examples.orquesta-ask-basic")
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
# Respond to the inquiry.
- ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING)
- self.st2client.inquiries.respond(ac_exs[0].id, {'approved': True})
+ ac_exs = self._wait_for_task(
+ ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING
+ )
+ self.st2client.inquiries.respond(ac_exs[0].id, {"approved": True})
# Wait for completion.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_consecutive_inquiries(self):
# Launch the workflow. The workflow will paused at the pending task.
- ex = self._execute_workflow('examples.orquesta-ask-consecutive')
+ ex = self._execute_workflow("examples.orquesta-ask-consecutive")
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
# Respond to the first inquiry.
- t1_ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING)
- self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True})
+ t1_ac_exs = self._wait_for_task(
+ ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING
+ )
+ self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True})
# Wait for the workflow to pause again.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
# Respond to the second inquiry.
- t2_ac_exs = self._wait_for_task(ex, 'get_confirmation', ac_const.LIVEACTION_STATUS_PENDING)
- self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True})
+ t2_ac_exs = self._wait_for_task(
+ ex, "get_confirmation", ac_const.LIVEACTION_STATUS_PENDING
+ )
+ self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True})
# Wait for completion.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_parallel_inquiries(self):
# Launch the workflow. The workflow will paused at the pending task.
- ex = self._execute_workflow('examples.orquesta-ask-parallel')
+ ex = self._execute_workflow("examples.orquesta-ask-parallel")
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
# Respond to the first inquiry.
- t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_PENDING)
- self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True})
- t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ t1_ac_exs = self._wait_for_task(
+ ex, "ask_jack", ac_const.LIVEACTION_STATUS_PENDING
+ )
+ self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True})
+ t1_ac_exs = self._wait_for_task(
+ ex, "ask_jack", ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
# Allow some time for the first inquiry to get processed.
eventlet.sleep(1)
# Respond to the second inquiry.
- t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_PENDING)
- self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True})
- t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ t2_ac_exs = self._wait_for_task(
+ ex, "ask_jill", ac_const.LIVEACTION_STATUS_PENDING
+ )
+ self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True})
+ t2_ac_exs = self._wait_for_task(
+ ex, "ask_jill", ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
# Wait for completion.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_nested_inquiry(self):
# Launch the workflow. The workflow will paused at the pending task.
- ex = self._execute_workflow('examples.orquesta-ask-nested')
+ ex = self._execute_workflow("examples.orquesta-ask-nested")
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
# Get the action execution of the subworkflow
- ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PAUSED)
+ ac_exs = self._wait_for_task(
+ ex, "get_approval", ac_const.LIVEACTION_STATUS_PAUSED
+ )
# Respond to the inquiry in the subworkflow.
t2_t2_ac_exs = self._wait_for_task(
- ac_exs[0],
- 'get_approval',
- ac_const.LIVEACTION_STATUS_PENDING
+ ac_exs[0], "get_approval", ac_const.LIVEACTION_STATUS_PENDING
)
- self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {'approved': True})
+ self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {"approved": True})
# Wait for completion.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
diff --git a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py
index 52eca1490f..9779ee26b8 100644
--- a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py
+++ b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py
@@ -22,7 +22,9 @@
from st2common.constants import action as ac_const
-class PauseResumeWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin):
+class PauseResumeWiringTest(
+ base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin
+):
temp_file_path_x = None
temp_file_path_y = None
@@ -47,9 +49,9 @@ def test_pause_and_resume(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-pause', params)
- self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-pause", params)
+ self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
# Cancel the workflow before the temp file is deleted. The workflow will be paused
# but task1 will still be running to allow for graceful exit.
@@ -77,10 +79,10 @@ def test_pause_and_resume_cascade_to_subworkflow(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the workflow before the temp file is deleted. The workflow will be paused
# but task1 will still be running to allow for graceful exit.
@@ -113,11 +115,11 @@ def test_pause_and_resume_cascade_to_subworkflows(self):
self.assertTrue(os.path.exists(path2))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path1, 'file2': path2}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params)
+ params = {"file1": path1, "file2": path2}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
+ tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the workflow before the temp files are deleted. The workflow will be paused
# but task1 will still be running to allow for graceful exit.
@@ -150,8 +152,12 @@ def test_pause_and_resume_cascade_to_subworkflows(self):
ex = self.st2client.executions.resume(ex.id)
# Wait for completion.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_pause_and_resume_cascade_from_subworkflow(self):
@@ -160,10 +166,10 @@ def test_pause_and_resume_cascade_from_subworkflow(self):
self.assertTrue(os.path.exists(path))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow before the temp file is deleted. The task will be
# paused but workflow will still be running.
@@ -188,7 +194,9 @@ def test_pause_and_resume_cascade_from_subworkflow(self):
tk_ac_ex = self._wait_for_state(tk_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused(self):
+ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused(
+ self,
+ ):
# Temp files are created during test setup. Ensure the temp files exist.
path1 = self.temp_file_path_x
self.assertTrue(os.path.exists(path1))
@@ -196,11 +204,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau
self.assertTrue(os.path.exists(path2))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path1, 'file2': path2}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params)
+ params = {"file1": path1, "file2": path2}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
+ tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow before the temp file is deleted. The task will be
# paused but workflow and the other subworkflow will still be running.
@@ -228,17 +236,25 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau
# The workflow will now be paused because no other task is running.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
# Resume the subworkflow.
tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id)
# Wait for completion.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running(self):
+ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running(
+ self,
+ ):
# Temp files are created during test setup. Ensure the temp files exist.
path1 = self.temp_file_path_x
self.assertTrue(os.path.exists(path1))
@@ -246,11 +262,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru
self.assertTrue(os.path.exists(path2))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path1, 'file2': path2}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params)
+ params = {"file1": path1, "file2": path2}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
+ tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow before the temp file is deleted. The task will be
# paused but workflow and the other subworkflow will still be running.
@@ -276,7 +292,9 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru
# The subworkflow will succeed while the other subworkflow is still running.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_RUNNING)
# Delete the temporary file for the other subworkflow.
@@ -284,8 +302,12 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru
self.assertFalse(os.path.exists(path2))
# Wait for completion.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self):
@@ -296,11 +318,11 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self):
self.assertTrue(os.path.exists(path2))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path1, 'file2': path2}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params)
+ params = {"file1": path1, "file2": path2}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
+ tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow before the temp file is deleted. The task will be
# paused but workflow and the other subworkflow will still be running.
@@ -336,7 +358,9 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self):
tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id)
# The subworkflow will succeed while the other subworkflow is still paused.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED)
@@ -344,8 +368,12 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self):
tk2_ac_ex = self.st2client.executions.resume(tk2_ac_ex.id)
# Wait for completion.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self):
@@ -356,11 +384,11 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self):
self.assertTrue(os.path.exists(path2))
# Launch the workflow. The workflow will wait for the temp file to be deleted.
- params = {'file1': path1, 'file2': path2}
- ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params)
+ params = {"file1": path1, "file2": path2}
+ ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params)
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING)
- tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING)
- tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING)
+ tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING)
+ tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING)
# Pause the subworkflow before the temp file is deleted. The task will be
# paused but workflow and the other subworkflow will still be running.
@@ -396,6 +424,10 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self):
ex = self.st2client.executions.resume(ex.id)
# Wait for completion.
- tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
- tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
+ tk1_ac_ex = self._wait_for_state(
+ tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
+ tk2_ac_ex = self._wait_for_state(
+ tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED
+ )
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
diff --git a/st2tests/integration/orquesta/test_wiring_rerun.py b/st2tests/integration/orquesta/test_wiring_rerun.py
index b7a6de0efe..2fafee76e4 100644
--- a/st2tests/integration/orquesta/test_wiring_rerun.py
+++ b/st2tests/integration/orquesta/test_wiring_rerun.py
@@ -43,106 +43,104 @@ def tearDown(self):
def test_rerun_workflow(self):
path = self.temp_dir_path
- with open(path, 'w') as f:
- f.write('1')
+ with open(path, "w") as f:
+ f.write("1")
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-rerun', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-rerun", params)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED)
orig_st2_ex_id = ex.id
- orig_wf_ex_id = ex.context['workflow_execution']
+ orig_wf_ex_id = ex.context["workflow_execution"]
- with open(path, 'w') as f:
- f.write('0')
+ with open(path, "w") as f:
+ f.write("0")
ex = self.st2client.executions.re_run(orig_st2_ex_id)
self.assertNotEqual(ex.id, orig_st2_ex_id)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertNotEqual(ex.context['workflow_execution'], orig_wf_ex_id)
+ self.assertNotEqual(ex.context["workflow_execution"], orig_wf_ex_id)
def test_rerun_task(self):
path = self.temp_dir_path
- with open(path, 'w') as f:
- f.write('1')
+ with open(path, "w") as f:
+ f.write("1")
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-rerun', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-rerun", params)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED)
orig_st2_ex_id = ex.id
- orig_wf_ex_id = ex.context['workflow_execution']
+ orig_wf_ex_id = ex.context["workflow_execution"]
- with open(path, 'w') as f:
- f.write('0')
+ with open(path, "w") as f:
+ f.write("0")
- ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2'])
+ ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"])
self.assertNotEqual(ex.id, orig_st2_ex_id)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id)
+ self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id)
def test_rerun_task_of_workflow_already_succeeded(self):
path = self.temp_dir_path
- with open(path, 'w') as f:
- f.write('0')
+ with open(path, "w") as f:
+ f.write("0")
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-rerun', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-rerun", params)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
orig_st2_ex_id = ex.id
- orig_wf_ex_id = ex.context['workflow_execution']
+ orig_wf_ex_id = ex.context["workflow_execution"]
- ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2'])
+ ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"])
self.assertNotEqual(ex.id, orig_st2_ex_id)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id)
+ self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id)
def test_rerun_and_reset_with_items_task(self):
path = self.temp_dir_path
- with open(path, 'w') as f:
- f.write('1')
+ with open(path, "w") as f:
+ f.write("1")
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED)
orig_st2_ex_id = ex.id
- orig_wf_ex_id = ex.context['workflow_execution']
+ orig_wf_ex_id = ex.context["workflow_execution"]
- with open(path, 'w') as f:
- f.write('0')
+ with open(path, "w") as f:
+ f.write("0")
- ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task1'])
+ ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task1"])
self.assertNotEqual(ex.id, orig_st2_ex_id)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id)
+ self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id)
- children = self.st2client.executions.get_property(ex.id, 'children')
+ children = self.st2client.executions.get_property(ex.id, "children")
self.assertEqual(len(children), 4)
def test_rerun_and_resume_with_items_task(self):
path = self.temp_dir_path
- with open(path, 'w') as f:
- f.write('1')
+ with open(path, "w") as f:
+ f.write("1")
- params = {'tempfile': path}
- ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params)
+ params = {"tempfile": path}
+ ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED)
orig_st2_ex_id = ex.id
- orig_wf_ex_id = ex.context['workflow_execution']
+ orig_wf_ex_id = ex.context["workflow_execution"]
- with open(path, 'w') as f:
- f.write('0')
+ with open(path, "w") as f:
+ f.write("0")
ex = self.st2client.executions.re_run(
- orig_st2_ex_id,
- tasks=['task1'],
- no_reset=['task1']
+ orig_st2_ex_id, tasks=["task1"], no_reset=["task1"]
)
self.assertNotEqual(ex.id, orig_st2_ex_id)
ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED)
- self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id)
+ self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id)
- children = self.st2client.executions.get_property(ex.id, 'children')
+ children = self.st2client.executions.get_property(ex.id, "children")
self.assertEqual(len(children), 2)
diff --git a/st2tests/integration/orquesta/test_wiring_task_retry.py b/st2tests/integration/orquesta/test_wiring_task_retry.py
index c8d3bd1889..7bb7f3f258 100644
--- a/st2tests/integration/orquesta/test_wiring_task_retry.py
+++ b/st2tests/integration/orquesta/test_wiring_task_retry.py
@@ -23,9 +23,8 @@
class TaskRetryWiringTest(base.TestWorkflowExecution):
-
def test_task_retry(self):
- wf_name = 'examples.orquesta-task-retry'
+ wf_name = "examples.orquesta-task-retry"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
@@ -34,14 +33,15 @@ def test_task_retry(self):
# Assert there are retries for the task.
task_exs = [
- task_ex for task_ex in self._get_children(ex)
- if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check'
+ task_ex
+ for task_ex in self._get_children(ex)
+ if task_ex.context.get("orquesta", {}).get("task_name", "") == "check"
]
self.assertGreater(len(task_exs), 1)
def test_task_retry_exhausted(self):
- wf_name = 'examples.orquesta-task-retry-exhausted'
+ wf_name = "examples.orquesta-task-retry-exhausted"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
@@ -51,16 +51,18 @@ def test_task_retry_exhausted(self):
# Assert the task has exhausted the number of retries
task_exs = [
- task_ex for task_ex in self._get_children(ex)
- if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check'
+ task_ex
+ for task_ex in self._get_children(ex)
+ if task_ex.context.get("orquesta", {}).get("task_name", "") == "check"
]
- self.assertListEqual(['failed'] * 3, [task_ex.status for task_ex in task_exs])
+ self.assertListEqual(["failed"] * 3, [task_ex.status for task_ex in task_exs])
# Assert the task following the retry task is not run.
task_exs = [
- task_ex for task_ex in self._get_children(ex)
- if task_ex.context.get('orquesta', {}).get('task_name', '') == 'delete'
+ task_ex
+ for task_ex in self._get_children(ex)
+ if task_ex.context.get("orquesta", {}).get("task_name", "") == "delete"
]
self.assertEqual(len(task_exs), 0)
diff --git a/st2tests/integration/orquesta/test_wiring_with_items.py b/st2tests/integration/orquesta/test_wiring_with_items.py
index b80e04e702..0bf83f1bf1 100644
--- a/st2tests/integration/orquesta/test_wiring_with_items.py
+++ b/st2tests/integration/orquesta/test_wiring_with_items.py
@@ -40,14 +40,14 @@ def tearDown(self):
super(WithItemsWiringTest, self).tearDown()
def test_with_items(self):
- wf_name = 'examples.orquesta-with-items'
+ wf_name = "examples.orquesta-with-items"
- members = ['Lakshmi', 'Lindsay', 'Tomaz', 'Matt', 'Drew']
- wf_input = {'members': members}
+ members = ["Lakshmi", "Lindsay", "Tomaz", "Matt", "Drew"]
+ wf_input = {"members": members}
- message = '%s, resistance is futile!'
- expected_output = {'items': [message % i for i in members]}
- expected_result = {'output': expected_output}
+ message = "%s, resistance is futile!"
+ expected_output = {"items": [message % i for i in members]}
+ expected_result = {"output": expected_output}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_completion(ex)
@@ -56,17 +56,17 @@ def test_with_items(self):
self.assertDictEqual(ex.result, expected_result)
def test_with_items_failure(self):
- wf_name = 'examples.orquesta-test-with-items-failure'
+ wf_name = "examples.orquesta-test-with-items-failure"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
- self._wait_for_task(ex, 'task1', num_task_exs=10)
+ self._wait_for_task(ex, "task1", num_task_exs=10)
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED)
def test_with_items_concurrency(self):
- wf_name = 'examples.orquesta-test-with-items'
+ wf_name = "examples.orquesta-test-with-items"
concurrency = 2
num_items = 5
@@ -74,22 +74,22 @@ def test_with_items_concurrency(self):
for i in range(0, num_items):
_, f = tempfile.mkstemp()
- os.chmod(f, 0o755) # nosec
+ os.chmod(f, 0o755) # nosec
self.tempfiles.append(f)
- wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency}
+ wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING])
- self._wait_for_task(ex, 'task1', num_task_exs=2)
+ self._wait_for_task(ex, "task1", num_task_exs=2)
os.remove(self.tempfiles[0])
os.remove(self.tempfiles[1])
- self._wait_for_task(ex, 'task1', num_task_exs=4)
+ self._wait_for_task(ex, "task1", num_task_exs=4)
os.remove(self.tempfiles[2])
os.remove(self.tempfiles[3])
- self._wait_for_task(ex, 'task1', num_task_exs=5)
+ self._wait_for_task(ex, "task1", num_task_exs=5)
os.remove(self.tempfiles[4])
ex = self._wait_for_completion(ex)
@@ -97,7 +97,7 @@ def test_with_items_concurrency(self):
self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_with_items_cancellation(self):
- wf_name = 'examples.orquesta-test-with-items'
+ wf_name = "examples.orquesta-test-with-items"
concurrency = 2
num_items = 2
@@ -105,19 +105,16 @@ def test_with_items_cancellation(self):
for i in range(0, num_items):
_, f = tempfile.mkstemp()
- os.chmod(f, 0o755) # nosec
+ os.chmod(f, 0o755) # nosec
self.tempfiles.append(f)
- wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency}
+ wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING])
# Wait for action executions to run.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_RUNNING,
- num_task_exs=concurrency
+ ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency
)
# Cancel the workflow execution.
@@ -133,17 +130,14 @@ def test_with_items_cancellation(self):
# Task is completed successfully for graceful exit.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_SUCCEEDED,
- num_task_exs=concurrency
+ ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency
)
# Wait for the ex to be canceled.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
def test_with_items_concurrency_cancellation(self):
- wf_name = 'examples.orquesta-test-with-items'
+ wf_name = "examples.orquesta-test-with-items"
concurrency = 2
num_items = 4
@@ -151,19 +145,16 @@ def test_with_items_concurrency_cancellation(self):
for i in range(0, num_items):
_, f = tempfile.mkstemp()
- os.chmod(f, 0o755) # nosec
+ os.chmod(f, 0o755) # nosec
self.tempfiles.append(f)
- wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency}
+ wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING])
# Wait for action executions to run.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_RUNNING,
- num_task_exs=concurrency
+ ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency
)
# Cancel the workflow execution.
@@ -180,27 +171,24 @@ def test_with_items_concurrency_cancellation(self):
# Task is completed successfully for graceful exit.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_SUCCEEDED,
- num_task_exs=concurrency
+ ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency
)
# Wait for the ex to be canceled.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED)
def test_with_items_pause_and_resume(self):
- wf_name = 'examples.orquesta-test-with-items'
+ wf_name = "examples.orquesta-test-with-items"
num_items = 2
self.tempfiles = []
for i in range(0, num_items):
_, f = tempfile.mkstemp()
- os.chmod(f, 0o755) # nosec
+ os.chmod(f, 0o755) # nosec
self.tempfiles.append(f)
- wf_input = {'tempfiles': self.tempfiles}
+ wf_input = {"tempfiles": self.tempfiles}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING])
@@ -217,10 +205,7 @@ def test_with_items_pause_and_resume(self):
# Wait for action executions for task to succeed.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_SUCCEEDED,
- num_task_exs=num_items
+ ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items
)
# Wait for the workflow execution to pause.
@@ -233,7 +218,7 @@ def test_with_items_pause_and_resume(self):
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_with_items_concurrency_pause_and_resume(self):
- wf_name = 'examples.orquesta-test-with-items'
+ wf_name = "examples.orquesta-test-with-items"
concurrency = 2
num_items = 4
@@ -241,10 +226,10 @@ def test_with_items_concurrency_pause_and_resume(self):
for i in range(0, num_items):
_, f = tempfile.mkstemp()
- os.chmod(f, 0o755) # nosec
+ os.chmod(f, 0o755) # nosec
self.tempfiles.append(f)
- wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency}
+ wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency}
ex = self._execute_workflow(wf_name, wf_input)
ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING])
@@ -261,10 +246,7 @@ def test_with_items_concurrency_pause_and_resume(self):
# Wait for action executions for task to succeed.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_SUCCEEDED,
- num_task_exs=concurrency
+ ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency
)
# Wait for the workflow execution to pause.
@@ -280,17 +262,14 @@ def test_with_items_concurrency_pause_and_resume(self):
# Wait for action executions for task to succeed.
self._wait_for_task(
- ex,
- 'task1',
- ac_const.LIVEACTION_STATUS_SUCCEEDED,
- num_task_exs=num_items
+ ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items
)
# Wait for completion.
ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED)
def test_subworkflow_empty_with_items(self):
- wf_name = 'examples.orquesta-test-subworkflow-empty-with-items'
+ wf_name = "examples.orquesta-test-subworkflow-empty-with-items"
ex = self._execute_workflow(wf_name)
ex = self._wait_for_completion(ex)
diff --git a/st2tests/setup.py b/st2tests/setup.py
index 3d5947be04..f5e17bb3a3 100644
--- a/st2tests/setup.py
+++ b/st2tests/setup.py
@@ -23,10 +23,10 @@
from dist_utils import apply_vagrant_workaround
from dist_utils import get_version_string
-ST2_COMPONENT = 'st2tests'
+ST2_COMPONENT = "st2tests"
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt')
-INIT_FILE = os.path.join(BASE_DIR, 'st2tests/__init__.py')
+REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt")
+INIT_FILE = os.path.join(BASE_DIR, "st2tests/__init__.py")
install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE)
@@ -39,15 +39,17 @@
setup(
name=ST2_COMPONENT,
version=get_version_string(INIT_FILE),
- description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT),
- author='StackStorm',
- author_email='info@stackstorm.com',
- license='Apache License (2.0)',
- url='https://stackstorm.com/',
+ description="{} StackStorm event-driven automation platform component".format(
+ ST2_COMPONENT
+ ),
+ author="StackStorm",
+ author_email="info@stackstorm.com",
+ license="Apache License (2.0)",
+ url="https://stackstorm.com/",
install_requires=install_reqs,
dependency_links=dep_links,
test_suite=ST2_COMPONENT,
zip_safe=False,
include_package_data=True,
- packages=find_packages(exclude=['setuptools', 'tests'])
+ packages=find_packages(exclude=["setuptools", "tests"]),
)
diff --git a/st2tests/st2tests/__init__.py b/st2tests/st2tests/__init__.py
index 3813790cca..65ec867472 100644
--- a/st2tests/st2tests/__init__.py
+++ b/st2tests/st2tests/__init__.py
@@ -23,11 +23,11 @@
__all__ = [
- 'EventletTestCase',
- 'DbTestCase',
- 'ExecutionDbTestCase',
- 'DbModelTestCase',
- 'WorkflowTestCase'
+ "EventletTestCase",
+ "DbTestCase",
+ "ExecutionDbTestCase",
+ "DbModelTestCase",
+ "WorkflowTestCase",
]
-__version__ = '3.5dev'
+__version__ = "3.5dev"
diff --git a/st2tests/st2tests/action_aliases.py b/st2tests/st2tests/action_aliases.py
index 301fd9a20f..88f02f9642 100644
--- a/st2tests/st2tests/action_aliases.py
+++ b/st2tests/st2tests/action_aliases.py
@@ -25,13 +25,13 @@
from st2common.util.pack import get_pack_ref_from_metadata
from st2common.exceptions.content import ParseException
from st2common.bootstrap.aliasesregistrar import AliasesRegistrar
-from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db
+from st2common.models.utils.action_alias_utils import (
+ extract_parameters_for_action_alias_db,
+)
from st2common.models.utils.action_alias_utils import extract_parameters
from st2tests.pack_resource import BasePackResourceTestCase
-__all__ = [
- 'BaseActionAliasTestCase'
-]
+__all__ = ["BaseActionAliasTestCase"]
class BaseActionAliasTestCase(BasePackResourceTestCase):
@@ -48,7 +48,9 @@ def setUp(self):
if not self.action_alias_name:
raise ValueError('"action_alias_name" class attribute needs to be provided')
- self.action_alias_db = self._get_action_alias_db_by_name(name=self.action_alias_name)
+ self.action_alias_db = self._get_action_alias_db_by_name(
+ name=self.action_alias_name
+ )
def assertCommandMatchesExactlyOneFormatString(self, format_strings, command):
"""
@@ -58,19 +60,22 @@ def assertCommandMatchesExactlyOneFormatString(self, format_strings, command):
for format_string in format_strings:
try:
- extract_parameters(format_str=format_string,
- param_stream=command)
+ extract_parameters(format_str=format_string, param_stream=command)
except ParseException:
continue
matched_format_strings.append(format_string)
if len(matched_format_strings) == 0:
- msg = ('Command "%s" didn\'t match any of the provided format strings' % (command))
+ msg = 'Command "%s" didn\'t match any of the provided format strings' % (
+ command
+ )
raise AssertionError(msg)
elif len(matched_format_strings) > 1:
- msg = ('Command "%s" matched multiple format strings: %s' %
- (command, ', '.join(matched_format_strings)))
+ msg = 'Command "%s" matched multiple format strings: %s' % (
+ command,
+ ", ".join(matched_format_strings),
+ )
raise AssertionError(msg)
def assertExtractedParametersMatch(self, format_string, command, parameters):
@@ -83,11 +88,14 @@ def assertExtractedParametersMatch(self, format_string, command, parameters):
extracted_params = extract_parameters_for_action_alias_db(
action_alias_db=self.action_alias_db,
format_str=format_string,
- param_stream=command)
+ param_stream=command,
+ )
if extracted_params != parameters:
- msg = ('Extracted parameters from command string "%s" against format string "%s"'
- ' didn\'t match the provided parameters: ' % (command, format_string))
+ msg = (
+ 'Extracted parameters from command string "%s" against format string "%s"'
+ " didn't match the provided parameters: " % (command, format_string)
+ )
# Note: We intercept the exception so we can can include diff for the dictionaries
try:
@@ -117,13 +125,14 @@ def _get_action_alias_db_by_name(self, name):
pack_loader = ContentPackLoader()
registrar = AliasesRegistrar(use_pack_cache=False)
- aliases_path = pack_loader.get_content_from_pack(pack_dir=base_pack_path,
- content_type='aliases')
+ aliases_path = pack_loader.get_content_from_pack(
+ pack_dir=base_pack_path, content_type="aliases"
+ )
aliases = registrar._get_aliases_from_pack(aliases_dir=aliases_path)
for alias_path in aliases:
- action_alias_db = registrar._get_action_alias_db(pack=pack,
- action_alias=alias_path,
- ignore_metadata_file_error=True)
+ action_alias_db = registrar._get_action_alias_db(
+ pack=pack, action_alias=alias_path, ignore_metadata_file_error=True
+ )
if action_alias_db.name == name:
return action_alias_db
diff --git a/st2tests/st2tests/actions.py b/st2tests/st2tests/actions.py
index f6026bc8bd..9caec9bca9 100644
--- a/st2tests/st2tests/actions.py
+++ b/st2tests/st2tests/actions.py
@@ -19,9 +19,7 @@
from st2tests.mocks.action import MockActionService
from st2tests.pack_resource import BasePackResourceTestCase
-__all__ = [
- 'BaseActionTestCase'
-]
+__all__ = ["BaseActionTestCase"]
class BaseActionTestCase(BasePackResourceTestCase):
@@ -35,7 +33,7 @@ def setUp(self):
super(BaseActionTestCase, self).setUp()
class_name = self.action_cls.__name__
- action_wrapper = MockActionWrapper(pack='tests', class_name=class_name)
+ action_wrapper = MockActionWrapper(pack="tests", class_name=class_name)
self.action_service = MockActionService(action_wrapper=action_wrapper)
def get_action_instance(self, config=None):
@@ -43,7 +41,9 @@ def get_action_instance(self, config=None):
Retrieve instance of the action class.
"""
# pylint: disable=not-callable
- instance = get_action_class_instance(action_cls=self.action_cls,
- config=config,
- action_service=self.action_service)
+ instance = get_action_class_instance(
+ action_cls=self.action_cls,
+ config=config,
+ action_service=self.action_service,
+ )
return instance
diff --git a/st2tests/st2tests/api.py b/st2tests/st2tests/api.py
index 3b48df737a..7000ddd9a1 100644
--- a/st2tests/st2tests/api.py
+++ b/st2tests/st2tests/api.py
@@ -34,19 +34,19 @@
from st2tests import config as tests_config
__all__ = [
- 'BaseFunctionalTest',
-
- 'FunctionalTest',
- 'APIControllerWithIncludeAndExcludeFilterTestCase',
- 'BaseInquiryControllerTestCase',
-
- 'FakeResponse',
- 'TestApp'
+ "BaseFunctionalTest",
+ "FunctionalTest",
+ "APIControllerWithIncludeAndExcludeFilterTestCase",
+ "BaseInquiryControllerTestCase",
+ "FakeResponse",
+ "TestApp",
]
-SUPER_SECRET_PARAMETER = 'SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS'
-ANOTHER_SUPER_SECRET_PARAMETER = 'ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING'
+SUPER_SECRET_PARAMETER = (
+ "SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS"
+)
+ANOTHER_SUPER_SECRET_PARAMETER = "ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING"
class ResponseValidationError(ValueError):
@@ -61,32 +61,37 @@ class TestApp(webtest.TestApp):
def do_request(self, req, **kwargs):
self.cookiejar.clear()
- if req.environ['REQUEST_METHOD'] != 'OPTIONS':
+ if req.environ["REQUEST_METHOD"] != "OPTIONS":
# Making sure endpoint handles OPTIONS method properly
- self.options(req.environ['PATH_INFO'])
+ self.options(req.environ["PATH_INFO"])
res = super(TestApp, self).do_request(req, **kwargs)
- if res.headers.get('Warning', None):
- raise ResponseValidationError('Endpoint produced invalid response. Make sure the '
- 'response matches OpenAPI scheme for the endpoint.')
+ if res.headers.get("Warning", None):
+ raise ResponseValidationError(
+ "Endpoint produced invalid response. Make sure the "
+ "response matches OpenAPI scheme for the endpoint."
+ )
- if not kwargs.get('expect_errors', None):
+ if not kwargs.get("expect_errors", None):
try:
body = res.body
except AssertionError as e:
- if 'Iterator read after closed' in six.text_type(e):
- body = b''
+ if "Iterator read after closed" in six.text_type(e):
+ body = b""
else:
raise e
- if six.b(SUPER_SECRET_PARAMETER) in body or \
- six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body:
- raise ResponseLeakError('Endpoint response contains secret parameter. '
- 'Find the leak.')
+ if (
+ six.b(SUPER_SECRET_PARAMETER) in body
+ or six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body
+ ):
+ raise ResponseLeakError(
+ "Endpoint response contains secret parameter. " "Find the leak."
+ )
- if 'Access-Control-Allow-Origin' not in res.headers:
- raise ResponseValidationError('Response missing a required CORS header')
+ if "Access-Control-Allow-Origin" not in res.headers:
+ raise ResponseValidationError("Response missing a required CORS header")
return res
@@ -113,19 +118,19 @@ def tearDown(self):
super(BaseFunctionalTest, self).tearDown()
# Reset mock context for API requests
- if getattr(self, 'request_context_mock', None):
+ if getattr(self, "request_context_mock", None):
self.request_context_mock.stop()
- if hasattr(Router, 'mock_context'):
- del(Router.mock_context)
+ if hasattr(Router, "mock_context"):
+ del Router.mock_context
@classmethod
def _do_setUpClass(cls):
tests_config.parse_args()
- cfg.CONF.set_default('enable', cls.enable_auth, group='auth')
+ cfg.CONF.set_default("enable", cls.enable_auth, group="auth")
- cfg.CONF.set_override(name='enable', override=False, group='rbac')
+ cfg.CONF.set_override(name="enable", override=False, group="rbac")
# TODO(manas) : register action types here for now. RunnerType registration can be moved
# to posting to /runnertypes but that implies implementing POST.
@@ -142,11 +147,8 @@ def use_user(self, user_db):
raise ValueError('"user_db" is mandatory')
mock_context = {
- 'user': user_db,
- 'auth_info': {
- 'method': 'authentication token',
- 'location': 'header'
- }
+ "user": user_db,
+ "auth_info": {"method": "authentication token", "location": "header"},
}
self.request_context_mock = mock.PropertyMock(return_value=mock_context)
Router.mock_context = self.request_context_mock
@@ -184,40 +186,48 @@ class APIControllerWithIncludeAndExcludeFilterTestCase(object):
# True if those tests are running with rbac enabled
rbac_enabled = False
- def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive(self):
+ def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive(
+ self,
+ ):
if self.rbac_enabled:
- self.use_user(self.users['admin'])
+ self.use_user(self.users["admin"])
- url = self.get_all_path + '?include_attributes=id&exclude_attributes=id'
+ url = self.get_all_path + "?include_attributes=id&exclude_attributes=id"
resp = self.app.get(url, expect_errors=True)
self.assertEqual(resp.status_int, 400)
- expected_msg = ('exclude.*? and include.*? arguments are mutually exclusive. '
- 'You need to provide either one or another, but not both.')
- self.assertRegexpMatches(resp.json['faultstring'], expected_msg)
+ expected_msg = (
+ "exclude.*? and include.*? arguments are mutually exclusive. "
+ "You need to provide either one or another, but not both."
+ )
+ self.assertRegexpMatches(resp.json["faultstring"], expected_msg)
def test_get_all_invalid_exclude_and_include_parameter(self):
if self.rbac_enabled:
- self.use_user(self.users['admin'])
+ self.use_user(self.users["admin"])
# 1. Invalid exclude_attributes field
- url = self.get_all_path + '?exclude_attributes=invalid_field'
+ url = self.get_all_path + "?exclude_attributes=invalid_field"
resp = self.app.get(url, expect_errors=True)
- expected_msg = ('Invalid or unsupported exclude attribute specified: .*invalid_field.*')
+ expected_msg = (
+ "Invalid or unsupported exclude attribute specified: .*invalid_field.*"
+ )
self.assertEqual(resp.status_int, 400)
- self.assertRegexpMatches(resp.json['faultstring'], expected_msg)
+ self.assertRegexpMatches(resp.json["faultstring"], expected_msg)
# 2. Invalid include_attributes field
- url = self.get_all_path + '?include_attributes=invalid_field'
+ url = self.get_all_path + "?include_attributes=invalid_field"
resp = self.app.get(url, expect_errors=True)
- expected_msg = ('Invalid or unsupported include attribute specified: .*invalid_field.*')
+ expected_msg = (
+ "Invalid or unsupported include attribute specified: .*invalid_field.*"
+ )
self.assertEqual(resp.status_int, 400)
- self.assertRegexpMatches(resp.json['faultstring'], expected_msg)
+ self.assertRegexpMatches(resp.json["faultstring"], expected_msg)
def test_get_all_include_attributes_filter(self):
if self.rbac_enabled:
- self.use_user(self.users['admin'])
+ self.use_user(self.users["admin"])
mandatory_include_fields = self.controller_cls.mandatory_include_fields_response
@@ -226,8 +236,10 @@ def test_get_all_include_attributes_filter(self):
object_ids = self._insert_mock_models()
# Valid include attribute - mandatory field which should always be included
- resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path,
- mandatory_include_fields[0]))
+ resp = self.app.get(
+ "%s?include_attributes=%s"
+ % (self.get_all_path, mandatory_include_fields[0])
+ )
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) >= 1)
@@ -245,7 +257,9 @@ def test_get_all_include_attributes_filter(self):
include_field = self.include_attribute_field_name
assert include_field not in mandatory_include_fields
- resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path, include_field))
+ resp = self.app.get(
+ "%s?include_attributes=%s" % (self.get_all_path, include_field)
+ )
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) >= 1)
@@ -263,7 +277,7 @@ def test_get_all_include_attributes_filter(self):
def test_get_all_exclude_attributes_filter(self):
if self.rbac_enabled:
- self.use_user(self.users['admin'])
+ self.use_user(self.users["admin"])
# Create any resources needed by those tests (if not already created inside setUp /
# setUpClass)
@@ -285,8 +299,9 @@ def test_get_all_exclude_attributes_filter(self):
# 2. Verify attribute is excluded when filter is provided
exclude_attribute = self.exclude_attribute_field_name
- resp = self.app.get('%s?exclude_attributes=%s' % (self.get_all_path,
- exclude_attribute))
+ resp = self.app.get(
+ "%s?exclude_attributes=%s" % (self.get_all_path, exclude_attribute)
+ )
self.assertEqual(resp.status_int, 200)
self.assertTrue(len(resp.json) >= 1)
@@ -300,8 +315,8 @@ def test_get_all_exclude_attributes_filter(self):
def assertResponseObjectContainsField(self, resp_item, field):
# Handle "." and nested fields
- if '.' in field:
- split = field.split('.')
+ if "." in field:
+ split = field.split(".")
for index, field_part in enumerate(split):
self.assertIn(field_part, resp_item)
@@ -336,7 +351,6 @@ def _do_delete(self, object_id):
class FakeResponse(object):
-
def __init__(self, text, status_code, reason):
self.text = text
self.status_code = status_code
@@ -354,24 +368,27 @@ class BaseActionExecutionControllerTestCase(object):
@staticmethod
def _get_actionexecution_id(resp):
- return resp.json['id']
+ return resp.json["id"]
@staticmethod
def _get_liveaction_id(resp):
- return resp.json['liveaction']['id']
+ return resp.json["liveaction"]["id"]
def _do_get_one(self, actionexecution_id, *args, **kwargs):
- return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs)
+ return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs)
def _do_post(self, liveaction, *args, **kwargs):
- return self.app.post_json('/v1/executions', liveaction, *args, **kwargs)
+ return self.app.post_json("/v1/executions", liveaction, *args, **kwargs)
def _do_delete(self, actionexecution_id, expect_errors=False):
- return self.app.delete('/v1/executions/%s' % actionexecution_id,
- expect_errors=expect_errors)
+ return self.app.delete(
+ "/v1/executions/%s" % actionexecution_id, expect_errors=expect_errors
+ )
def _do_put(self, actionexecution_id, updates, *args, **kwargs):
- return self.app.put_json('/v1/executions/%s' % actionexecution_id, updates, *args, **kwargs)
+ return self.app.put_json(
+ "/v1/executions/%s" % actionexecution_id, updates, *args, **kwargs
+ )
class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase):
@@ -380,6 +397,7 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase):
Inherits from CleanDbTestCase to preserve atomicity between tests
"""
+
from st2api import app
enable_auth = False
@@ -387,26 +405,27 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase):
@staticmethod
def _get_inquiry_id(resp):
- return resp.json['id']
+ return resp.json["id"]
def _do_get_execution(self, actionexecution_id, *args, **kwargs):
- return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs)
+ return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs)
def _do_get_one(self, inquiry_id, *args, **kwargs):
- return self.app.get('/v1/inquiries/%s' % inquiry_id, *args, **kwargs)
+ return self.app.get("/v1/inquiries/%s" % inquiry_id, *args, **kwargs)
def _do_get_all(self, limit=50, *args, **kwargs):
- return self.app.get('/v1/inquiries/?limit=%s' % limit, *args, **kwargs)
+ return self.app.get("/v1/inquiries/?limit=%s" % limit, *args, **kwargs)
def _do_respond(self, inquiry_id, response, *args, **kwargs):
- payload = {
- "id": inquiry_id,
- "response": response
- }
- return self.app.put_json('/v1/inquiries/%s' % inquiry_id, payload, *args, **kwargs)
+ payload = {"id": inquiry_id, "response": response}
+ return self.app.put_json(
+ "/v1/inquiries/%s" % inquiry_id, payload, *args, **kwargs
+ )
- def _do_create_inquiry(self, liveaction, result, status='pending', *args, **kwargs):
- post_resp = self.app.post_json('/v1/executions', liveaction, *args, **kwargs)
+ def _do_create_inquiry(self, liveaction, result, status="pending", *args, **kwargs):
+ post_resp = self.app.post_json("/v1/executions", liveaction, *args, **kwargs)
inquiry_id = self._get_inquiry_id(post_resp)
- updates = {'status': status, 'result': result}
- return self.app.put_json('/v1/executions/%s' % inquiry_id, updates, *args, **kwargs)
+ updates = {"status": status, "result": result}
+ return self.app.put_json(
+ "/v1/executions/%s" % inquiry_id, updates, *args, **kwargs
+ )
diff --git a/st2tests/st2tests/base.py b/st2tests/st2tests/base.py
index 75a8f7ce02..4a4964763d 100644
--- a/st2tests/st2tests/base.py
+++ b/st2tests/st2tests/base.py
@@ -19,6 +19,7 @@
# NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail.
# See https://github.com/StackStorm/st2/pull/4834 for details
from st2common.util.monkey_patch import monkey_patch
+
monkey_patch()
try:
@@ -50,6 +51,7 @@
# parse_args when BaseDbTestCase runs class setup. If that is removed, unit tests
# will failed due to conflict with duplicate DB keys.
import st2tests.config as tests_config
+
tests_config.parse_args()
from st2common.util.api import get_full_public_api_url
@@ -95,26 +97,23 @@
__all__ = [
- 'EventletTestCase',
- 'DbTestCase',
- 'DbModelTestCase',
- 'CleanDbTestCase',
- 'CleanFilesTestCase',
- 'IntegrationTestCase',
- 'RunnerTestCase',
- 'ExecutionDbTestCase',
- 'WorkflowTestCase',
-
+ "EventletTestCase",
+ "DbTestCase",
+ "DbModelTestCase",
+ "CleanDbTestCase",
+ "CleanFilesTestCase",
+ "IntegrationTestCase",
+ "RunnerTestCase",
+ "ExecutionDbTestCase",
+ "WorkflowTestCase",
# Pack test classes
- 'BaseSensorTestCase',
- 'BaseActionTestCase',
- 'BaseActionAliasTestCase',
-
- 'get_fixtures_path',
- 'get_resources_path',
-
- 'blocking_eventlet_spawn',
- 'make_mock_stream_readline'
+ "BaseSensorTestCase",
+ "BaseActionTestCase",
+ "BaseActionAliasTestCase",
+ "get_fixtures_path",
+ "get_resources_path",
+ "blocking_eventlet_spawn",
+ "make_mock_stream_readline",
]
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -135,7 +134,7 @@
ALL_MODELS.extend(rule_enforcement_model.MODELS)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
-TESTS_CONFIG_PATH = os.path.join(BASE_DIR, '../conf/st2.conf')
+TESTS_CONFIG_PATH = os.path.join(BASE_DIR, "../conf/st2.conf")
class RunnerTestCase(unittest2.TestCase):
@@ -148,17 +147,15 @@ def assertCommonSt2EnvVarsAvailableInEnv(self, env):
"""
for var_name in COMMON_ACTION_ENV_VARIABLES:
self.assertIn(var_name, env)
- self.assertEqual(env['ST2_ACTION_API_URL'], get_full_public_api_url())
+ self.assertEqual(env["ST2_ACTION_API_URL"], get_full_public_api_url())
self.assertIsNotNone(env[AUTH_TOKEN_ENV_VARIABLE_NAME])
def loader(self, path):
- """ Load the runner config
- """
+ """Load the runner config"""
return self.meta_loader.load(path)
class BaseTestCase(TestCase):
-
@classmethod
def _register_packs(self):
"""
@@ -173,7 +170,9 @@ def _register_pack_configs(self, validate_configs=False):
"""
Register all the packs inside the fixtures directory.
"""
- registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=validate_configs)
+ registrar = ConfigsRegistrar(
+ use_pack_cache=False, validate_configs=validate_configs
+ )
registrar.register_from_packs(base_dirs=get_packs_base_paths())
@@ -189,18 +188,14 @@ def setUpClass(cls):
os=True,
select=True,
socket=True,
- thread=False if '--use-debugger' in sys.argv else True,
- time=True
+ thread=False if "--use-debugger" in sys.argv else True,
+ time=True,
)
@classmethod
def tearDownClass(cls):
eventlet.monkey_patch(
- os=False,
- select=False,
- socket=False,
- thread=False,
- time=False
+ os=False, select=False, socket=False, thread=False, time=False
)
@@ -222,17 +217,29 @@ def setUpClass(cls):
tests_config.parse_args()
if cls.DISPLAY_LOG_MESSAGES:
- config_path = os.path.join(BASE_DIR, '../conf/logging.conf')
- logging.config.fileConfig(config_path,
- disable_existing_loggers=False)
+ config_path = os.path.join(BASE_DIR, "../conf/logging.conf")
+ logging.config.fileConfig(config_path, disable_existing_loggers=False)
@classmethod
def _establish_connection_and_re_create_db(cls):
- username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None
- password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None
+ username = (
+ cfg.CONF.database.username
+ if hasattr(cfg.CONF.database, "username")
+ else None
+ )
+ password = (
+ cfg.CONF.database.password
+ if hasattr(cfg.CONF.database, "password")
+ else None
+ )
cls.db_connection = db_setup(
- cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port,
- username=username, password=password, ensure_indexes=False)
+ cfg.CONF.database.db_name,
+ cfg.CONF.database.host,
+ cfg.CONF.database.port,
+ username=username,
+ password=password,
+ ensure_indexes=False,
+ )
cls._drop_collections()
cls.db_connection.drop_database(cfg.CONF.database.db_name)
@@ -242,12 +249,17 @@ def _establish_connection_and_re_create_db(cls):
# NOTE: This is only needed in distributed scenarios (production deployments) where
# multiple services can start up at the same time and race conditions are possible.
if cls.ensure_indexes:
- if len(cls.ensure_indexes_models) == 0 or len(cls.ensure_indexes_models) > 1:
- msg = ('Ensuring indexes for all the models, this could significantly slow down '
- 'the tests')
- print('#' * len(msg), file=sys.stderr)
+ if (
+ len(cls.ensure_indexes_models) == 0
+ or len(cls.ensure_indexes_models) > 1
+ ):
+ msg = (
+ "Ensuring indexes for all the models, this could significantly slow down "
+ "the tests"
+ )
+ print("#" * len(msg), file=sys.stderr)
print(msg, file=sys.stderr)
- print('#' * len(msg), file=sys.stderr)
+ print("#" * len(msg), file=sys.stderr)
db_ensure_indexes(cls.ensure_indexes_models)
@@ -319,19 +331,19 @@ def run(self, result=None):
class ExecutionDbTestCase(DbTestCase):
- """"
+ """ "
Base test class for tests which test various execution related code paths.
This class offers some utility methods for waiting on execution status, etc.
"""
ensure_indexes = True
- ensure_indexes_models = [
- ActionExecutionSchedulingQueueItemDB
- ]
+ ensure_indexes_models = [ActionExecutionSchedulingQueueItemDB]
- def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True):
- assert isinstance(status, six.string_types), '%s is not of text type' % (status)
+ def _wait_on_status(
+ self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True
+ ):
+ assert isinstance(status, six.string_types), "%s is not of text type" % (status)
for _ in range(0, retries):
eventlet.sleep(delay)
@@ -344,8 +356,12 @@ def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_e
return liveaction_db
- def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True):
- assert isinstance(statuses, (list, tuple)), '%s is not of list type' % (statuses)
+ def _wait_on_statuses(
+ self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True
+ ):
+ assert isinstance(statuses, (list, tuple)), "%s is not of list type" % (
+ statuses
+ )
for _ in range(0, retries):
eventlet.sleep(delay)
@@ -358,7 +374,9 @@ def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, rai
return liveaction_db
- def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, raise_exc=True):
+ def _wait_on_ac_ex_status(
+ self, execution_db, status, retries=300, delay=0.1, raise_exc=True
+ ):
for _ in range(0, retries):
eventlet.sleep(delay)
execution_db = ex_db_access.ActionExecution.get_by_id(str(execution_db.id))
@@ -370,7 +388,9 @@ def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, ra
return execution_db
- def _wait_on_call_count(self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True):
+ def _wait_on_call_count(
+ self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True
+ ):
for _ in range(0, retries):
eventlet.sleep(delay)
if mocked.call_count == expected_count:
@@ -395,12 +415,14 @@ def setUpClass(cls):
def _assert_fields_equal(self, a, b, exclude=None):
exclude = exclude or []
- fields = {k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude}
+ fields = {
+ k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude
+ }
assert_funcs = {
- 'mongoengine.fields.DictField': self.assertDictEqual,
- 'mongoengine.fields.ListField': self.assertListEqual,
- 'mongoengine.fields.SortedListField': self.assertListEqual
+ "mongoengine.fields.DictField": self.assertDictEqual,
+ "mongoengine.fields.ListField": self.assertListEqual,
+ "mongoengine.fields.SortedListField": self.assertListEqual,
}
for k, v in six.iteritems(fields):
@@ -410,10 +432,7 @@ def _assert_fields_equal(self, a, b, exclude=None):
def _assert_values_equal(self, a, values=None):
values = values or {}
- assert_funcs = {
- 'dict': self.assertDictEqual,
- 'list': self.assertListEqual
- }
+ assert_funcs = {"dict": self.assertDictEqual, "list": self.assertListEqual}
for k, v in six.iteritems(values):
assert_func = assert_funcs.get(type(v).__name__, self.assertEqual)
@@ -421,7 +440,7 @@ def _assert_values_equal(self, a, values=None):
def _assert_crud(self, instance, defaults=None, updates=None):
# Assert instance is not already in the database.
- self.assertIsNone(getattr(instance, 'id', None))
+ self.assertIsNone(getattr(instance, "id", None))
# Assert default values are assigned.
self._assert_values_equal(instance, values=defaults)
@@ -429,7 +448,7 @@ def _assert_crud(self, instance, defaults=None, updates=None):
# Assert instance is created in the datbaase.
saved = self.access_type.add_or_update(instance)
self.assertIsNotNone(saved.id)
- self._assert_fields_equal(instance, saved, exclude=['id'])
+ self._assert_fields_equal(instance, saved, exclude=["id"])
retrieved = self.access_type.get_by_id(saved.id)
self._assert_fields_equal(saved, retrieved)
@@ -443,22 +462,23 @@ def _assert_crud(self, instance, defaults=None, updates=None):
# Assert instance is deleted from the database.
retrieved = self.access_type.get_by_id(instance.id)
retrieved.delete()
- self.assertRaises(StackStormDBObjectNotFoundError,
- self.access_type.get_by_id, instance.id)
+ self.assertRaises(
+ StackStormDBObjectNotFoundError, self.access_type.get_by_id, instance.id
+ )
def _assert_unique_key_constraint(self, instance):
# Assert instance is not already in the database.
- self.assertIsNone(getattr(instance, 'id', None))
+ self.assertIsNone(getattr(instance, "id", None))
# Assert instance is created in the datbaase.
saved = self.access_type.add_or_update(instance)
self.assertIsNotNone(saved.id)
# Assert exception is thrown if try to create same instance again.
- delattr(instance, 'id')
- self.assertRaises(StackStormDBObjectConflictError,
- self.access_type.add_or_update,
- instance)
+ delattr(instance, "id")
+ self.assertRaises(
+ StackStormDBObjectConflictError, self.access_type.add_or_update, instance
+ )
class CleanDbTestCase(BaseDbTestCase):
@@ -486,6 +506,7 @@ class CleanFilesTestCase(TestCase):
"""
Base test class which deletes specified files and directories on setUp and `tearDown.
"""
+
to_delete_files = []
to_delete_directories = []
@@ -555,8 +576,8 @@ def tearDown(self):
stderr = None
print('Process "%s"' % (process.pid))
- print('Stdout: %s' % (stdout))
- print('Stderr: %s' % (stderr))
+ print("Stdout: %s" % (stdout))
+ print("Stderr: %s" % (stderr))
def add_process(self, process):
"""
@@ -578,7 +599,7 @@ def assertProcessIsRunning(self, process):
has succesfuly started and is running.
"""
if not process:
- raise ValueError('process is None')
+ raise ValueError("process is None")
return_code = process.poll()
@@ -586,24 +607,27 @@ def assertProcessIsRunning(self, process):
if process.stdout:
stdout = process.stdout.read()
else:
- stdout = ''
+ stdout = ""
if process.stderr:
stderr = process.stderr.read()
else:
- stderr = ''
+ stderr = ""
- msg = ('Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s' %
- (return_code, stdout, stderr))
+ msg = "Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s" % (
+ return_code,
+ stdout,
+ stderr,
+ )
self.fail(msg)
def assertProcessExited(self, proc):
try:
status = proc.status()
except psutil.NoSuchProcess:
- status = 'exited'
+ status = "exited"
- if status not in ['exited', 'zombie']:
+ if status not in ["exited", "zombie"]:
self.fail('Process with pid "%s" is still running' % (proc.pid))
@@ -613,49 +637,49 @@ class WorkflowTestCase(ExecutionDbTestCase):
"""
def get_wf_fixture_meta_data(self, fixture_pack_path, wf_meta_file_name):
- wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name
+ wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name
wf_meta_content = loader.load_meta_file(wf_meta_file_path)
- wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name']
+ wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"]
return {
- 'file_name': wf_meta_file_name,
- 'file_path': wf_meta_file_path,
- 'content': wf_meta_content,
- 'name': wf_name
+ "file_name": wf_meta_file_name,
+ "file_path": wf_meta_file_path,
+ "content": wf_meta_content,
+ "name": wf_name,
}
def get_wf_def(self, test_pack_path, wf_meta):
- rel_wf_def_path = wf_meta['content']['entry_point']
- abs_wf_def_path = os.path.join(test_pack_path, 'actions', rel_wf_def_path)
+ rel_wf_def_path = wf_meta["content"]["entry_point"]
+ abs_wf_def_path = os.path.join(test_pack_path, "actions", rel_wf_def_path)
- with open(abs_wf_def_path, 'r') as def_file:
+ with open(abs_wf_def_path, "r") as def_file:
return def_file.read()
def mock_st2_context(self, ac_ex_db, context=None):
st2_ctx = {
- 'st2': {
- 'api_url': api_util.get_full_public_api_url(),
- 'action_execution_id': str(ac_ex_db.id),
- 'user': 'stanley',
- 'action': ac_ex_db.action['ref'],
- 'runner': ac_ex_db.runner['name']
+ "st2": {
+ "api_url": api_util.get_full_public_api_url(),
+ "action_execution_id": str(ac_ex_db.id),
+ "user": "stanley",
+ "action": ac_ex_db.action["ref"],
+ "runner": ac_ex_db.runner["name"],
}
}
if context:
- st2_ctx['parent'] = context
+ st2_ctx["parent"] = context
return st2_ctx
def prep_wf_ex(self, wf_ex_db):
data = {
- 'spec': wf_ex_db.spec,
- 'graph': wf_ex_db.graph,
- 'input': wf_ex_db.input,
- 'context': wf_ex_db.context,
- 'state': wf_ex_db.state,
- 'output': wf_ex_db.output,
- 'errors': wf_ex_db.errors
+ "spec": wf_ex_db.spec,
+ "graph": wf_ex_db.graph,
+ "input": wf_ex_db.input,
+ "context": wf_ex_db.context,
+ "state": wf_ex_db.state,
+ "output": wf_ex_db.output,
+ "errors": wf_ex_db.errors,
}
conductor = conducting.WorkflowConductor.deserialize(data)
@@ -663,7 +687,7 @@ def prep_wf_ex(self, wf_ex_db):
for task in conductor.get_next_tasks():
ac_ex_event = events.ActionExecutionEvent(wf_statuses.RUNNING)
- conductor.update_task_state(task['id'], task['route'], ac_ex_event)
+ conductor.update_task_state(task["id"], task["route"], ac_ex_event)
wf_ex_db.status = conductor.get_workflow_status()
wf_ex_db.state = conductor.workflow_state.serialize()
@@ -672,7 +696,9 @@ def prep_wf_ex(self, wf_ex_db):
return wf_ex_db
def get_task_ex(self, task_id, route):
- task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route)
+ task_ex_dbs = wf_db_access.TaskExecution.query(
+ task_id=task_id, task_route=route
+ )
self.assertGreater(len(task_ex_dbs), 0)
return task_ex_dbs[0]
@@ -686,21 +712,29 @@ def get_action_ex(self, task_ex_id):
self.assertEqual(len(ac_ex_dbs), 1)
return ac_ex_dbs[0]
- def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None,
- expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED,
- expected_tk_ex_db_status=wf_statuses.SUCCEEDED):
- spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog'])
+ def run_workflow_step(
+ self,
+ wf_ex_db,
+ task_id,
+ route,
+ ctx=None,
+ expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED,
+ expected_tk_ex_db_status=wf_statuses.SUCCEEDED,
+ ):
+ spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"])
wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec)
- st2_ctx = {'execution_id': wf_ex_db.action_execution}
+ st2_ctx = {"execution_id": wf_ex_db.action_execution}
task_spec = wf_spec.tasks.get_task(task_id)
- task_actions = [{'action': task_spec.action, 'input': getattr(task_spec, 'input', {})}]
+ task_actions = [
+ {"action": task_spec.action, "input": getattr(task_spec, "input", {})}
+ ]
task_req = {
- 'id': task_id,
- 'route': route,
- 'spec': task_spec,
- 'ctx': ctx or {},
- 'actions': task_actions
+ "id": task_id,
+ "route": route,
+ "spec": task_spec,
+ "ctx": ctx or {},
+ "actions": task_actions,
}
task_ex_db = wf_svc.request_task_execution(wf_ex_db, st2_ctx, task_req)
@@ -712,10 +746,12 @@ def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None,
self.assertEqual(task_ex_db.status, expected_tk_ex_db_status)
def sort_workflow_errors(self, errors):
- return sorted(errors, key=lambda x: x.get('task_id', None))
+ return sorted(errors, key=lambda x: x.get("task_id", None))
def assert_task_not_started(self, task_id, route):
- task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route)
+ task_ex_dbs = wf_db_access.TaskExecution.query(
+ task_id=task_id, task_route=route
+ )
self.assertEqual(len(task_ex_dbs), 0)
def assert_task_running(self, task_id, route):
@@ -734,7 +770,6 @@ def assert_workflow_completed(self, wf_ex_id, status=None):
class FakeResponse(object):
-
def __init__(self, text, status_code, reason):
self.text = text
self.status_code = status_code
@@ -748,11 +783,11 @@ def raise_for_status(self):
def get_fixtures_path():
- return os.path.join(os.path.dirname(__file__), 'fixtures')
+ return os.path.join(os.path.dirname(__file__), "fixtures")
def get_resources_path():
- return os.path.join(os.path.dirname(__file__), 'resources')
+ return os.path.join(os.path.dirname(__file__), "resources")
def blocking_eventlet_spawn(func, *args, **kwargs):
diff --git a/st2tests/st2tests/config.py b/st2tests/st2tests/config.py
index 7fa4ad7b6e..b140357839 100644
--- a/st2tests/st2tests/config.py
+++ b/st2tests/st2tests/config.py
@@ -77,57 +77,66 @@ def _register_config_opts():
def _override_db_opts():
- CONF.set_override(name='db_name', override='st2-test', group='database')
- CONF.set_override(name='host', override='127.0.0.1', group='database')
+ CONF.set_override(name="db_name", override="st2-test", group="database")
+ CONF.set_override(name="host", override="127.0.0.1", group="database")
def _override_common_opts():
packs_base_path = get_fixtures_packs_base_path()
- CONF.set_override(name='base_path', override=packs_base_path, group='system')
- CONF.set_override(name='validate_output_schema', override=True, group='system')
- CONF.set_override(name='system_packs_base_path', override=packs_base_path, group='content')
- CONF.set_override(name='packs_base_paths', override=packs_base_path, group='content')
- CONF.set_override(name='api_url', override='http://127.0.0.1', group='auth')
- CONF.set_override(name='mask_secrets', override=True, group='log')
- CONF.set_override(name='stream_output', override=False, group='actionrunner')
+ CONF.set_override(name="base_path", override=packs_base_path, group="system")
+ CONF.set_override(name="validate_output_schema", override=True, group="system")
+ CONF.set_override(
+ name="system_packs_base_path", override=packs_base_path, group="content"
+ )
+ CONF.set_override(
+ name="packs_base_paths", override=packs_base_path, group="content"
+ )
+ CONF.set_override(name="api_url", override="http://127.0.0.1", group="auth")
+ CONF.set_override(name="mask_secrets", override=True, group="log")
+ CONF.set_override(name="stream_output", override=False, group="actionrunner")
def _override_api_opts():
- CONF.set_override(name='allow_origin', override=['http://127.0.0.1:3000', 'http://dev'],
- group='api')
+ CONF.set_override(
+ name="allow_origin",
+ override=["http://127.0.0.1:3000", "http://dev"],
+ group="api",
+ )
def _override_keyvalue_opts():
current_file_path = os.path.dirname(__file__)
- rel_st2_base_path = os.path.join(current_file_path, '../..')
+ rel_st2_base_path = os.path.join(current_file_path, "../..")
abs_st2_base_path = os.path.abspath(rel_st2_base_path)
- rel_enc_key_path = 'st2tests/conf/st2_kvstore_tests.crypto.key.json'
+ rel_enc_key_path = "st2tests/conf/st2_kvstore_tests.crypto.key.json"
ovr_enc_key_path = os.path.join(abs_st2_base_path, rel_enc_key_path)
- CONF.set_override(name='encryption_key_path', override=ovr_enc_key_path, group='keyvalue')
+ CONF.set_override(
+ name="encryption_key_path", override=ovr_enc_key_path, group="keyvalue"
+ )
def _override_scheduler_opts():
- CONF.set_override(name='sleep_interval', group='scheduler', override=0.01)
+ CONF.set_override(name="sleep_interval", group="scheduler", override=0.01)
def _override_coordinator_opts(noop=False):
- driver = None if noop else 'zake://'
- CONF.set_override(name='url', override=driver, group='coordination')
- CONF.set_override(name='lock_timeout', override=1, group='coordination')
+ driver = None if noop else "zake://"
+ CONF.set_override(name="url", override=driver, group="coordination")
+ CONF.set_override(name="lock_timeout", override=1, group="coordination")
def _override_workflow_engine_opts():
- cfg.CONF.set_override('retry_stop_max_msec', 500, group='workflow_engine')
- cfg.CONF.set_override('retry_wait_fixed_msec', 100, group='workflow_engine')
- cfg.CONF.set_override('retry_max_jitter_msec', 100, group='workflow_engine')
- cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine')
+ cfg.CONF.set_override("retry_stop_max_msec", 500, group="workflow_engine")
+ cfg.CONF.set_override("retry_wait_fixed_msec", 100, group="workflow_engine")
+ cfg.CONF.set_override("retry_max_jitter_msec", 100, group="workflow_engine")
+ cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine")
def _register_common_opts():
try:
common_config.register_opts(ignore_errors=True)
except:
- LOG.exception('Common config registration failed.')
+ LOG.exception("Common config registration failed.")
def _register_api_opts():
@@ -135,225 +144,292 @@ def _register_api_opts():
# Brittle!
pecan_opts = [
cfg.StrOpt(
- 'root', default='st2api.controllers.root.RootController',
- help='Pecan root controller'),
- cfg.StrOpt('template_path', default='%(confdir)s/st2api/st2api/templates'),
- cfg.ListOpt('modules', default=['st2api']),
- cfg.BoolOpt('debug', default=True),
- cfg.BoolOpt('auth_enable', default=True),
- cfg.DictOpt('errors', default={404: '/error/404', '__force_dict__': True})
+ "root",
+ default="st2api.controllers.root.RootController",
+ help="Pecan root controller",
+ ),
+ cfg.StrOpt("template_path", default="%(confdir)s/st2api/st2api/templates"),
+ cfg.ListOpt("modules", default=["st2api"]),
+ cfg.BoolOpt("debug", default=True),
+ cfg.BoolOpt("auth_enable", default=True),
+ cfg.DictOpt("errors", default={404: "/error/404", "__force_dict__": True}),
]
- _register_opts(pecan_opts, group='api_pecan')
+ _register_opts(pecan_opts, group="api_pecan")
api_opts = [
- cfg.BoolOpt('debug', default=True),
+ cfg.BoolOpt("debug", default=True),
cfg.IntOpt(
- 'max_page_size', default=100,
- help='Maximum limit (page size) argument which can be specified by the user in a query '
- 'string. If a larger value is provided, it will default to this value.')
+ "max_page_size",
+ default=100,
+ help="Maximum limit (page size) argument which can be specified by the user in a query "
+ "string. If a larger value is provided, it will default to this value.",
+ ),
]
- _register_opts(api_opts, group='api')
+ _register_opts(api_opts, group="api")
messaging_opts = [
cfg.StrOpt(
- 'url', default='amqp://guest:guest@127.0.0.1:5672//',
- help='URL of the messaging server.'),
+ "url",
+ default="amqp://guest:guest@127.0.0.1:5672//",
+ help="URL of the messaging server.",
+ ),
cfg.ListOpt(
- 'cluster_urls', default=[],
- help='URL of all the nodes in a messaging service cluster.'),
+ "cluster_urls",
+ default=[],
+ help="URL of all the nodes in a messaging service cluster.",
+ ),
cfg.IntOpt(
- 'connection_retries', default=10,
- help='How many times should we retry connection before failing.'),
+ "connection_retries",
+ default=10,
+ help="How many times should we retry connection before failing.",
+ ),
cfg.IntOpt(
- 'connection_retry_wait', default=10000,
- help='How long should we wait between connection retries.'),
+ "connection_retry_wait",
+ default=10000,
+ help="How long should we wait between connection retries.",
+ ),
cfg.BoolOpt(
- 'ssl', default=False,
- help='Use SSL / TLS to connect to the messaging server. Same as '
- 'appending "?ssl=true" at the end of the connection URL string.'),
+ "ssl",
+ default=False,
+ help="Use SSL / TLS to connect to the messaging server. Same as "
+ 'appending "?ssl=true" at the end of the connection URL string.',
+ ),
cfg.StrOpt(
- 'ssl_keyfile', default=None,
- help='Private keyfile used to identify the local connection against RabbitMQ.'),
+ "ssl_keyfile",
+ default=None,
+ help="Private keyfile used to identify the local connection against RabbitMQ.",
+ ),
cfg.StrOpt(
- 'ssl_certfile', default=None,
- help='Certificate file used to identify the local connection (client).'),
+ "ssl_certfile",
+ default=None,
+ help="Certificate file used to identify the local connection (client).",
+ ),
cfg.StrOpt(
- 'ssl_cert_reqs', default=None, choices='none, optional, required',
- help='Specifies whether a certificate is required from the other side of the '
- 'connection, and whether it will be validated if provided.'),
+ "ssl_cert_reqs",
+ default=None,
+ choices="none, optional, required",
+ help="Specifies whether a certificate is required from the other side of the "
+ "connection, and whether it will be validated if provided.",
+ ),
cfg.StrOpt(
- 'ssl_ca_certs', default=None,
- help='ca_certs file contains a set of concatenated CA certificates, which are '
- 'used to validate certificates passed from RabbitMQ.'),
+ "ssl_ca_certs",
+ default=None,
+ help="ca_certs file contains a set of concatenated CA certificates, which are "
+ "used to validate certificates passed from RabbitMQ.",
+ ),
cfg.StrOpt(
- 'login_method', default=None,
- help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).')
+ "login_method",
+ default=None,
+ help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).",
+ ),
]
- _register_opts(messaging_opts, group='messaging')
+ _register_opts(messaging_opts, group="messaging")
ssh_runner_opts = [
cfg.StrOpt(
- 'remote_dir', default='/tmp',
- help='Location of the script on the remote filesystem.'),
+ "remote_dir",
+ default="/tmp",
+ help="Location of the script on the remote filesystem.",
+ ),
cfg.BoolOpt(
- 'allow_partial_failure', default=False,
- help='How partial success of actions run on multiple nodes should be treated.'),
+ "allow_partial_failure",
+ default=False,
+ help="How partial success of actions run on multiple nodes should be treated.",
+ ),
cfg.BoolOpt(
- 'use_ssh_config', default=False,
- help='Use the .ssh/config file. Useful to override ports etc.')
+ "use_ssh_config",
+ default=False,
+ help="Use the .ssh/config file. Useful to override ports etc.",
+ ),
]
- _register_opts(ssh_runner_opts, group='ssh_runner')
+ _register_opts(ssh_runner_opts, group="ssh_runner")
def _register_stream_opts():
stream_opts = [
cfg.IntOpt(
- 'heartbeat', default=25,
- help='Send empty message every N seconds to keep connection open'),
- cfg.BoolOpt(
- 'debug', default=False,
- help='Specify to enable debug mode.'),
+ "heartbeat",
+ default=25,
+ help="Send empty message every N seconds to keep connection open",
+ ),
+ cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."),
]
- _register_opts(stream_opts, group='stream')
+ _register_opts(stream_opts, group="stream")
def _register_auth_opts():
auth_opts = [
- cfg.StrOpt('host', default='127.0.0.1'),
- cfg.IntOpt('port', default=9100),
- cfg.BoolOpt('use_ssl', default=False),
- cfg.StrOpt('mode', default='proxy'),
- cfg.StrOpt('backend', default='flat_file'),
- cfg.StrOpt('backend_kwargs', default=None),
- cfg.StrOpt('logging', default='conf/logging.conf'),
- cfg.IntOpt('token_ttl', default=86400, help='Access token ttl in seconds.'),
- cfg.BoolOpt('sso', default=True),
- cfg.StrOpt('sso_backend', default='noop'),
- cfg.StrOpt('sso_backend_kwargs', default=None),
- cfg.BoolOpt('debug', default=True)
+ cfg.StrOpt("host", default="127.0.0.1"),
+ cfg.IntOpt("port", default=9100),
+ cfg.BoolOpt("use_ssl", default=False),
+ cfg.StrOpt("mode", default="proxy"),
+ cfg.StrOpt("backend", default="flat_file"),
+ cfg.StrOpt("backend_kwargs", default=None),
+ cfg.StrOpt("logging", default="conf/logging.conf"),
+ cfg.IntOpt("token_ttl", default=86400, help="Access token ttl in seconds."),
+ cfg.BoolOpt("sso", default=True),
+ cfg.StrOpt("sso_backend", default="noop"),
+ cfg.StrOpt("sso_backend_kwargs", default=None),
+ cfg.BoolOpt("debug", default=True),
]
- _register_opts(auth_opts, group='auth')
+ _register_opts(auth_opts, group="auth")
def _register_action_sensor_opts():
action_sensor_opts = [
cfg.BoolOpt(
- 'enable', default=True,
- help='Whether to enable or disable the ability to post a trigger on action.'),
+ "enable",
+ default=True,
+ help="Whether to enable or disable the ability to post a trigger on action.",
+ ),
cfg.StrOpt(
- 'triggers_base_url', default='http://127.0.0.1:9101/v1/triggertypes/',
- help='URL for action sensor to post TriggerType.'),
+ "triggers_base_url",
+ default="http://127.0.0.1:9101/v1/triggertypes/",
+ help="URL for action sensor to post TriggerType.",
+ ),
cfg.IntOpt(
- 'request_timeout', default=1,
- help='Timeout value of all httprequests made by action sensor.'),
+ "request_timeout",
+ default=1,
+ help="Timeout value of all httprequests made by action sensor.",
+ ),
cfg.IntOpt(
- 'max_attempts', default=10,
- help='No. of times to retry registration.'),
+ "max_attempts", default=10, help="No. of times to retry registration."
+ ),
cfg.IntOpt(
- 'retry_wait', default=1,
- help='Amount of time to wait prior to retrying a request.')
+ "retry_wait",
+ default=1,
+ help="Amount of time to wait prior to retrying a request.",
+ ),
]
- _register_opts(action_sensor_opts, group='action_sensor')
+ _register_opts(action_sensor_opts, group="action_sensor")
def _register_ssh_runner_opts():
ssh_runner_opts = [
cfg.BoolOpt(
- 'use_ssh_config', default=False,
- help='Use the .ssh/config file. Useful to override ports etc.'),
+ "use_ssh_config",
+ default=False,
+ help="Use the .ssh/config file. Useful to override ports etc.",
+ ),
cfg.StrOpt(
- 'remote_dir', default='/tmp',
- help='Location of the script on the remote filesystem.'),
+ "remote_dir",
+ default="/tmp",
+ help="Location of the script on the remote filesystem.",
+ ),
cfg.BoolOpt(
- 'allow_partial_failure', default=False,
- help='How partial success of actions run on multiple nodes should be treated.'),
+ "allow_partial_failure",
+ default=False,
+ help="How partial success of actions run on multiple nodes should be treated.",
+ ),
cfg.IntOpt(
- 'max_parallel_actions', default=50,
- help='Max number of parallel remote SSH actions that should be run. '
- 'Works only with Paramiko SSH runner.'),
+ "max_parallel_actions",
+ default=50,
+ help="Max number of parallel remote SSH actions that should be run. "
+ "Works only with Paramiko SSH runner.",
+ ),
]
- _register_opts(ssh_runner_opts, group='ssh_runner')
+ _register_opts(ssh_runner_opts, group="ssh_runner")
def _register_scheduler_opts():
scheduler_opts = [
cfg.FloatOpt(
- 'execution_scheduling_timeout_threshold_min', default=1,
- help='How long GC to search back in minutes for orphaned scheduled actions'),
+ "execution_scheduling_timeout_threshold_min",
+ default=1,
+ help="How long GC to search back in minutes for orphaned scheduled actions",
+ ),
cfg.IntOpt(
- 'pool_size', default=10,
- help='The size of the pool used by the scheduler for scheduling executions.'),
+ "pool_size",
+ default=10,
+ help="The size of the pool used by the scheduler for scheduling executions.",
+ ),
cfg.FloatOpt(
- 'sleep_interval', default=0.01,
- help='How long to sleep between each action scheduler main loop run interval (in ms).'),
+ "sleep_interval",
+ default=0.01,
+ help="How long to sleep between each action scheduler main loop run interval (in ms).",
+ ),
cfg.FloatOpt(
- 'gc_interval', default=5,
- help='How often to look for zombie executions before rescheduling them (in ms).'),
+ "gc_interval",
+ default=5,
+ help="How often to look for zombie executions before rescheduling them (in ms).",
+ ),
cfg.IntOpt(
- 'retry_max_attempt', default=3,
- help='The maximum number of attempts that the scheduler retries on error.'),
+ "retry_max_attempt",
+ default=3,
+ help="The maximum number of attempts that the scheduler retries on error.",
+ ),
cfg.IntOpt(
- 'retry_wait_msec', default=100,
- help='The number of milliseconds to wait in between retries.')
+ "retry_wait_msec",
+ default=100,
+ help="The number of milliseconds to wait in between retries.",
+ ),
]
- _register_opts(scheduler_opts, group='scheduler')
+ _register_opts(scheduler_opts, group="scheduler")
def _register_exporter_opts():
exporter_opts = [
cfg.StrOpt(
- 'dump_dir', default='/opt/stackstorm/exports/',
- help='Directory to dump data to.')
+ "dump_dir",
+ default="/opt/stackstorm/exports/",
+ help="Directory to dump data to.",
+ )
]
- _register_opts(exporter_opts, group='exporter')
+ _register_opts(exporter_opts, group="exporter")
def _register_sensor_container_opts():
partition_opts = [
cfg.StrOpt(
- 'sensor_node_name', default='sensornode1',
- help='name of the sensor node.'),
+ "sensor_node_name", default="sensornode1", help="name of the sensor node."
+ ),
cfg.Opt(
- 'partition_provider',
+ "partition_provider",
type=types.Dict(value_type=types.String()),
- default={'name': DEFAULT_PARTITION_LOADER},
- help='Provider of sensor node partition config.')
+ default={"name": DEFAULT_PARTITION_LOADER},
+ help="Provider of sensor node partition config.",
+ ),
]
- _register_opts(partition_opts, group='sensorcontainer')
+ _register_opts(partition_opts, group="sensorcontainer")
# Other options
other_opts = [
cfg.BoolOpt(
- 'single_sensor_mode', default=False,
- help='Run in a single sensor mode where parent process exits when a sensor crashes / '
- 'dies. This is useful in environments where partitioning, sensor process life '
- 'cycle and failover is handled by a 3rd party service such as kubernetes.')
+ "single_sensor_mode",
+ default=False,
+ help="Run in a single sensor mode where parent process exits when a sensor crashes / "
+ "dies. This is useful in environments where partitioning, sensor process life "
+ "cycle and failover is handled by a 3rd party service such as kubernetes.",
+ )
]
- _register_opts(other_opts, group='sensorcontainer')
+ _register_opts(other_opts, group="sensorcontainer")
# CLI options
cli_opts = [
cfg.StrOpt(
- 'sensor-ref',
- help='Only run sensor with the provided reference. Value is of the form '
- '. (e.g. linux.FileWatchSensor).'),
+ "sensor-ref",
+ help="Only run sensor with the provided reference. Value is of the form "
+ ". (e.g. linux.FileWatchSensor).",
+ ),
cfg.BoolOpt(
- 'single-sensor-mode', default=False,
- help='Run in a single sensor mode where parent process exits when a sensor crashes / '
- 'dies. This is useful in environments where partitioning, sensor process life '
- 'cycle and failover is handled by a 3rd party service such as kubernetes.')
+ "single-sensor-mode",
+ default=False,
+ help="Run in a single sensor mode where parent process exits when a sensor crashes / "
+ "dies. This is useful in environments where partitioning, sensor process life "
+ "cycle and failover is handled by a 3rd party service such as kubernetes.",
+ ),
]
_register_cli_opts(cli_opts)
@@ -362,40 +438,52 @@ def _register_sensor_container_opts():
def _register_garbage_collector_opts():
common_opts = [
cfg.IntOpt(
- 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL,
- help='How often to check database for old data and perform garbage collection.'),
+ "collection_interval",
+ default=DEFAULT_COLLECTION_INTERVAL,
+ help="How often to check database for old data and perform garbage collection.",
+ ),
cfg.FloatOpt(
- 'sleep_delay', default=DEFAULT_SLEEP_DELAY,
- help='How long to wait / sleep (in seconds) between '
- 'collection of different object types.')
+ "sleep_delay",
+ default=DEFAULT_SLEEP_DELAY,
+ help="How long to wait / sleep (in seconds) between "
+ "collection of different object types.",
+ ),
]
- _register_opts(common_opts, group='garbagecollector')
+ _register_opts(common_opts, group="garbagecollector")
ttl_opts = [
cfg.IntOpt(
- 'action_executions_ttl', default=None,
- help='Action executions and related objects (live actions, action output '
- 'objects) older than this value (days) will be automatically deleted.'),
+ "action_executions_ttl",
+ default=None,
+ help="Action executions and related objects (live actions, action output "
+ "objects) older than this value (days) will be automatically deleted.",
+ ),
cfg.IntOpt(
- 'action_executions_output_ttl', default=7,
- help='Action execution output objects (ones generated by action output '
- 'streaming) older than this value (days) will be automatically deleted.'),
+ "action_executions_output_ttl",
+ default=7,
+ help="Action execution output objects (ones generated by action output "
+ "streaming) older than this value (days) will be automatically deleted.",
+ ),
cfg.IntOpt(
- 'trigger_instances_ttl', default=None,
- help='Trigger instances older than this value (days) will be automatically deleted.')
+ "trigger_instances_ttl",
+ default=None,
+ help="Trigger instances older than this value (days) will be automatically deleted.",
+ ),
]
- _register_opts(ttl_opts, group='garbagecollector')
+ _register_opts(ttl_opts, group="garbagecollector")
inquiry_opts = [
cfg.BoolOpt(
- 'purge_inquiries', default=False,
- help='Set to True to perform garbage collection on Inquiries (based on '
- 'the TTL value per Inquiry)')
+ "purge_inquiries",
+ default=False,
+ help="Set to True to perform garbage collection on Inquiries (based on "
+ "the TTL value per Inquiry)",
+ )
]
- _register_opts(inquiry_opts, group='garbagecollector')
+ _register_opts(inquiry_opts, group="garbagecollector")
def _register_opts(opts, group=None):
diff --git a/st2tests/st2tests/fixtures/history_views/__init__.py b/st2tests/st2tests/fixtures/history_views/__init__.py
index dd42395788..24567ead6e 100644
--- a/st2tests/st2tests/fixtures/history_views/__init__.py
+++ b/st2tests/st2tests/fixtures/history_views/__init__.py
@@ -21,12 +21,12 @@
PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)))
-FILES = glob.glob('%s/*.yaml' % PATH)
+FILES = glob.glob("%s/*.yaml" % PATH)
ARTIFACTS = {}
for f in FILES:
f_name = os.path.split(f)[1]
name = six.text_type(os.path.splitext(f_name)[0])
- with open(f, 'r') as fd:
+ with open(f, "r") as fd:
ARTIFACTS[name] = yaml.safe_load(fd)
diff --git a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py
index b5184b586c..5b2cc19cc0 100755
--- a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py
+++ b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py
@@ -32,16 +32,16 @@ def print_random_chars(chars=1000, selection=ascii_letters + string.digits):
s = []
for _ in range(chars - 1):
s.append(random.choice(selection))
- s.append('@')
- print(''.join(s))
+ s.append("@")
+ print("".join(s))
def main():
parser = argparse.ArgumentParser()
- parser.add_argument('--chars', type=int, metavar='N', default=10)
+ parser.add_argument("--chars", type=int, metavar="N", default=10)
args = parser.parse_args()
print_random_chars(args.chars)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py
index 57f3f6eea3..40875e1182 100644
--- a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py
+++ b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py
@@ -17,6 +17,5 @@
class PrintPythonVersionAction(Action):
-
def run(self, value1):
return {"context_value": value1}
diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py
index ef42e25d15..acd2832627 100644
--- a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py
+++ b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py
@@ -14,8 +14,8 @@
# limitations under the License.
from __future__ import absolute_import
-from invalid import Invalid # noqa
+from invalid import Invalid # noqa
-class Foo():
+class Foo:
pass
diff --git a/st2tests/st2tests/fixtures/packs/executions/__init__.py b/st2tests/st2tests/fixtures/packs/executions/__init__.py
index 3faa0a81ad..ef9bf26a3f 100644
--- a/st2tests/st2tests/fixtures/packs/executions/__init__.py
+++ b/st2tests/st2tests/fixtures/packs/executions/__init__.py
@@ -22,17 +22,17 @@
PATH = os.path.dirname(os.path.realpath(__file__))
-FILES = glob.glob('%s/*.yaml' % PATH)
+FILES = glob.glob("%s/*.yaml" % PATH)
ARTIFACTS = {}
for f in FILES:
f_name = os.path.split(f)[1]
name = six.text_type(os.path.splitext(f_name)[0])
- with open(f, 'r') as fd:
+ with open(f, "r") as fd:
ARTIFACTS[name] = yaml.safe_load(fd)
if isinstance(ARTIFACTS[name], dict):
- ARTIFACTS[name][u'id'] = six.text_type(bson.ObjectId())
+ ARTIFACTS[name]["id"] = six.text_type(bson.ObjectId())
elif isinstance(ARTIFACTS[name], list):
for item in ARTIFACTS[name]:
- item[u'id'] = six.text_type(bson.ObjectId())
+ item["id"] = six.text_type(bson.ObjectId())
diff --git a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py
index 0409202903..31258fae4e 100644
--- a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py
+++ b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import AsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class AsyncTestRunner(AsyncActionRunner):
def __init__(self):
- super(AsyncTestRunner, self).__init__(runner_id='1')
+ super(AsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py
index 435f7eb9b6..c48bb9aa67 100644
--- a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py
+++ b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import PollingAsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class PollingAsyncTestRunner(PollingAsyncActionRunner):
def __init__(self):
- super(PollingAsyncTestRunner, self).__init__(runner_id='1')
+ super(PollingAsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py
index fe50e37ae5..5d18a77ccb 100644
--- a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py
+++ b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py
@@ -15,9 +15,7 @@
from st2actions.runners.pythonrunner import Action
-__all__ = [
- 'GetLibraryPathAction'
-]
+__all__ = ["GetLibraryPathAction"]
class GetLibraryPathAction(Action):
diff --git a/st2tests/st2tests/fixturesloader.py b/st2tests/st2tests/fixturesloader.py
index df9f2cef7f..dd1446153e 100644
--- a/st2tests/st2tests/fixturesloader.py
+++ b/st2tests/st2tests/fixturesloader.py
@@ -21,16 +21,21 @@
from st2common.content.loader import MetaLoader
-from st2common.models.api.action import (ActionAPI, LiveActionAPI, ActionExecutionStateAPI,
- RunnerTypeAPI, ActionAliasAPI)
+from st2common.models.api.action import (
+ ActionAPI,
+ LiveActionAPI,
+ ActionExecutionStateAPI,
+ RunnerTypeAPI,
+ ActionAliasAPI,
+)
from st2common.models.api.auth import ApiKeyAPI, UserAPI
-from st2common.models.api.execution import (ActionExecutionAPI)
-from st2common.models.api.policy import (PolicyTypeAPI, PolicyAPI)
-from st2common.models.api.rule import (RuleAPI)
+from st2common.models.api.execution import ActionExecutionAPI
+from st2common.models.api.policy import PolicyTypeAPI, PolicyAPI
+from st2common.models.api.rule import RuleAPI
from st2common.models.api.rule_enforcement import RuleEnforcementAPI
from st2common.models.api.sensor import SensorTypeAPI
from st2common.models.api.trace import TraceAPI
-from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI)
+from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI
from st2common.models.db.action import ActionDB
from st2common.models.db.actionalias import ActionAliasDB
@@ -38,13 +43,13 @@
from st2common.models.db.liveaction import LiveActionDB
from st2common.models.db.executionstate import ActionExecutionStateDB
from st2common.models.db.runner import RunnerTypeDB
-from st2common.models.db.execution import (ActionExecutionDB)
-from st2common.models.db.policy import (PolicyTypeDB, PolicyDB)
+from st2common.models.db.execution import ActionExecutionDB
+from st2common.models.db.policy import PolicyTypeDB, PolicyDB
from st2common.models.db.rule import RuleDB
from st2common.models.db.rule_enforcement import RuleEnforcementDB
from st2common.models.db.sensor import SensorTypeDB
from st2common.models.db.trace import TraceDB
-from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB)
+from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB
from st2common.persistence.action import Action
from st2common.persistence.actionalias import ActionAlias
from st2common.persistence.execution import ActionExecution
@@ -52,107 +57,125 @@
from st2common.persistence.auth import ApiKey, User
from st2common.persistence.liveaction import LiveAction
from st2common.persistence.runner import RunnerType
-from st2common.persistence.policy import (PolicyType, Policy)
+from st2common.persistence.policy import PolicyType, Policy
from st2common.persistence.rule import Rule
from st2common.persistence.rule_enforcement import RuleEnforcement
from st2common.persistence.sensor import SensorType
from st2common.persistence.trace import Trace
-from st2common.persistence.trigger import (Trigger, TriggerType, TriggerInstance)
-
-
-ALLOWED_DB_FIXTURES = ['actions', 'actionstates', 'aliases', 'executions', 'liveactions',
- 'policies', 'policytypes', 'rules', 'runners', 'sensors',
- 'triggertypes', 'triggers', 'triggerinstances', 'traces', 'apikeys',
- 'users', 'enforcements']
+from st2common.persistence.trigger import Trigger, TriggerType, TriggerInstance
+
+
+ALLOWED_DB_FIXTURES = [
+ "actions",
+ "actionstates",
+ "aliases",
+ "executions",
+ "liveactions",
+ "policies",
+ "policytypes",
+ "rules",
+ "runners",
+ "sensors",
+ "triggertypes",
+ "triggers",
+ "triggerinstances",
+ "traces",
+ "apikeys",
+ "users",
+ "enforcements",
+]
ALLOWED_FIXTURES = copy.copy(ALLOWED_DB_FIXTURES)
-ALLOWED_FIXTURES.extend(['actionchains', 'workflows'])
+ALLOWED_FIXTURES.extend(["actionchains", "workflows"])
FIXTURE_DB_MODEL = {
- 'actions': ActionDB,
- 'aliases': ActionAliasDB,
- 'actionstates': ActionExecutionStateDB,
- 'apikeys': ApiKeyDB,
- 'enforcements': RuleEnforcementDB,
- 'executions': ActionExecutionDB,
- 'liveactions': LiveActionDB,
- 'policies': PolicyDB,
- 'policytypes': PolicyTypeDB,
- 'rules': RuleDB,
- 'runners': RunnerTypeDB,
- 'sensors': SensorTypeDB,
- 'traces': TraceDB,
- 'triggertypes': TriggerTypeDB,
- 'triggers': TriggerDB,
- 'triggerinstances': TriggerInstanceDB,
- 'users': UserDB
+ "actions": ActionDB,
+ "aliases": ActionAliasDB,
+ "actionstates": ActionExecutionStateDB,
+ "apikeys": ApiKeyDB,
+ "enforcements": RuleEnforcementDB,
+ "executions": ActionExecutionDB,
+ "liveactions": LiveActionDB,
+ "policies": PolicyDB,
+ "policytypes": PolicyTypeDB,
+ "rules": RuleDB,
+ "runners": RunnerTypeDB,
+ "sensors": SensorTypeDB,
+ "traces": TraceDB,
+ "triggertypes": TriggerTypeDB,
+ "triggers": TriggerDB,
+ "triggerinstances": TriggerInstanceDB,
+ "users": UserDB,
}
FIXTURE_API_MODEL = {
- 'actions': ActionAPI,
- 'aliases': ActionAliasAPI,
- 'actionstates': ActionExecutionStateAPI,
- 'apikeys': ApiKeyAPI,
- 'enforcements': RuleEnforcementAPI,
- 'executions': ActionExecutionAPI,
- 'liveactions': LiveActionAPI,
- 'policies': PolicyAPI,
- 'policytypes': PolicyTypeAPI,
- 'rules': RuleAPI,
- 'runners': RunnerTypeAPI,
- 'sensors': SensorTypeAPI,
- 'traces': TraceAPI,
- 'triggertypes': TriggerTypeAPI,
- 'triggers': TriggerAPI,
- 'triggerinstances': TriggerInstanceAPI,
- 'users': UserAPI
+ "actions": ActionAPI,
+ "aliases": ActionAliasAPI,
+ "actionstates": ActionExecutionStateAPI,
+ "apikeys": ApiKeyAPI,
+ "enforcements": RuleEnforcementAPI,
+ "executions": ActionExecutionAPI,
+ "liveactions": LiveActionAPI,
+ "policies": PolicyAPI,
+ "policytypes": PolicyTypeAPI,
+ "rules": RuleAPI,
+ "runners": RunnerTypeAPI,
+ "sensors": SensorTypeAPI,
+ "traces": TraceAPI,
+ "triggertypes": TriggerTypeAPI,
+ "triggers": TriggerAPI,
+ "triggerinstances": TriggerInstanceAPI,
+ "users": UserAPI,
}
FIXTURE_PERSISTENCE_MODEL = {
- 'actions': Action,
- 'aliases': ActionAlias,
- 'actionstates': ActionExecutionState,
- 'apikeys': ApiKey,
- 'enforcements': RuleEnforcement,
- 'executions': ActionExecution,
- 'liveactions': LiveAction,
- 'policies': Policy,
- 'policytypes': PolicyType,
- 'rules': Rule,
- 'runners': RunnerType,
- 'sensors': SensorType,
- 'traces': Trace,
- 'triggertypes': TriggerType,
- 'triggers': Trigger,
- 'triggerinstances': TriggerInstance,
- 'users': User
+ "actions": Action,
+ "aliases": ActionAlias,
+ "actionstates": ActionExecutionState,
+ "apikeys": ApiKey,
+ "enforcements": RuleEnforcement,
+ "executions": ActionExecution,
+ "liveactions": LiveAction,
+ "policies": Policy,
+ "policytypes": PolicyType,
+ "rules": Rule,
+ "runners": RunnerType,
+ "sensors": SensorType,
+ "traces": Trace,
+ "triggertypes": TriggerType,
+ "triggers": Trigger,
+ "triggerinstances": TriggerInstance,
+ "users": User,
}
GIT_SUBMODULES_NOT_CHECKED_OUT_ERROR = """
Git submodule "%s" is not checked out. Make sure to run "git submodule update --init
--recursive" in the repository root directory to check out all the
submodules.
-""".replace('\n', '').strip()
+""".replace(
+ "\n", ""
+).strip()
def get_fixtures_base_path():
- return os.path.join(os.path.dirname(__file__), 'fixtures')
+ return os.path.join(os.path.dirname(__file__), "fixtures")
def get_fixtures_packs_base_path():
- return os.path.join(os.path.dirname(__file__), 'fixtures/packs')
+ return os.path.join(os.path.dirname(__file__), "fixtures/packs")
def get_resources_base_path():
- return os.path.join(os.path.dirname(__file__), 'resources')
+ return os.path.join(os.path.dirname(__file__), "resources")
class FixturesLoader(object):
def __init__(self):
self.meta_loader = MetaLoader()
- def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None,
- use_object_ids=False):
+ def save_fixtures_to_db(
+ self, fixtures_pack="generic", fixtures_dict=None, use_object_ids=False
+ ):
"""
Loads fixtures specified in fixtures_dict into the database
and returns DB models for the fixtures.
@@ -193,17 +216,22 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None,
for fixture in fixtures:
# Guard against copy and type and similar typos
if fixture in loaded_fixtures:
- msg = 'Fixture "%s" is specified twice, probably a typo.' % (fixture)
+ msg = 'Fixture "%s" is specified twice, probably a typo.' % (
+ fixture
+ )
raise ValueError(msg)
fixture_dict = self.meta_loader.load(
- self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture))
+ self._get_fixture_file_path_abs(
+ fixtures_pack_path, fixture_type, fixture
+ )
+ )
api_model = API_MODEL(**fixture_dict)
db_model = API_MODEL.to_model(api_model)
# Make sure we also set and use object id if that functionality is used
- if use_object_ids and 'id' in fixture_dict:
- db_model.id = fixture_dict['id']
+ if use_object_ids and "id" in fixture_dict:
+ db_model.id = fixture_dict["id"]
db_model = PERSISTENCE_MODEL.add_or_update(db_model)
loaded_fixtures[fixture] = db_model
@@ -212,7 +240,7 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None,
return db_models
- def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None):
+ def load_fixtures(self, fixtures_pack="generic", fixtures_dict=None):
"""
Loads fixtures specified in fixtures_dict. We
simply want to load the meta into dict objects.
@@ -241,13 +269,16 @@ def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None):
loaded_fixtures = {}
for fixture in fixtures:
fixture_dict = self.meta_loader.load(
- self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture))
+ self._get_fixture_file_path_abs(
+ fixtures_pack_path, fixture_type, fixture
+ )
+ )
loaded_fixtures[fixture] = fixture_dict
all_fixtures[fixture_type] = loaded_fixtures
return all_fixtures
- def load_models(self, fixtures_pack='generic', fixtures_dict=None):
+ def load_models(self, fixtures_pack="generic", fixtures_dict=None):
"""
Loads fixtures specified in fixtures_dict as db models. This method must be
used for fixtures that have associated DB models. We simply want to load the
@@ -281,7 +312,10 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None):
loaded_models = {}
for fixture in fixtures:
fixture_dict = self.meta_loader.load(
- self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture))
+ self._get_fixture_file_path_abs(
+ fixtures_pack_path, fixture_type, fixture
+ )
+ )
api_model = API_MODEL(**fixture_dict)
db_model = API_MODEL.to_model(api_model)
loaded_models[fixture] = db_model
@@ -289,8 +323,9 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None):
return all_fixtures
- def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None,
- raise_on_fail=False):
+ def delete_fixtures_from_db(
+ self, fixtures_pack="generic", fixtures_dict=None, raise_on_fail=False
+ ):
"""
Deletes fixtures specified in fixtures_dict from the database.
@@ -320,7 +355,10 @@ def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None,
PERSISTENCE_MODEL = FIXTURE_PERSISTENCE_MODEL.get(fixture_type, None)
for fixture in fixtures:
fixture_dict = self.meta_loader.load(
- self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture))
+ self._get_fixture_file_path_abs(
+ fixtures_pack_path, fixture_type, fixture
+ )
+ )
# Note that when we have a reference mechanism consistent for
# every model, we can just do a get and delete the object. Until
# then, this model conversions are necessary.
@@ -362,28 +400,36 @@ def _validate_fixtures_pack(self, fixtures_pack):
fixtures_pack_path = self._get_fixtures_pack_path(fixtures_pack)
if not self._is_fixture_pack_exists(fixtures_pack_path):
- raise Exception('Fixtures pack not found ' +
- 'in fixtures path %s.' % get_fixtures_base_path())
+ raise Exception(
+ "Fixtures pack not found "
+ + "in fixtures path %s." % get_fixtures_base_path()
+ )
return fixtures_pack_path
def _validate_fixture_dict(self, fixtures_dict, allowed=ALLOWED_FIXTURES):
fixture_types = list(fixtures_dict.keys())
for fixture_type in fixture_types:
if fixture_type not in allowed:
- raise Exception('Disallowed fixture type: %s. Valid fixture types are: %s' % (
- fixture_type, ", ".join(allowed)))
+ raise Exception(
+ "Disallowed fixture type: %s. Valid fixture types are: %s"
+ % (fixture_type, ", ".join(allowed))
+ )
def _is_fixture_pack_exists(self, fixtures_pack_path):
return os.path.exists(fixtures_pack_path)
- def _get_fixture_file_path_abs(self, fixtures_pack_path, fixtures_type, fixture_name):
+ def _get_fixture_file_path_abs(
+ self, fixtures_pack_path, fixtures_type, fixture_name
+ ):
return os.path.join(fixtures_pack_path, fixtures_type, fixture_name)
def _get_fixtures_pack_path(self, fixtures_pack_name):
return os.path.join(get_fixtures_base_path(), fixtures_pack_name)
def get_fixture_file_path_abs(self, fixtures_pack, fixtures_type, fixture_name):
- return os.path.join(get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name)
+ return os.path.join(
+ get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name
+ )
def assert_submodules_are_checked_out():
@@ -392,9 +438,9 @@ def assert_submodules_are_checked_out():
root of the directory and that the "st2tests/st2tests/fixtures/packs/test" git repo submodule
used by the tests is checked out.
"""
- pack_path = os.path.join(get_fixtures_packs_base_path(), 'test_content_version/')
+ pack_path = os.path.join(get_fixtures_packs_base_path(), "test_content_version/")
pack_path = os.path.abspath(pack_path)
- submodule_git_dir_or_file_path = os.path.join(pack_path, '.git')
+ submodule_git_dir_or_file_path = os.path.join(pack_path, ".git")
# NOTE: In newer versions of git, that .git is a file and not a directory
if not os.path.exists(submodule_git_dir_or_file_path):
diff --git a/st2tests/st2tests/http.py b/st2tests/st2tests/http.py
index 4dce56f45a..e14672d001 100644
--- a/st2tests/st2tests/http.py
+++ b/st2tests/st2tests/http.py
@@ -18,7 +18,6 @@
class FakeResponse(object):
-
def __init__(self, text, status_code, reason):
self.text = text
self.status_code = status_code
diff --git a/st2tests/st2tests/mocks/action.py b/st2tests/st2tests/mocks/action.py
index f09d5a7f8d..ec8f7842b8 100644
--- a/st2tests/st2tests/mocks/action.py
+++ b/st2tests/st2tests/mocks/action.py
@@ -25,10 +25,7 @@
from python_runner.python_action_wrapper import ActionService
from st2tests.mocks.datastore import MockDatastoreService
-__all__ = [
- 'MockActionWrapper',
- 'MockActionService'
-]
+__all__ = ["MockActionWrapper", "MockActionService"]
class MockActionWrapper(object):
@@ -49,9 +46,11 @@ def __init__(self, action_wrapper):
# We use a Mock class so use can assert logger was called with particular arguments
self._logger = Mock(spec=RootLogger)
- self._datastore_service = MockDatastoreService(logger=self._logger,
- pack_name=self._action_wrapper._pack,
- class_name=self._action_wrapper._class_name)
+ self._datastore_service = MockDatastoreService(
+ logger=self._logger,
+ pack_name=self._action_wrapper._pack,
+ class_name=self._action_wrapper._class_name,
+ )
@property
def datastore_service(self):
diff --git a/st2tests/st2tests/mocks/auth.py b/st2tests/st2tests/mocks/auth.py
index e0624aca42..6f322959cc 100644
--- a/st2tests/st2tests/mocks/auth.py
+++ b/st2tests/st2tests/mocks/auth.py
@@ -18,24 +18,18 @@
from st2auth.backends.base import BaseAuthenticationBackend
# auser:apassword in b64
-DUMMY_CREDS = 'YXVzZXI6YXBhc3N3b3Jk'
+DUMMY_CREDS = "YXVzZXI6YXBhc3N3b3Jk"
-__all__ = [
- 'DUMMY_CREDS',
-
- 'MockAuthBackend',
- 'MockRequest',
-
- 'get_mock_backend'
-]
+__all__ = ["DUMMY_CREDS", "MockAuthBackend", "MockRequest", "get_mock_backend"]
class MockAuthBackend(BaseAuthenticationBackend):
groups = []
def authenticate(self, username, password):
- return ((username == 'auser' and password == 'apassword') or
- (username == 'username' and password == 'password:password'))
+ return (username == "auser" and password == "apassword") or (
+ username == "username" and password == "password:password"
+ )
def get_user(self, username):
return username
@@ -44,7 +38,7 @@ def get_user_groups(self, username):
return self.groups
-class MockRequest():
+class MockRequest:
def __init__(self, ttl):
self.ttl = ttl
diff --git a/st2tests/st2tests/mocks/datastore.py b/st2tests/st2tests/mocks/datastore.py
index fe8156bf9e..0282a18ffd 100644
--- a/st2tests/st2tests/mocks/datastore.py
+++ b/st2tests/st2tests/mocks/datastore.py
@@ -22,9 +22,7 @@
from st2common.services.datastore import BaseDatastoreService
from st2client.models.keyvalue import KeyValuePair
-__all__ = [
- 'MockDatastoreService'
-]
+__all__ = ["MockDatastoreService"]
class MockDatastoreService(BaseDatastoreService):
@@ -35,7 +33,7 @@ class MockDatastoreService(BaseDatastoreService):
def __init__(self, logger, pack_name, class_name, api_username=None):
self._pack_name = pack_name
self._class_name = class_name
- self._username = api_username or 'admin'
+ self._username = api_username or "admin"
# Holds mock KeyValuePair objects
# Key is a KeyValuePair name and value is the KeyValuePair object
@@ -53,18 +51,9 @@ def get_user_info(self):
:rtype: ``dict``
"""
result = {
- 'username': self._username,
- 'rbac': {
- 'is_admin': True,
- 'enabled': True,
- 'roles': [
- 'admin'
- ]
- },
- 'authentication': {
- 'method': 'authentication token',
- 'location': 'header'
- }
+ "username": self._username,
+ "rbac": {"is_admin": True, "enabled": True, "roles": ["admin"]},
+ "authentication": {"method": "authentication token", "location": "header"},
}
return result
@@ -101,12 +90,16 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False):
kvp = self._datastore_items[name]
return kvp.value
- def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False):
+ def set_value(
+ self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False
+ ):
"""
Store a value in a dictionary which is local to this class.
"""
if ttl:
- raise ValueError('MockDatastoreService.set_value doesn\'t support "ttl" argument')
+ raise ValueError(
+ 'MockDatastoreService.set_value doesn\'t support "ttl" argument'
+ )
name = self._get_full_key_name(name=name, local=local)
diff --git a/st2tests/st2tests/mocks/execution.py b/st2tests/st2tests/mocks/execution.py
index 1fdf8a4262..00e3c8ef11 100644
--- a/st2tests/st2tests/mocks/execution.py
+++ b/st2tests/st2tests/mocks/execution.py
@@ -21,13 +21,10 @@
from st2common.models.db.execution import ActionExecutionDB
-__all__ = [
- 'MockExecutionPublisher'
-]
+__all__ = ["MockExecutionPublisher"]
class MockExecutionPublisher(object):
-
@classmethod
def publish_update(cls, payload):
try:
@@ -39,7 +36,6 @@ def publish_update(cls, payload):
class MockExecutionPublisherNonBlocking(object):
-
@classmethod
def publish_update(cls, payload):
try:
diff --git a/st2tests/st2tests/mocks/liveaction.py b/st2tests/st2tests/mocks/liveaction.py
index 753224d9ea..2b329e6b25 100644
--- a/st2tests/st2tests/mocks/liveaction.py
+++ b/st2tests/st2tests/mocks/liveaction.py
@@ -26,14 +26,10 @@
from st2common.constants import action as action_constants
from st2common.models.db.liveaction import LiveActionDB
-__all__ = [
- 'MockLiveActionPublisher',
- 'MockLiveActionPublisherNonBlocking'
-]
+__all__ = ["MockLiveActionPublisher", "MockLiveActionPublisherNonBlocking"]
class MockLiveActionPublisher(object):
-
@classmethod
def process(cls, payload):
ex_req = scheduling.get_scheduler_entrypoint().process(payload)
@@ -106,7 +102,6 @@ def wait_all(cls):
class MockLiveActionPublisherSchedulingQueueOnly(object):
-
@classmethod
def process(cls, payload):
scheduling.get_scheduler_entrypoint().process(payload)
diff --git a/st2tests/st2tests/mocks/runners/async_runner.py b/st2tests/st2tests/mocks/runners/async_runner.py
index 0409202903..31258fae4e 100644
--- a/st2tests/st2tests/mocks/runners/async_runner.py
+++ b/st2tests/st2tests/mocks/runners/async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import AsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class AsyncTestRunner(AsyncActionRunner):
def __init__(self):
- super(AsyncTestRunner, self).__init__(runner_id='1')
+ super(AsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2tests/st2tests/mocks/runners/polling_async_runner.py b/st2tests/st2tests/mocks/runners/polling_async_runner.py
index 435f7eb9b6..c48bb9aa67 100644
--- a/st2tests/st2tests/mocks/runners/polling_async_runner.py
+++ b/st2tests/st2tests/mocks/runners/polling_async_runner.py
@@ -14,15 +14,16 @@
# limitations under the License.
from __future__ import absolute_import
+
try:
import simplejson as json
except:
import json
from st2common.runners.base import PollingAsyncActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_RUNNING)
+from st2common.constants.action import LIVEACTION_STATUS_RUNNING
-RAISE_PROPERTY = 'raise'
+RAISE_PROPERTY = "raise"
def get_runner():
@@ -31,7 +32,7 @@ def get_runner():
class PollingAsyncTestRunner(PollingAsyncActionRunner):
def __init__(self):
- super(PollingAsyncTestRunner, self).__init__(runner_id='1')
+ super(PollingAsyncTestRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
self.post_run_called = False
@@ -43,14 +44,11 @@ def run(self, action_params):
self.run_called = True
result = {}
if self.runner_parameters.get(RAISE_PROPERTY, False):
- raise Exception('Raise required.')
+ raise Exception("Raise required.")
else:
- result = {
- 'ran': True,
- 'action_params': action_params
- }
+ result = {"ran": True, "action_params": action_params}
- return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'})
+ return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"})
def post_run(self, status, result):
self.post_run_called = True
diff --git a/st2tests/st2tests/mocks/runners/runner.py b/st2tests/st2tests/mocks/runners/runner.py
index 40d07516c6..b89b75b712 100644
--- a/st2tests/st2tests/mocks/runners/runner.py
+++ b/st2tests/st2tests/mocks/runners/runner.py
@@ -17,12 +17,9 @@
import json
from st2common.runners.base import ActionRunner
-from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED)
+from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED
-__all__ = [
- 'get_runner',
- 'MockActionRunner'
-]
+__all__ = ["get_runner", "MockActionRunner"]
def get_runner(config=None):
@@ -31,7 +28,7 @@ def get_runner(config=None):
class MockActionRunner(ActionRunner):
def __init__(self):
- super(MockActionRunner, self).__init__(runner_id='1')
+ super(MockActionRunner, self).__init__(runner_id="1")
self.pre_run_called = False
self.run_called = False
@@ -45,22 +42,15 @@ def run(self, action_params):
self.run_called = True
result = {}
- if self.runner_parameters.get('raise', False):
- raise Exception('Raise required.')
+ if self.runner_parameters.get("raise", False):
+ raise Exception("Raise required.")
- default_result = {
- 'ran': True,
- 'action_params': action_params
- }
- default_context = {
- 'third_party_system': {
- 'ref_id': '1234'
- }
- }
+ default_result = {"ran": True, "action_params": action_params}
+ default_context = {"third_party_system": {"ref_id": "1234"}}
- status = self.runner_parameters.get('mock_status', LIVEACTION_STATUS_SUCCEEDED)
- result = self.runner_parameters.get('mock_result', default_result)
- context = self.runner_parameters.get('mock_context', default_context)
+ status = self.runner_parameters.get("mock_status", LIVEACTION_STATUS_SUCCEEDED)
+ result = self.runner_parameters.get("mock_result", default_result)
+ context = self.runner_parameters.get("mock_context", default_context)
return (status, json.dumps(result), context)
diff --git a/st2tests/st2tests/mocks/sensor.py b/st2tests/st2tests/mocks/sensor.py
index 1f06787b14..c65825786c 100644
--- a/st2tests/st2tests/mocks/sensor.py
+++ b/st2tests/st2tests/mocks/sensor.py
@@ -27,10 +27,7 @@
from st2reactor.container.sensor_wrapper import SensorService
from st2tests.mocks.datastore import MockDatastoreService
-__all__ = [
- 'MockSensorWrapper',
- 'MockSensorService'
-]
+__all__ = ["MockSensorWrapper", "MockSensorService"]
class MockSensorWrapper(object):
@@ -54,9 +51,11 @@ def __init__(self, sensor_wrapper):
# Holds a list of triggers which were dispatched
self.dispatched_triggers = []
- self._datastore_service = MockDatastoreService(logger=self._logger,
- pack_name=self._sensor_wrapper._pack,
- class_name=self._sensor_wrapper._class_name)
+ self._datastore_service = MockDatastoreService(
+ logger=self._logger,
+ pack_name=self._sensor_wrapper._pack,
+ class_name=self._sensor_wrapper._class_name,
+ )
@property
def datastore_service(self):
@@ -74,14 +73,11 @@ def get_logger(self, name):
def dispatch(self, trigger, payload=None, trace_tag=None):
trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None
- return self.dispatch_with_context(trigger=trigger, payload=payload,
- trace_context=trace_context)
+ return self.dispatch_with_context(
+ trigger=trigger, payload=payload, trace_context=trace_context
+ )
def dispatch_with_context(self, trigger, payload=None, trace_context=None):
- item = {
- 'trigger': trigger,
- 'payload': payload,
- 'trace_context': trace_context
- }
+ item = {"trigger": trigger, "payload": payload, "trace_context": trace_context}
self.dispatched_triggers.append(item)
return item
diff --git a/st2tests/st2tests/mocks/workflow.py b/st2tests/st2tests/mocks/workflow.py
index ef50b66389..051bf5cb83 100644
--- a/st2tests/st2tests/mocks/workflow.py
+++ b/st2tests/st2tests/mocks/workflow.py
@@ -23,13 +23,10 @@
from st2common.models.db import workflow as wf_ex_db
-__all__ = [
- 'MockWorkflowExecutionPublisher'
-]
+__all__ = ["MockWorkflowExecutionPublisher"]
class MockWorkflowExecutionPublisher(object):
-
@classmethod
def publish_create(cls, payload):
try:
diff --git a/st2tests/st2tests/pack_resource.py b/st2tests/st2tests/pack_resource.py
index 7d51d74219..51f5899218 100644
--- a/st2tests/st2tests/pack_resource.py
+++ b/st2tests/st2tests/pack_resource.py
@@ -19,9 +19,7 @@
from unittest2 import TestCase
-__all__ = [
- 'BasePackResourceTestCase'
-]
+__all__ = ["BasePackResourceTestCase"]
class BasePackResourceTestCase(TestCase):
@@ -39,16 +37,16 @@ def get_fixture_content(self, fixture_path):
:type fixture_path: ``str``
"""
base_pack_path = self._get_base_pack_path()
- fixtures_path = os.path.join(base_pack_path, 'tests/fixtures/')
+ fixtures_path = os.path.join(base_pack_path, "tests/fixtures/")
fixture_path = os.path.join(fixtures_path, fixture_path)
- with open(fixture_path, 'r') as fp:
+ with open(fixture_path, "r") as fp:
content = fp.read()
return content
def _get_base_pack_path(self):
test_file_path = inspect.getfile(self.__class__)
- base_pack_path = os.path.join(os.path.dirname(test_file_path), '..')
+ base_pack_path = os.path.join(os.path.dirname(test_file_path), "..")
base_pack_path = os.path.abspath(base_pack_path)
return base_pack_path
diff --git a/st2tests/st2tests/policies/concurrency.py b/st2tests/st2tests/policies/concurrency.py
index e7494a134b..ecd6ffb51a 100644
--- a/st2tests/st2tests/policies/concurrency.py
+++ b/st2tests/st2tests/policies/concurrency.py
@@ -20,11 +20,12 @@
class FakeConcurrencyApplicator(BaseConcurrencyApplicator):
-
def __init__(self, policy_ref, policy_type, *args, **kwargs):
- super(FakeConcurrencyApplicator, self).__init__(policy_ref=policy_ref,
- policy_type=policy_type,
- threshold=kwargs.get('threshold', 0))
+ super(FakeConcurrencyApplicator, self).__init__(
+ policy_ref=policy_ref,
+ policy_type=policy_type,
+ threshold=kwargs.get("threshold", 0),
+ )
def get_threshold(self):
return self.threshold
@@ -35,7 +36,8 @@ def apply_before(self, target):
target = action_utils.update_liveaction_status(
status=action_constants.LIVEACTION_STATUS_CANCELED,
liveaction_id=target.id,
- publish=False)
+ publish=False,
+ )
return target
diff --git a/st2tests/st2tests/policies/mock_exception.py b/st2tests/st2tests/policies/mock_exception.py
index 298a8cb7bb..673eccbb54 100644
--- a/st2tests/st2tests/policies/mock_exception.py
+++ b/st2tests/st2tests/policies/mock_exception.py
@@ -18,9 +18,8 @@
class RaiseExceptionApplicator(base.ResourcePolicyApplicator):
-
def apply_before(self, target):
- raise Exception('For honor!!!!')
+ raise Exception("For honor!!!!")
def apply_after(self, target):
return target
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py
index 4994db82b3..6a124573e9 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py
@@ -18,4 +18,4 @@
class Echoer(Action):
def run(self, action_input):
- return {'action_input': action_input}
+ return {"action_input": action_input}
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py
index 926a56c73f..2597f811ad 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py
@@ -18,14 +18,10 @@
class Test(object):
- foo = 'bar'
+ foo = "bar"
class NonSimpleTypeAction(Action):
def run(self):
- result = [
- {'a': '1'},
- {'c': 2, 'h': 3},
- {'e': Test()}
- ]
+ result = [{"a": "1"}, {"c": 2, "h": 3}, {"e": Test()}]
return result
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py
index 3034e0352a..cacb89d005 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py
@@ -30,35 +30,39 @@ def run(self, **kwargs):
except Exception:
pass
- self.logger.info('test info log message')
- self.logger.debug('test debug log message')
- self.logger.error('test error log message')
+ self.logger.info("test info log message")
+ self.logger.debug("test debug log message")
+ self.logger.error("test error log message")
return PascalRowAction._compute_pascal_row(**kwargs)
@staticmethod
def _compute_pascal_row(row_index=0):
- print('Pascal row action')
+ print("Pascal row action")
- if row_index == 'a':
- return False, 'This is suppose to fail don\'t worry!!'
- elif row_index == 'b':
+ if row_index == "a":
+ return False, "This is suppose to fail don't worry!!"
+ elif row_index == "b":
return None
- elif row_index == 'complex_type':
+ elif row_index == "complex_type":
result = PascalRowAction()
return (False, result)
- elif row_index == 'c':
+ elif row_index == "c":
return False, None
- elif row_index == 'd':
- return 'succeeded', [1, 2, 3, 4]
- elif row_index == 'e':
+ elif row_index == "d":
+ return "succeeded", [1, 2, 3, 4]
+ elif row_index == "e":
return [1, 2]
elif row_index == 5:
- return [math.factorial(row_index) /
- (math.factorial(i) * math.factorial(row_index - i))
- for i in range(row_index + 1)]
- elif row_index == 'f':
- raise ValueError('Duplicate traceback test')
+ return [
+ math.factorial(row_index)
+ / (math.factorial(i) * math.factorial(row_index - i))
+ for i in range(row_index + 1)
+ ]
+ elif row_index == "f":
+ raise ValueError("Duplicate traceback test")
else:
- return True, [math.factorial(row_index) /
- (math.factorial(i) * math.factorial(row_index - i))
- for i in range(row_index + 1)]
+ return True, [
+ math.factorial(row_index)
+ / (math.factorial(i) * math.factorial(row_index - i))
+ for i in range(row_index + 1)
+ ]
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py
index f1f888f069..0bf6856145 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py
@@ -22,5 +22,5 @@ class PrintConfigItemAction(Action):
def run(self):
print(self.config)
# Verify .get() still works
- print(self.config.get('item1', 'default_value'))
- print(self.config['key'])
+ print(self.config.get("item1", "default_value"))
+ print(self.config["key"])
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py
index 06c0d2f30a..9838e5bfb6 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py
@@ -24,7 +24,7 @@
class PrintToStdoutAndStderrAction(Action):
def run(self, stdout_count=3, stderr_count=3):
for index in range(0, stdout_count):
- sys.stdout.write('stdout line %s\n' % (index))
+ sys.stdout.write("stdout line %s\n" % (index))
for index in range(0, stderr_count):
- sys.stderr.write('stderr line %s\n' % (index))
+ sys.stderr.write("stderr line %s\n" % (index))
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py
index 717549347b..ffe7b69b3b 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py
@@ -22,5 +22,5 @@
class PythonPathsAction(Action):
def run(self):
- print('sys.path: %s' % (sys.path))
- print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH')))
+ print("sys.path: %s" % (sys.path))
+ print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH")))
diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py
index d95939990f..eeed54fbb0 100644
--- a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py
+++ b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py
@@ -22,4 +22,4 @@
class TestAction(Action):
def run(self):
- return 'test action'
+ return "test action"
diff --git a/st2tests/st2tests/sensors.py b/st2tests/st2tests/sensors.py
index 0b6f31e6b6..52c0451f45 100644
--- a/st2tests/st2tests/sensors.py
+++ b/st2tests/st2tests/sensors.py
@@ -18,9 +18,7 @@
from st2tests.mocks.sensor import MockSensorService
from st2tests.pack_resource import BasePackResourceTestCase
-__all__ = [
- 'BaseSensorTestCase'
-]
+__all__ = ["BaseSensorTestCase"]
class BaseSensorTestCase(BasePackResourceTestCase):
@@ -37,22 +35,20 @@ def setUp(self):
super(BaseSensorTestCase, self).setUp()
class_name = self.sensor_cls.__name__
- sensor_wrapper = MockSensorWrapper(pack='tests', class_name=class_name)
+ sensor_wrapper = MockSensorWrapper(pack="tests", class_name=class_name)
self.sensor_service = MockSensorService(sensor_wrapper=sensor_wrapper)
def get_sensor_instance(self, config=None, poll_interval=None):
"""
Retrieve instance of the sensor class.
"""
- kwargs = {
- 'sensor_service': self.sensor_service
- }
+ kwargs = {"sensor_service": self.sensor_service}
if config:
- kwargs['config'] = config
+ kwargs["config"] = config
if poll_interval is not None:
- kwargs['poll_interval'] = poll_interval
+ kwargs["poll_interval"] = poll_interval
instance = self.sensor_cls(**kwargs) # pylint: disable=not-callable
return instance
@@ -79,15 +75,15 @@ def assertTriggerDispatched(self, trigger, payload=None, trace_context=None):
"""
dispatched_triggers = self.get_dispatched_triggers()
for item in dispatched_triggers:
- trigger_matches = (item['trigger'] == trigger)
+ trigger_matches = item["trigger"] == trigger
if payload:
- payload_matches = (item['payload'] == payload)
+ payload_matches = item["payload"] == payload
else:
payload_matches = True
if trace_context:
- trace_context_matches = (item['trace_context'] == trace_context)
+ trace_context_matches = item["trace_context"] == trace_context
else:
trace_context_matches = True
diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml
index ac38037d6c..06abc65227 100644
--- a/st2tests/testpacks/checks/actions/check_loadavg.yaml
+++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml
@@ -4,8 +4,8 @@
description: "Check CPU Load Average on a Host"
enabled: true
entry_point: "checks/check_loadavg.py"
- parameters:
- period:
+ parameters:
+ period:
type: "string"
description: "Time period for load avg: 5,10,15 minutes, or 'all'"
default: "all"
diff --git a/st2tests/testpacks/checks/actions/checks/check_loadavg.py b/st2tests/testpacks/checks/actions/checks/check_loadavg.py
index 4a56834832..9439679df3 100755
--- a/st2tests/testpacks/checks/actions/checks/check_loadavg.py
+++ b/st2tests/testpacks/checks/actions/checks/check_loadavg.py
@@ -23,40 +23,40 @@
def print_load_avg(args):
period = args[1]
- loadavg_file = '/proc/loadavg'
- cpuinfo_file = '/proc/cpuinfo'
+ loadavg_file = "/proc/loadavg"
+ cpuinfo_file = "/proc/cpuinfo"
cpus = 0
try:
- fh = open(loadavg_file, 'r')
+ fh = open(loadavg_file, "r")
load = fh.readline().split()[0:3]
fh.close()
except:
- sys.stderr.write('Error opening %s\n' % loadavg_file)
+ sys.stderr.write("Error opening %s\n" % loadavg_file)
sys.exit(2)
try:
- fh = open(cpuinfo_file, 'r')
+ fh = open(cpuinfo_file, "r")
for line in fh:
- if 'processor' in line:
+ if "processor" in line:
cpus += 1
fh.close()
except:
- sys.stderr.write('Error opeing %s\n' % cpuinfo_file)
+ sys.stderr.write("Error opeing %s\n" % cpuinfo_file)
- one_min = '1 min load/core: %s' % str(float(load[0]) / cpus)
- five_min = '5 min load/core: %s' % str(float(load[1]) / cpus)
- fifteen_min = '15 min load/core: %s' % str(float(load[2]) / cpus)
+ one_min = "1 min load/core: %s" % str(float(load[0]) / cpus)
+ five_min = "5 min load/core: %s" % str(float(load[1]) / cpus)
+ fifteen_min = "15 min load/core: %s" % str(float(load[2]) / cpus)
- if period == '1' or period == 'one':
+ if period == "1" or period == "one":
print(one_min)
- elif period == '5' or period == 'five':
+ elif period == "5" or period == "five":
print(five_min)
- elif period == '15' or period == 'fifteen':
+ elif period == "15" or period == "fifteen":
print(fifteen_min)
else:
print(one_min + " " + five_min + " " + fifteen_min)
-if __name__ == '__main__':
+if __name__ == "__main__":
print_load_avg(sys.argv)
diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh
index 5320dc2f36..2e6eadf6a2 100755
--- a/st2tests/testpacks/errorcheck/actions/exit-code.sh
+++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh
@@ -6,4 +6,4 @@ if [ -n "$1" ]
exit_code=$1
fi
-exit $exit_code
+exit $exit_code
diff --git a/test-requirements.txt b/test-requirements.txt
index 6ca0e9608d..c004342bc8 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -5,6 +5,8 @@ st2flake8==0.1.0
astroid==2.4.2
pylint==2.6.0
pylint-plugin-utils>=0.4
+black==20.8b1
+pre-commit==2.1.0
bandit==1.5.1
ipython<6.0.0
isort>=4.2.5
diff --git a/tools/config_gen.py b/tools/config_gen.py
index e705161ea3..309bdf608f 100755
--- a/tools/config_gen.py
+++ b/tools/config_gen.py
@@ -24,57 +24,57 @@
from oslo_config import cfg
-CONFIGS = ['st2actions.config',
- 'st2actions.scheduler.config',
- 'st2actions.notifier.config',
- 'st2actions.workflows.config',
- 'st2api.config',
- 'st2stream.config',
- 'st2auth.config',
- 'st2common.config',
- 'st2exporter.config',
- 'st2reactor.rules.config',
- 'st2reactor.sensor.config',
- 'st2reactor.timer.config',
- 'st2reactor.garbage_collector.config']
-
-SKIP_GROUPS = ['api_pecan', 'rbac', 'results_tracker']
+CONFIGS = [
+ "st2actions.config",
+ "st2actions.scheduler.config",
+ "st2actions.notifier.config",
+ "st2actions.workflows.config",
+ "st2api.config",
+ "st2stream.config",
+ "st2auth.config",
+ "st2common.config",
+ "st2exporter.config",
+ "st2reactor.rules.config",
+ "st2reactor.sensor.config",
+ "st2reactor.timer.config",
+ "st2reactor.garbage_collector.config",
+]
+
+SKIP_GROUPS = ["api_pecan", "rbac", "results_tracker"]
# We group auth options together to make it a bit more clear what applies where
AUTH_OPTIONS = {
- 'common': [
- 'enable',
- 'mode',
- 'logging',
- 'api_url',
- 'token_ttl',
- 'service_token_ttl',
- 'sso',
- 'sso_backend',
- 'sso_backend_kwargs',
- 'debug'
+ "common": [
+ "enable",
+ "mode",
+ "logging",
+ "api_url",
+ "token_ttl",
+ "service_token_ttl",
+ "sso",
+ "sso_backend",
+ "sso_backend_kwargs",
+ "debug",
+ ],
+ "standalone": [
+ "host",
+ "port",
+ "use_ssl",
+ "cert",
+ "key",
+ "backend",
+ "backend_kwargs",
],
- 'standalone': [
- 'host',
- 'port',
- 'use_ssl',
- 'cert',
- 'key',
- 'backend',
- 'backend_kwargs'
- ]
}
# Some of the config values change depending on the environment where this script is ran so we
# set them to static values to ensure consistent and stable output
STATIC_OPTION_VALUES = {
- 'actionrunner': {
- 'virtualenv_binary': '/usr/bin/virtualenv',
- 'python_binary': '/usr/bin/python',
+ "actionrunner": {
+ "virtualenv_binary": "/usr/bin/virtualenv",
+ "python_binary": "/usr/bin/python",
},
- 'webui': {
- 'webui_base_url': 'https://localhost'
- }
+ "webui": {"webui_base_url": "https://localhost"},
}
COMMON_AUTH_OPTIONS_COMMENT = """
@@ -112,22 +112,28 @@ def _clear_config():
def _read_group(opt_group):
all_options = list(opt_group._opts.values())
- if opt_group.name == 'auth':
+ if opt_group.name == "auth":
print(COMMON_AUTH_OPTIONS_COMMENT)
- print('')
- common_options = [option for option in all_options if option['opt'].name in
- AUTH_OPTIONS['common']]
+ print("")
+ common_options = [
+ option
+ for option in all_options
+ if option["opt"].name in AUTH_OPTIONS["common"]
+ ]
_print_options(opt_group=opt_group, options=common_options)
- print('')
+ print("")
print(STANDALONE_AUTH_OPTIONS_COMMENT)
- print('')
- standalone_options = [option for option in all_options if option['opt'].name in
- AUTH_OPTIONS['standalone']]
+ print("")
+ standalone_options = [
+ option
+ for option in all_options
+ if option["opt"].name in AUTH_OPTIONS["standalone"]
+ ]
_print_options(opt_group=opt_group, options=standalone_options)
if len(common_options) + len(standalone_options) != len(all_options):
- msg = ('Not all options are declared in AUTH_OPTIONS dict, please update it')
+ msg = "Not all options are declared in AUTH_OPTIONS dict, please update it"
raise Exception(msg)
else:
options = all_options
@@ -137,33 +143,35 @@ def _read_group(opt_group):
def _read_groups(opt_groups):
opt_groups = collections.OrderedDict(sorted(opt_groups.items()))
for name, opt_group in six.iteritems(opt_groups):
- print('[%s]' % name)
+ print("[%s]" % name)
_read_group(opt_group)
- print('')
+ print("")
def _print_options(opt_group, options):
- for opt in sorted(options, key=lambda x: x['opt'].name):
- opt = opt['opt']
+ for opt in sorted(options, key=lambda x: x["opt"].name):
+ opt = opt["opt"]
# Special case for options which could change during this script run
- static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get(opt.name, None)
+ static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get(
+ opt.name, None
+ )
if static_option_value:
opt.default = static_option_value
# Special handling for list options
if isinstance(opt, cfg.ListOpt):
if opt.default:
- value = ','.join(opt.default)
+ value = ",".join(opt.default)
else:
- value = ''
+ value = ""
- value += ' # comma separated list allowed here.'
+ value += " # comma separated list allowed here."
else:
value = opt.default
- print('# %s' % opt.help)
- print('%s = %s' % (opt.name, value))
+ print(("# %s" % opt.help).strip())
+ print(("%s = %s" % (opt.name, value)).strip())
def main(args):
@@ -176,5 +184,5 @@ def main(args):
_read_groups(opt_groups)
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv)
diff --git a/tools/diff-db-disk.py b/tools/diff-db-disk.py
index ec09a76709..a9e65d72ea 100755
--- a/tools/diff-db-disk.py
+++ b/tools/diff-db-disk.py
@@ -47,20 +47,20 @@
from st2common.persistence.action import Action
registrar = ResourceRegistrar()
-registrar.ALLOWED_EXTENSIONS = ['.yaml', '.yml', '.json']
+registrar.ALLOWED_EXTENSIONS = [".yaml", ".yml", ".json"]
meta_loader = MetaLoader()
API_MODELS_ARTIFACT_TYPES = {
- 'actions': ActionAPI,
- 'sensors': SensorTypeAPI,
- 'rules': RuleAPI
+ "actions": ActionAPI,
+ "sensors": SensorTypeAPI,
+ "rules": RuleAPI,
}
API_MODELS_PERSISTENT_MODELS = {
Action: ActionAPI,
SensorType: SensorTypeAPI,
- Rule: RuleAPI
+ Rule: RuleAPI,
}
@@ -77,13 +77,15 @@ def _get_api_models_from_db(persistence_model, pack_dir=None):
filters = {}
if pack_dir:
pack_name = os.path.basename(os.path.normpath(pack_dir))
- filters = {'pack': pack_name}
+ filters = {"pack": pack_name}
models = persistence_model.query(**filters)
models_dict = {}
for model in models:
- model_pack = getattr(model, 'pack', None) or DEFAULT_PACK_NAME
- model_ref = ResourceReference.to_string_reference(name=model.name, pack=model_pack)
- if getattr(model, 'id', None):
+ model_pack = getattr(model, "pack", None) or DEFAULT_PACK_NAME
+ model_ref = ResourceReference.to_string_reference(
+ name=model.name, pack=model_pack
+ )
+ if getattr(model, "id", None):
del model.id
API_MODEL = API_MODELS_PERSISTENT_MODELS[persistence_model]
models_dict[model_ref] = API_MODEL.from_model(model)
@@ -107,15 +109,14 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None):
artifacts_paths = registrar.get_resources_from_pack(pack_path)
for artifact_path in artifacts_paths:
artifact = meta_loader.load(artifact_path)
- if artifact_type == 'sensors':
+ if artifact_type == "sensors":
sensors_dir = os.path.dirname(artifact_path)
- sensor_file_path = os.path.join(sensors_dir, artifact['entry_point'])
- artifact['artifact_uri'] = 'file://' + sensor_file_path
- name = artifact.get('name', None) or artifact.get('class_name', None)
- if not artifact.get('pack', None):
- artifact['pack'] = pack_name
- ref = ResourceReference.to_string_reference(name=name,
- pack=pack_name)
+ sensor_file_path = os.path.join(sensors_dir, artifact["entry_point"])
+ artifact["artifact_uri"] = "file://" + sensor_file_path
+ name = artifact.get("name", None) or artifact.get("class_name", None)
+ if not artifact.get("pack", None):
+ artifact["pack"] = pack_name
+ ref = ResourceReference.to_string_reference(name=name, pack=pack_name)
API_MODEL = API_MODELS_ARTIFACT_TYPES[artifact_type]
# Following conversions are required because we add some fields with
# default values in db model. If we don't do these conversions,
@@ -128,42 +129,49 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None):
return artifacts_dict
-def _content_diff(artifact_type=None, artifact_in_disk=None, artifact_in_db=None,
- verbose=False):
+def _content_diff(
+ artifact_type=None, artifact_in_disk=None, artifact_in_db=None, verbose=False
+):
artifact_in_disk_str = json.dumps(
- artifact_in_disk.__json__(), sort_keys=True,
- indent=4, separators=(',', ': ')
+ artifact_in_disk.__json__(), sort_keys=True, indent=4, separators=(",", ": ")
)
artifact_in_db_str = json.dumps(
- artifact_in_db.__json__(), sort_keys=True,
- indent=4, separators=(',', ': ')
+ artifact_in_db.__json__(), sort_keys=True, indent=4, separators=(",", ": ")
+ )
+ diffs = difflib.context_diff(
+ artifact_in_db_str.splitlines(),
+ artifact_in_disk_str.splitlines(),
+ fromfile="DB contents",
+ tofile="Disk contents",
)
- diffs = difflib.context_diff(artifact_in_db_str.splitlines(),
- artifact_in_disk_str.splitlines(),
- fromfile='DB contents', tofile='Disk contents')
printed = False
for diff in diffs:
if not printed:
- identifier = getattr(artifact_in_db, 'ref', getattr(artifact_in_db, 'name'))
- print('%s %s in db differs from what is in disk.' % (artifact_type.upper(),
- identifier))
+ identifier = getattr(artifact_in_db, "ref", getattr(artifact_in_db, "name"))
+ print(
+ "%s %s in db differs from what is in disk."
+ % (artifact_type.upper(), identifier)
+ )
printed = True
print(diff)
if verbose:
- print('\n\nOriginal contents:')
- print('===================\n')
- print('Artifact in db:\n\n%s\n\n' % artifact_in_db_str)
- print('Artifact in disk:\n\n%s\n\n' % artifact_in_disk_str)
+ print("\n\nOriginal contents:")
+ print("===================\n")
+ print("Artifact in db:\n\n%s\n\n" % artifact_in_db_str)
+ print("Artifact in disk:\n\n%s\n\n" % artifact_in_disk_str)
-def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True,
- content_diff=True):
+def _diff(
+ persistence_model, artifact_type, pack_dir=None, verbose=True, content_diff=True
+):
artifacts_in_db_dict = _get_api_models_from_db(persistence_model, pack_dir=pack_dir)
artifacts_in_disk_dict = _get_api_models_from_disk(artifact_type, pack_dir=pack_dir)
# print(artifacts_in_disk_dict)
- all_artifacts = set(list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys()))
+ all_artifacts = set(
+ list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys())
+ )
for artifact in all_artifacts:
artifact_in_db = artifacts_in_db_dict.get(artifact, None)
@@ -172,76 +180,96 @@ def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True,
artifact_in_db_pretty_json = None
if verbose:
- print('******************************************************************************')
- print('Checking if artifact %s is present in both disk and db.' % artifact)
+ print(
+ "******************************************************************************"
+ )
+ print("Checking if artifact %s is present in both disk and db." % artifact)
if not artifact_in_db:
- print('##############################################################################')
- print('%s %s in disk not available in db.' % (artifact_type.upper(), artifact))
+ print(
+ "##############################################################################"
+ )
+ print(
+ "%s %s in disk not available in db." % (artifact_type.upper(), artifact)
+ )
artifact_in_disk_pretty_json = json.dumps(
- artifact_in_disk.__json__(), sort_keys=True,
- indent=4, separators=(',', ': ')
+ artifact_in_disk.__json__(),
+ sort_keys=True,
+ indent=4,
+ separators=(",", ": "),
)
if verbose:
- print('File contents: \n')
+ print("File contents: \n")
print(artifact_in_disk_pretty_json)
continue
if not artifact_in_disk:
- print('##############################################################################')
- print('%s %s in db not available in disk.' % (artifact_type.upper(), artifact))
+ print(
+ "##############################################################################"
+ )
+ print(
+ "%s %s in db not available in disk." % (artifact_type.upper(), artifact)
+ )
artifact_in_db_pretty_json = json.dumps(
- artifact_in_db.__json__(), sort_keys=True,
- indent=4, separators=(',', ': ')
+ artifact_in_db.__json__(),
+ sort_keys=True,
+ indent=4,
+ separators=(",", ": "),
)
if verbose:
- print('DB contents: \n')
+ print("DB contents: \n")
print(artifact_in_db_pretty_json)
continue
if verbose:
- print('Artifact %s exists in both disk and db.' % artifact)
+ print("Artifact %s exists in both disk and db." % artifact)
if content_diff:
if verbose:
- print('Performing content diff for artifact %s.' % artifact)
+ print("Performing content diff for artifact %s." % artifact)
- _content_diff(artifact_type=artifact_type,
- artifact_in_disk=artifact_in_disk,
- artifact_in_db=artifact_in_db,
- verbose=verbose)
+ _content_diff(
+ artifact_type=artifact_type,
+ artifact_in_disk=artifact_in_disk,
+ artifact_in_db=artifact_in_db,
+ verbose=verbose,
+ )
def _diff_actions(pack_dir=None, verbose=False, content_diff=True):
- _diff(Action, 'actions', pack_dir=pack_dir,
- verbose=verbose, content_diff=content_diff)
+ _diff(
+ Action, "actions", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff
+ )
def _diff_sensors(pack_dir=None, verbose=False, content_diff=True):
- _diff(SensorType, 'sensors', pack_dir=pack_dir,
- verbose=verbose, content_diff=content_diff)
+ _diff(
+ SensorType,
+ "sensors",
+ pack_dir=pack_dir,
+ verbose=verbose,
+ content_diff=content_diff,
+ )
def _diff_rules(pack_dir=None, verbose=True, content_diff=True):
- _diff(Rule, 'rules', pack_dir=pack_dir,
- verbose=verbose, content_diff=content_diff)
+ _diff(Rule, "rules", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff)
def main():
monkey_patch()
cli_opts = [
- cfg.BoolOpt('sensors', default=False,
- help='diff sensor alone.'),
- cfg.BoolOpt('actions', default=False,
- help='diff actions alone.'),
- cfg.BoolOpt('rules', default=False,
- help='diff rules alone.'),
- cfg.BoolOpt('all', default=False,
- help='diff sensors, actions and rules.'),
- cfg.BoolOpt('verbose', default=False),
- cfg.BoolOpt('simple', default=False,
- help='In simple mode, tool only tells you if content is missing.' +
- 'It doesn\'t show you content diff between disk and db.'),
- cfg.StrOpt('pack-dir', default=None, help='Path to specific pack to diff.')
+ cfg.BoolOpt("sensors", default=False, help="diff sensor alone."),
+ cfg.BoolOpt("actions", default=False, help="diff actions alone."),
+ cfg.BoolOpt("rules", default=False, help="diff rules alone."),
+ cfg.BoolOpt("all", default=False, help="diff sensors, actions and rules."),
+ cfg.BoolOpt("verbose", default=False),
+ cfg.BoolOpt(
+ "simple",
+ default=False,
+ help="In simple mode, tool only tells you if content is missing."
+ + "It doesn't show you content diff between disk and db.",
+ ),
+ cfg.StrOpt("pack-dir", default=None, help="Path to specific pack to diff."),
]
do_register_cli_opts(cli_opts)
config.parse_args()
@@ -254,23 +282,35 @@ def main():
content_diff = not cfg.CONF.simple
if cfg.CONF.all:
- _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
- _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
- _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
+ _diff_sensors(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
+ _diff_actions(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
+ _diff_rules(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
return
if cfg.CONF.sensors:
- _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
+ _diff_sensors(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
if cfg.CONF.actions:
- _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
+ _diff_actions(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
if cfg.CONF.rules:
- _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff)
+ _diff_rules(
+ pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff
+ )
# Disconnect from db.
db_teardown()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tools/direct_queue_publisher.py b/tools/direct_queue_publisher.py
index bc01242085..0da7dd0b08 100755
--- a/tools/direct_queue_publisher.py
+++ b/tools/direct_queue_publisher.py
@@ -22,26 +22,27 @@
def main(queue, payload):
- connection = pika.BlockingConnection(pika.ConnectionParameters(
- host='localhost',
- credentials=pika.credentials.PlainCredentials(username='guest', password='guest')))
+ connection = pika.BlockingConnection(
+ pika.ConnectionParameters(
+ host="localhost",
+ credentials=pika.credentials.PlainCredentials(
+ username="guest", password="guest"
+ ),
+ )
+ )
channel = connection.channel()
channel.queue_declare(queue=queue, durable=True)
- channel.basic_publish(exchange='',
- routing_key=queue,
- body=payload)
+ channel.basic_publish(exchange="", routing_key=queue, body=payload)
print("Sent %s" % payload)
connection.close()
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Direct queue publisher')
- parser.add_argument('--queue', required=True,
- help='Routing key to use')
- parser.add_argument('--payload', required=True,
- help='Message payload')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Direct queue publisher")
+ parser.add_argument("--queue", required=True, help="Routing key to use")
+ parser.add_argument("--payload", required=True, help="Message payload")
args = parser.parse_args()
main(queue=args.queue, payload=args.payload)
diff --git a/tools/enumerate-runners.py b/tools/enumerate-runners.py
index 9610407411..9cae10cd18 100755
--- a/tools/enumerate-runners.py
+++ b/tools/enumerate-runners.py
@@ -20,15 +20,18 @@
from st2common.runners import get_backend_driver
from st2common import config
+
config.parse_args()
runner_names = get_available_backends()
-print('Available / installed action runners:')
+print("Available / installed action runners:")
for name in runner_names:
runner_driver = get_backend_driver(name)
runner_instance = runner_driver.get_runner()
runner_metadata = runner_driver.get_metadata()
- print('- %s (runner_module=%s,cls=%s)' % (name, runner_metadata['runner_module'],
- runner_instance.__class__))
+ print(
+ "- %s (runner_module=%s,cls=%s)"
+ % (name, runner_metadata["runner_module"], runner_instance.__class__)
+ )
diff --git a/tools/json2yaml.py b/tools/json2yaml.py
index 29959949e8..5aecb3711e 100755
--- a/tools/json2yaml.py
+++ b/tools/json2yaml.py
@@ -21,6 +21,7 @@
from __future__ import absolute_import
import argparse
import fnmatch
+
try:
import simplejson as json
except ImportError:
@@ -33,7 +34,7 @@
PRINT = pprint.pprint
-YAML_HEADER = '---'
+YAML_HEADER = "---"
def get_files_matching_pattern(dir_, pattern):
@@ -47,47 +48,47 @@ def get_files_matching_pattern(dir_, pattern):
def json_2_yaml_convert(filename):
data = None
try:
- with open(filename, 'r') as json_file:
+ with open(filename, "r") as json_file:
data = json.load(json_file)
except:
- PRINT('Failed on {}'.format(filename))
+ PRINT("Failed on {}".format(filename))
traceback.print_exc()
- return (filename, '')
- new_filename = os.path.splitext(filename)[0] + '.yaml'
- with open(new_filename, 'w') as yaml_file:
- yaml_file.write(YAML_HEADER + '\n')
+ return (filename, "")
+ new_filename = os.path.splitext(filename)[0] + ".yaml"
+ with open(new_filename, "w") as yaml_file:
+ yaml_file.write(YAML_HEADER + "\n")
yaml_file.write(yaml.safe_dump(data, default_flow_style=False))
return (filename, new_filename)
def git_rm(filename):
try:
- subprocess.check_call(['git', 'rm', filename])
+ subprocess.check_call(["git", "rm", filename])
except subprocess.CalledProcessError:
- PRINT('Failed to git rm {}'.format(filename))
+ PRINT("Failed to git rm {}".format(filename))
traceback.print_exc()
return (False, filename)
return (True, filename)
def main(dir_, skip_convert):
- files = get_files_matching_pattern(dir_, '*.json')
+ files = get_files_matching_pattern(dir_, "*.json")
if skip_convert:
PRINT(files)
return
results = [json_2_yaml_convert(filename) for filename in files]
- PRINT('*** conversion done ***')
- PRINT(['converted {} to {}'.format(result[0], result[1]) for result in results])
+ PRINT("*** conversion done ***")
+ PRINT(["converted {} to {}".format(result[0], result[1]) for result in results])
results = [git_rm(filename) for filename, new_filename in results if new_filename]
- PRINT('*** git rm done ***')
+ PRINT("*** git rm done ***")
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='json2yaml converter.')
- parser.add_argument('--dir', '-d', required=True,
- help='The dir to look for json.')
- parser.add_argument('--skipconvert', '-s', action='store_true',
- help='Skip conversion')
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="json2yaml converter.")
+ parser.add_argument("--dir", "-d", required=True, help="The dir to look for json.")
+ parser.add_argument(
+ "--skipconvert", "-s", action="store_true", help="Skip conversion"
+ )
args = parser.parse_args()
main(dir_=args.dir, skip_convert=args.skipconvert)
diff --git a/tools/list_group_members.py b/tools/list_group_members.py
index e811eabd00..9cf575b62e 100755
--- a/tools/list_group_members.py
+++ b/tools/list_group_members.py
@@ -31,24 +31,26 @@ def main(group_id=None):
if not group_id:
group_ids = list(coordinator.get_groups().get())
- group_ids = [item.decode('utf-8') for item in group_ids]
+ group_ids = [item.decode("utf-8") for item in group_ids]
- print('Available groups (%s):' % (len(group_ids)))
+ print("Available groups (%s):" % (len(group_ids)))
for group_id in group_ids:
- print(' - %s' % (group_id))
- print('')
+ print(" - %s" % (group_id))
+ print("")
else:
group_ids = [group_id]
for group_id in group_ids:
member_ids = list(coordinator.get_members(group_id).get())
- member_ids = [member_id.decode('utf-8') for member_id in member_ids]
+ member_ids = [member_id.decode("utf-8") for member_id in member_ids]
print('Members in group "%s" (%s):' % (group_id, len(member_ids)))
for member_id in member_ids:
- capabilities = coordinator.get_member_capabilities(group_id, member_id).get()
- print(' - %s (capabilities=%s)' % (member_id, str(capabilities)))
+ capabilities = coordinator.get_member_capabilities(
+ group_id, member_id
+ ).get()
+ print(" - %s (capabilities=%s)" % (member_id, str(capabilities)))
def do_register_cli_opts(opts, ignore_errors=False):
@@ -60,11 +62,13 @@ def do_register_cli_opts(opts, ignore_errors=False):
raise
-if __name__ == '__main__':
+if __name__ == "__main__":
cli_opts = [
- cfg.StrOpt('group-id', default=None,
- help='If provided, only list members for that group.'),
-
+ cfg.StrOpt(
+ "group-id",
+ default=None,
+ help="If provided, only list members for that group.",
+ ),
]
do_register_cli_opts(cli_opts)
config.parse_args()
diff --git a/tools/log_watcher.py b/tools/log_watcher.py
index cafcb4efec..b16af95cc1 100755
--- a/tools/log_watcher.py
+++ b/tools/log_watcher.py
@@ -27,25 +27,9 @@
LOG_ALERT_PERCENT = 5 # default.
-EVILS = [
- 'info',
- 'debug',
- 'warning',
- 'exception',
- 'error',
- 'audit'
-]
-
-LOG_VARS = [
- 'LOG',
- 'Log',
- 'log',
- 'LOGGER',
- 'Logger',
- 'logger',
- 'logging',
- 'LOGGING'
-]
+EVILS = ["info", "debug", "warning", "exception", "error", "audit"]
+
+LOG_VARS = ["LOG", "Log", "log", "LOGGER", "Logger", "logger", "logging", "LOGGING"]
FILE_LOG_COUNT = collections.defaultdict()
FILE_LINE_COUNT = collections.defaultdict()
@@ -55,25 +39,25 @@ def _parse_args(args):
global LOG_ALERT_PERCENT
params = {}
if len(args) > 1:
- params['alert_percent'] = args[1]
+ params["alert_percent"] = args[1]
LOG_ALERT_PERCENT = int(args[1])
return params
def _skip_file(filename):
- if filename.startswith('.') or filename.startswith('_'):
+ if filename.startswith(".") or filename.startswith("_"):
return True
def _get_files(dir_path):
if not os.path.exists(dir_path):
- print('Directory %s doesn\'t exist.' % dir_path)
+ print("Directory %s doesn't exist." % dir_path)
files = []
- exclude = set(['virtualenv', 'build', '.tox'])
+ exclude = set(["virtualenv", "build", ".tox"])
for root, dirnames, filenames in os.walk(dir_path):
dirnames[:] = [d for d in dirnames if d not in exclude]
- for filename in fnmatch.filter(filenames, '*.py'):
+ for filename in fnmatch.filter(filenames, "*.py"):
if not _skip_file(filename):
files.append(os.path.join(root, filename))
return files
@@ -84,7 +68,7 @@ def _build_regex():
regex_strings = {}
regexes = {}
for level in EVILS:
- regex_string = '|'.join([r'\.'.join([log, level]) for log in LOG_VARS])
+ regex_string = "|".join([r"\.".join([log, level]) for log in LOG_VARS])
regex_strings[level] = regex_string
# print('Level: %s, regex_string: %s' % (level, regex_strings[level]))
regexes[level] = re.compile(regex_strings[level])
@@ -98,7 +82,7 @@ def _regex_match(line, regexes):
def _build_str_matchers():
match_strings = {}
for level in EVILS:
- match_strings[level] = ['.'.join([log, level]) for log in LOG_VARS]
+ match_strings[level] = [".".join([log, level]) for log in LOG_VARS]
return match_strings
@@ -107,8 +91,10 @@ def _get_log_count_dict():
def _alert(fil, lines, logs, logs_level):
- print('WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, '
- 'logs: %s' % (fil, lines, logs, float(logs) / lines * 100, logs_level))
+ print(
+ "WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, "
+ "logs: %s" % (fil, lines, logs, float(logs) / lines * 100, logs_level)
+ )
def _match(line, match_strings):
@@ -117,7 +103,7 @@ def _match(line, match_strings):
if line.startswith(match_string):
# print('Line: %s, match: %s' % (line, match_string))
return True, level, line
- return False, 'UNKNOWN', line
+ return False, "UNKNOWN", line
def _detect_log_lines(fil, matchers):
@@ -148,23 +134,45 @@ def _post_process(file_dir):
if total_log_count > 0:
if float(total_log_count) / lines * 100 > LOG_ALERT_PERCENT:
if file_dir in fil:
- fil = fil[len(file_dir) + 1:]
- alerts.append([fil, lines, total_log_count, float(total_log_count) / lines * 100,
- log_lines_count_level['audit'],
- log_lines_count_level['exception'],
- log_lines_count_level['error'],
- log_lines_count_level['warning'],
- log_lines_count_level['info'],
- log_lines_count_level['debug']])
+ fil = fil[len(file_dir) + 1 :]
+ alerts.append(
+ [
+ fil,
+ lines,
+ total_log_count,
+ float(total_log_count) / lines * 100,
+ log_lines_count_level["audit"],
+ log_lines_count_level["exception"],
+ log_lines_count_level["error"],
+ log_lines_count_level["warning"],
+ log_lines_count_level["info"],
+ log_lines_count_level["debug"],
+ ]
+ )
# sort by percent
alerts.sort(key=lambda alert: alert[3], reverse=True)
- print(tabulate(alerts, headers=['File', 'Lines', 'Logs', 'Percent', 'adt', 'exc', 'err', 'wrn',
- 'inf', 'dbg']))
+ print(
+ tabulate(
+ alerts,
+ headers=[
+ "File",
+ "Lines",
+ "Logs",
+ "Percent",
+ "adt",
+ "exc",
+ "err",
+ "wrn",
+ "inf",
+ "dbg",
+ ],
+ )
+ )
def main(args):
params = _parse_args(args)
- file_dir = params.get('dir', os.getcwd())
+ file_dir = params.get("dir", os.getcwd())
files = _get_files(file_dir)
matchers = _build_str_matchers()
for f in files:
@@ -172,5 +180,5 @@ def main(args):
_post_process(file_dir)
-if __name__ == '__main__':
+if __name__ == "__main__":
main(sys.argv)
diff --git a/tools/migrate_messaging_setup.py b/tools/migrate_messaging_setup.py
index 095af26e0d..3fea8cab83 100755
--- a/tools/migrate_messaging_setup.py
+++ b/tools/migrate_messaging_setup.py
@@ -36,11 +36,13 @@ class Migrate_0_13_x_to_1_1_0(object):
# changes or changes in durability proeprties.
OLD_QS = [
# Name changed in 1.1
- reactor.get_trigger_cud_queue('st2.trigger.watch.timers', routing_key='#'),
+ reactor.get_trigger_cud_queue("st2.trigger.watch.timers", routing_key="#"),
# Split to multiple queues in 1.1
- reactor.get_trigger_cud_queue('st2.trigger.watch.sensorwrapper', routing_key='#'),
+ reactor.get_trigger_cud_queue(
+ "st2.trigger.watch.sensorwrapper", routing_key="#"
+ ),
# Name changed in 1.1
- reactor.get_trigger_cud_queue('st2.trigger.watch.webhooks', routing_key='#')
+ reactor.get_trigger_cud_queue("st2.trigger.watch.webhooks", routing_key="#"),
]
def migrate(self):
@@ -53,7 +55,7 @@ def _cleanup_old_queues(self):
try:
bound_q.delete()
except:
- print('Failed to delete %s.' % q.name)
+ print("Failed to delete %s." % q.name)
traceback.print_exc()
@@ -62,10 +64,10 @@ def main():
migrator = Migrate_0_13_x_to_1_1_0()
migrator.migrate()
except:
- print('Messaging setup migration failed.')
+ print("Messaging setup migration failed.")
traceback.print_exc()
-if __name__ == '__main__':
+if __name__ == "__main__":
config.parse_args(args={})
main()
diff --git a/tools/migrate_rules_to_include_pack.py b/tools/migrate_rules_to_include_pack.py
index 8afd3faa15..1acdd26383 100755
--- a/tools/migrate_rules_to_include_pack.py
+++ b/tools/migrate_rules_to_include_pack.py
@@ -31,8 +31,11 @@
class Migration(object):
- class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
- stormbase.ContentPackResourceMixin):
+ class RuleDB(
+ stormbase.StormFoundationDB,
+ stormbase.TagsMixin,
+ stormbase.ContentPackResourceMixin,
+ ):
"""Specifies the action to invoke on the occurrence of a Trigger. It
also includes the transformation to perform to match the impedance
between the payload of a TriggerInstance and input of a action.
@@ -43,22 +46,23 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin,
status: enabled or disabled. If disabled occurrence of the trigger
does not lead to execution of a action and vice-versa.
"""
+
name = me.StringField(required=True)
ref = me.StringField(required=True)
description = me.StringField()
pack = me.StringField(
- required=False,
- help_text='Name of the content pack.',
- unique_with='name')
+ required=False, help_text="Name of the content pack.", unique_with="name"
+ )
trigger = me.StringField()
criteria = stormbase.EscapedDictField()
action = me.EmbeddedDocumentField(ActionExecutionSpecDB)
- enabled = me.BooleanField(required=True, default=True,
- help_text=u'Flag indicating whether the rule is enabled.')
+ enabled = me.BooleanField(
+ required=True,
+ default=True,
+ help_text="Flag indicating whether the rule is enabled.",
+ )
- meta = {
- 'indexes': stormbase.TagsMixin.get_indexes()
- }
+ meta = {"indexes": stormbase.TagsMixin.get_indexes()}
# specialized access objects
@@ -76,15 +80,17 @@ class RuleDB(stormbase.StormBaseDB, stormbase.TagsMixin):
status: enabled or disabled. If disabled occurrence of the trigger
does not lead to execution of a action and vice-versa.
"""
+
trigger = me.StringField()
criteria = stormbase.EscapedDictField()
action = me.EmbeddedDocumentField(ActionExecutionSpecDB)
- enabled = me.BooleanField(required=True, default=True,
- help_text=u'Flag indicating whether the rule is enabled.')
+ enabled = me.BooleanField(
+ required=True,
+ default=True,
+ help_text="Flag indicating whether the rule is enabled.",
+ )
- meta = {
- 'indexes': stormbase.TagsMixin.get_indexes()
- }
+ meta = {"indexes": stormbase.TagsMixin.get_indexes()}
rule_access_without_pack = MongoDBAccess(RuleDB)
@@ -100,7 +106,7 @@ def _get_impl(cls):
@classmethod
def _get_by_object(cls, object):
# For Rule name is unique.
- name = getattr(object, 'name', '')
+ name = getattr(object, "name", "")
return cls.get_by_name(name)
@@ -126,13 +132,14 @@ def migrate_rules():
action=rule.action,
enabled=rule.enabled,
pack=DEFAULT_PACK_NAME,
- ref=ResourceReference.to_string_reference(pack=DEFAULT_PACK_NAME,
- name=rule.name)
+ ref=ResourceReference.to_string_reference(
+ pack=DEFAULT_PACK_NAME, name=rule.name
+ ),
)
- print('Migrating rule: %s to rule: %s' % (rule.name, rule_with_pack.ref))
+ print("Migrating rule: %s to rule: %s" % (rule.name, rule_with_pack.ref))
RuleWithPack.add_or_update(rule_with_pack)
except Exception as e:
- print('Migration failed. %s' % six.text_type(e))
+ print("Migration failed. %s" % six.text_type(e))
def main():
@@ -148,5 +155,5 @@ def main():
db_teardown()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tools/migrate_triggers_to_include_ref_count.py b/tools/migrate_triggers_to_include_ref_count.py
index 3e8f1b79f0..af98a00a07 100755
--- a/tools/migrate_triggers_to_include_ref_count.py
+++ b/tools/migrate_triggers_to_include_ref_count.py
@@ -27,7 +27,6 @@
class TriggerMigrator(object):
-
def _get_trigger_with_parameters(self):
"""
All TriggerDB that has a parameter.
@@ -38,7 +37,7 @@ def _get_rules_for_trigger(self, trigger_ref):
"""
All rules that reference the supplied trigger_ref.
"""
- return Rule.get_all(**{'trigger': trigger_ref})
+ return Rule.get_all(**{"trigger": trigger_ref})
def _update_trigger_ref_count(self, trigger_db, ref_count):
"""
@@ -56,7 +55,7 @@ def migrate(self):
trigger_ref = trigger_db.get_reference().ref
rules = self._get_rules_for_trigger(trigger_ref=trigger_ref)
ref_count = len(rules)
- print('Updating Trigger %s to ref_count %s' % (trigger_ref, ref_count))
+ print("Updating Trigger %s to ref_count %s" % (trigger_ref, ref_count))
self._update_trigger_ref_count(trigger_db=trigger_db, ref_count=ref_count)
@@ -76,5 +75,5 @@ def main():
teartown()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tools/queue_consumer.py b/tools/queue_consumer.py
index bf19cbf1d5..69de164fbd 100755
--- a/tools/queue_consumer.py
+++ b/tools/queue_consumer.py
@@ -37,28 +37,31 @@ def __init__(self, connection, queue):
self.queue = queue
def get_consumers(self, Consumer, channel):
- return [Consumer(queues=[self.queue],
- accept=['pickle'],
- callbacks=[self.process_task])]
+ return [
+ Consumer(
+ queues=[self.queue], accept=["pickle"], callbacks=[self.process_task]
+ )
+ ]
def process_task(self, body, message):
- print('===================================================')
- print('Received message')
- print('message.properties:')
+ print("===================================================")
+ print("Received message")
+ print("message.properties:")
pprint(message.properties)
- print('message.delivery_info:')
+ print("message.delivery_info:")
pprint(message.delivery_info)
- print('body:')
+ print("body:")
pprint(body)
- print('===================================================')
+ print("===================================================")
message.ack()
-def main(queue, exchange, routing_key='#'):
- exchange = Exchange(exchange, type='topic')
- queue = Queue(name=queue, exchange=exchange, routing_key=routing_key,
- auto_delete=True)
+def main(queue, exchange, routing_key="#"):
+ exchange = Exchange(exchange, type="topic")
+ queue = Queue(
+ name=queue, exchange=exchange, routing_key=routing_key, auto_delete=True
+ )
with transport_utils.get_connection() as connection:
connection.connect()
@@ -66,13 +69,11 @@ def main(queue, exchange, routing_key='#'):
watcher.run()
-if __name__ == '__main__':
+if __name__ == "__main__":
config.parse_args(args={})
- parser = argparse.ArgumentParser(description='Queue consumer')
- parser.add_argument('--exchange', required=True,
- help='Exchange to listen on')
- parser.add_argument('--routing-key', default='#',
- help='Routing key')
+ parser = argparse.ArgumentParser(description="Queue consumer")
+ parser.add_argument("--exchange", required=True, help="Exchange to listen on")
+ parser.add_argument("--routing-key", default="#", help="Routing key")
args = parser.parse_args()
queue_name = args.exchange + str(random.randint(1, 10000))
diff --git a/tools/queue_producer.py b/tools/queue_producer.py
index c936088676..01476a26b8 100755
--- a/tools/queue_producer.py
+++ b/tools/queue_producer.py
@@ -30,22 +30,20 @@
def main(exchange, routing_key, payload):
- exchange = Exchange(exchange, type='topic')
+ exchange = Exchange(exchange, type="topic")
publisher = PoolPublisher()
publisher.publish(payload=payload, exchange=exchange, routing_key=routing_key)
eventlet.sleep(0.5)
-if __name__ == '__main__':
+if __name__ == "__main__":
config.parse_args(args={})
- parser = argparse.ArgumentParser(description='Queue producer')
- parser.add_argument('--exchange', required=True,
- help='Exchange to publish the message to')
- parser.add_argument('--routing-key', required=True,
- help='Routing key to use')
- parser.add_argument('--payload', required=True,
- help='Message payload')
+ parser = argparse.ArgumentParser(description="Queue producer")
+ parser.add_argument(
+ "--exchange", required=True, help="Exchange to publish the message to"
+ )
+ parser.add_argument("--routing-key", required=True, help="Routing key to use")
+ parser.add_argument("--payload", required=True, help="Message payload")
args = parser.parse_args()
- main(exchange=args.exchange, routing_key=args.routing_key,
- payload=args.payload)
+ main(exchange=args.exchange, routing_key=args.routing_key, payload=args.payload)
diff --git a/tools/st2-analyze-links.py b/tools/st2-analyze-links.py
index 4daeeafa44..f66c158dea 100644
--- a/tools/st2-analyze-links.py
+++ b/tools/st2-analyze-links.py
@@ -44,8 +44,10 @@
try:
from graphviz import Digraph
except ImportError:
- msg = ('Missing "graphviz" dependency. You can install it using pip: \n'
- 'pip install graphviz')
+ msg = (
+ 'Missing "graphviz" dependency. You can install it using pip: \n'
+ "pip install graphviz"
+ )
raise ImportError(msg)
@@ -59,18 +61,20 @@ def do_register_cli_opts(opts, ignore_errors=False):
class RuleLink(object):
-
def __init__(self, source_action_ref, rule_ref, dest_action_ref):
self._source_action_ref = source_action_ref
self._rule_ref = rule_ref
self._dest_action_ref = dest_action_ref
def __str__(self):
- return '(%s -> %s -> %s)' % (self._source_action_ref, self._rule_ref, self._dest_action_ref)
+ return "(%s -> %s -> %s)" % (
+ self._source_action_ref,
+ self._rule_ref,
+ self._dest_action_ref,
+ )
class LinksAnalyzer(object):
-
def __init__(self):
self._rule_link_by_action_ref = {}
self._rules = {}
@@ -81,25 +85,30 @@ def analyze(self, root_action_ref, link_tigger_ref):
for rule in rules:
source_action_ref = self._get_source_action_ref(rule)
if not source_action_ref:
- print('No source_action_ref for rule %s' % rule.ref)
+ print("No source_action_ref for rule %s" % rule.ref)
continue
rule_links = self._rules.get(source_action_ref, None)
if rule_links is None:
rule_links = []
self._rules[source_action_ref] = rule_links
- rule_links.append(RuleLink(source_action_ref=source_action_ref, rule_ref=rule.ref,
- dest_action_ref=rule.action.ref))
+ rule_links.append(
+ RuleLink(
+ source_action_ref=source_action_ref,
+ rule_ref=rule.ref,
+ dest_action_ref=rule.action.ref,
+ )
+ )
analyzed = self._do_analyze(action_ref=root_action_ref)
for (depth, rule_link) in analyzed:
- print('%s%s' % (' ' * depth, rule_link))
+ print("%s%s" % (" " * depth, rule_link))
return analyzed
def _get_source_action_ref(self, rule):
criteria = rule.criteria
- source_action_ref = criteria.get('trigger.action_name', None)
+ source_action_ref = criteria.get("trigger.action_name", None)
if not source_action_ref:
- source_action_ref = criteria.get('trigger.action_ref', None)
- return source_action_ref['pattern'] if source_action_ref else None
+ source_action_ref = criteria.get("trigger.action_ref", None)
+ return source_action_ref["pattern"] if source_action_ref else None
def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0):
if processed is None:
@@ -111,24 +120,32 @@ def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0):
rule_links.append((depth, rule_link))
if rule_link._dest_action_ref in processed:
continue
- self._do_analyze(rule_link._dest_action_ref, rule_links=rule_links,
- processed=processed, depth=depth + 1)
+ self._do_analyze(
+ rule_link._dest_action_ref,
+ rule_links=rule_links,
+ processed=processed,
+ depth=depth + 1,
+ )
return rule_links
class Grapher(object):
def generate_graph(self, rule_links, out_file):
- graph_label = 'Rule based visualizer'
+ graph_label = "Rule based visualizer"
graph_attr = {
- 'rankdir': 'TD',
- 'labelloc': 't',
- 'fontsize': '15',
- 'label': graph_label
+ "rankdir": "TD",
+ "labelloc": "t",
+ "fontsize": "15",
+ "label": graph_label,
}
node_attr = {}
- dot = Digraph(comment='Rule based links visualization',
- node_attr=node_attr, graph_attr=graph_attr, format='png')
+ dot = Digraph(
+ comment="Rule based links visualization",
+ node_attr=node_attr,
+ graph_attr=graph_attr,
+ format="png",
+ )
nodes = set()
for _, rule_link in rule_links:
@@ -139,10 +156,14 @@ def generate_graph(self, rule_links, out_file):
if rule_link._dest_action_ref not in nodes:
nodes.add(rule_link._dest_action_ref)
dot.node(rule_link._dest_action_ref, rule_link._dest_action_ref)
- dot.edge(rule_link._source_action_ref, rule_link._dest_action_ref, constraint='true',
- label=rule_link._rule_ref)
+ dot.edge(
+ rule_link._source_action_ref,
+ rule_link._dest_action_ref,
+ constraint="true",
+ label=rule_link._rule_ref,
+ )
output_path = os.path.join(os.getcwd(), out_file)
- dot.format = 'png'
+ dot.format = "png"
dot.render(output_path)
@@ -150,11 +171,13 @@ def main():
monkey_patch()
cli_opts = [
- cfg.StrOpt('action_ref', default=None,
- help='Root action to begin analysis.'),
- cfg.StrOpt('link_trigger_ref', default='core.st2.generic.actiontrigger',
- help='Root action to begin analysis.'),
- cfg.StrOpt('out_file', default='pipeline')
+ cfg.StrOpt("action_ref", default=None, help="Root action to begin analysis."),
+ cfg.StrOpt(
+ "link_trigger_ref",
+ default="core.st2.generic.actiontrigger",
+ help="Root action to begin analysis.",
+ ),
+ cfg.StrOpt("out_file", default="pipeline"),
]
do_register_cli_opts(cli_opts)
config.parse_args()
@@ -163,5 +186,5 @@ def main():
Grapher().generate_graph(rule_links, cfg.CONF.out_file)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tools/st2-inject-trigger-instances.py b/tools/st2-inject-trigger-instances.py
index a20ba8bcbb..79b0f18e25 100755
--- a/tools/st2-inject-trigger-instances.py
+++ b/tools/st2-inject-trigger-instances.py
@@ -49,7 +49,9 @@ def do_register_cli_opts(opts, ignore_errors=False):
raise
-def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_throughput=False):
+def _inject_instances(
+ trigger, rate_per_trigger, duration, payload=None, max_throughput=False
+):
payload = payload or {}
start = date_utils.get_datetime_utc_now()
@@ -72,37 +74,54 @@ def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_thr
actual_rate = int(count / elapsed)
- print('%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)' %
- (trigger, count, elapsed, actual_rate))
+ print(
+ "%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)"
+ % (trigger, count, elapsed, actual_rate)
+ )
# NOTE: Due to the overhead of dispatcher.dispatch call, we allow for 10% of deviation from
# requested rate before warning
if rate_per_trigger and (actual_rate < (rate_per_trigger * 0.9)):
- print('')
- print('Warning, requested rate was %s triggers / second, but only achieved %s '
- 'triggers / second' % (rate_per_trigger, actual_rate))
- print('Too increase the throuput you will likely need to run multiple instances of '
- 'this script in parallel.')
+ print("")
+ print(
+ "Warning, requested rate was %s triggers / second, but only achieved %s "
+ "triggers / second" % (rate_per_trigger, actual_rate)
+ )
+ print(
+ "Too increase the throuput you will likely need to run multiple instances of "
+ "this script in parallel."
+ )
def main():
monkey_patch()
cli_opts = [
- cfg.IntOpt('rate', default=100,
- help='Rate of trigger injection measured in instances in per sec.' +
- ' Assumes a default exponential distribution in time so arrival is poisson.'),
- cfg.ListOpt('triggers', required=False,
- help='List of triggers for which instances should be fired.' +
- ' Uniform distribution will be followed if there is more than one' +
- 'trigger.'),
- cfg.StrOpt('schema_file', default=None,
- help='Path to schema file defining trigger and payload.'),
- cfg.IntOpt('duration', default=60,
- help='Duration of stress test in seconds.'),
- cfg.BoolOpt('max-throughput', default=False,
- help='If True, "rate" argument will be ignored and this script will try to '
- 'saturize the CPU and achieve max utilization.')
+ cfg.IntOpt(
+ "rate",
+ default=100,
+ help="Rate of trigger injection measured in instances in per sec."
+ + " Assumes a default exponential distribution in time so arrival is poisson.",
+ ),
+ cfg.ListOpt(
+ "triggers",
+ required=False,
+ help="List of triggers for which instances should be fired."
+ + " Uniform distribution will be followed if there is more than one"
+ + "trigger.",
+ ),
+ cfg.StrOpt(
+ "schema_file",
+ default=None,
+ help="Path to schema file defining trigger and payload.",
+ ),
+ cfg.IntOpt("duration", default=60, help="Duration of stress test in seconds."),
+ cfg.BoolOpt(
+ "max-throughput",
+ default=False,
+ help='If True, "rate" argument will be ignored and this script will try to '
+ "saturize the CPU and achieve max utilization.",
+ ),
]
do_register_cli_opts(cli_opts)
config.parse_args()
@@ -112,15 +131,20 @@ def main():
trigger_payload_schema = {}
if not triggers:
- if (cfg.CONF.schema_file is None or cfg.CONF.schema_file == '' or
- not os.path.exists(cfg.CONF.schema_file)):
- print('Either "triggers" need to be provided or a schema file containing' +
- ' triggers should be provided.')
+ if (
+ cfg.CONF.schema_file is None
+ or cfg.CONF.schema_file == ""
+ or not os.path.exists(cfg.CONF.schema_file)
+ ):
+ print(
+ 'Either "triggers" need to be provided or a schema file containing'
+ + " triggers should be provided."
+ )
return
with open(cfg.CONF.schema_file) as fd:
trigger_payload_schema = yaml.safe_load(fd)
triggers = list(trigger_payload_schema.keys())
- print('Triggers=%s' % triggers)
+ print("Triggers=%s" % triggers)
rate = cfg.CONF.rate
rate_per_trigger = int(rate / len(triggers))
@@ -135,11 +159,17 @@ def main():
for trigger in triggers:
payload = trigger_payload_schema.get(trigger, {})
- dispatcher_pool.spawn(_inject_instances, trigger, rate_per_trigger, duration,
- payload=payload, max_throughput=max_throughput)
+ dispatcher_pool.spawn(
+ _inject_instances,
+ trigger,
+ rate_per_trigger,
+ duration,
+ payload=payload,
+ max_throughput=max_throughput,
+ )
eventlet.sleep(random.uniform(0, 1))
dispatcher_pool.waitall()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tools/visualize_action_chain.py b/tools/visualize_action_chain.py
index 9981bd956c..c6742c460d 100755
--- a/tools/visualize_action_chain.py
+++ b/tools/visualize_action_chain.py
@@ -26,8 +26,10 @@
try:
from graphviz import Digraph
except ImportError:
- msg = ('Missing "graphviz" dependency. You can install it using pip: \n'
- 'pip install graphviz')
+ msg = (
+ 'Missing "graphviz" dependency. You can install it using pip: \n'
+ "pip install graphviz"
+ )
raise ImportError(msg)
from st2common.content.loader import MetaLoader
@@ -41,25 +43,29 @@ def main(metadata_path, output_path, print_source=False):
meta_loader = MetaLoader()
data = meta_loader.load(metadata_path)
- action_name = data['name']
- entry_point = data['entry_point']
+ action_name = data["name"]
+ entry_point = data["entry_point"]
workflow_metadata_path = os.path.join(metadata_dir, entry_point)
chainspec = meta_loader.load(workflow_metadata_path)
- chain_holder = ChainHolder(chainspec, 'workflow')
+ chain_holder = ChainHolder(chainspec, "workflow")
- graph_label = '%s action-chain workflow visualization' % (action_name)
+ graph_label = "%s action-chain workflow visualization" % (action_name)
graph_attr = {
- 'rankdir': 'TD',
- 'labelloc': 't',
- 'fontsize': '15',
- 'label': graph_label
+ "rankdir": "TD",
+ "labelloc": "t",
+ "fontsize": "15",
+ "label": graph_label,
}
node_attr = {}
- dot = Digraph(comment='Action chain work-flow visualization',
- node_attr=node_attr, graph_attr=graph_attr, format='png')
+ dot = Digraph(
+ comment="Action chain work-flow visualization",
+ node_attr=node_attr,
+ graph_attr=graph_attr,
+ format="png",
+ )
# dot.body.extend(['rankdir=TD', 'size="10,5"'])
# Add all nodes
@@ -74,23 +80,35 @@ def main(metadata_path, output_path, print_source=False):
nodes = [node]
while nodes:
previous_node = nodes.pop()
- success_node = chain_holder.get_next_node(curr_node_name=previous_node.name,
- condition='on-success')
- failure_node = chain_holder.get_next_node(curr_node_name=previous_node.name,
- condition='on-failure')
+ success_node = chain_holder.get_next_node(
+ curr_node_name=previous_node.name, condition="on-success"
+ )
+ failure_node = chain_holder.get_next_node(
+ curr_node_name=previous_node.name, condition="on-failure"
+ )
# Add success node (if any)
if success_node:
- dot.edge(previous_node.name, success_node.name, constraint='true',
- color='green', label='on success')
+ dot.edge(
+ previous_node.name,
+ success_node.name,
+ constraint="true",
+ color="green",
+ label="on success",
+ )
if success_node.name not in processed_nodes:
nodes.append(success_node)
processed_nodes.add(success_node.name)
# Add failure node (if any)
if failure_node:
- dot.edge(previous_node.name, failure_node.name, constraint='true',
- color='red', label='on failure')
+ dot.edge(
+ previous_node.name,
+ failure_node.name,
+ constraint="true",
+ color="red",
+ label="on failure",
+ )
if failure_node.name not in processed_nodes:
nodes.append(failure_node)
processed_nodes.add(failure_node.name)
@@ -103,21 +121,36 @@ def main(metadata_path, output_path, print_source=False):
else:
output_path = output_path or os.path.join(os.getcwd(), action_name)
- dot.format = 'png'
+ dot.format = "png"
dot.render(output_path)
- print('Graph saved at %s' % (output_path + '.png'))
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Action chain visualization')
- parser.add_argument('--metadata-path', action='store', required=True,
- help='Path to the workflow action metadata file')
- parser.add_argument('--output-path', action='store', required=False,
- help='Output directory for the generated image')
- parser.add_argument('--print-source', action='store_true', default=False,
- help='Print graphviz source code to the stdout')
+ print("Graph saved at %s" % (output_path + ".png"))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Action chain visualization")
+ parser.add_argument(
+ "--metadata-path",
+ action="store",
+ required=True,
+ help="Path to the workflow action metadata file",
+ )
+ parser.add_argument(
+ "--output-path",
+ action="store",
+ required=False,
+ help="Output directory for the generated image",
+ )
+ parser.add_argument(
+ "--print-source",
+ action="store_true",
+ default=False,
+ help="Print graphviz source code to the stdout",
+ )
args = parser.parse_args()
- main(metadata_path=args.metadata_path, output_path=args.output_path,
- print_source=args.print_source)
+ main(
+ metadata_path=args.metadata_path,
+ output_path=args.output_path,
+ print_source=args.print_source,
+ )
diff --git a/tox.ini b/tox.ini
index 451ceee8e1..de40b85878 100644
--- a/tox.ini
+++ b/tox.ini
@@ -71,7 +71,7 @@ commands =
[testenv:py36-integration]
basepython = python3.6
-setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner
+setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner
VIRTUALENV_DIR = {envdir}
passenv = NOSE_WITH_TIMER TRAVIS ST2_CI
install_command = pip install -U --force-reinstall {opts} {packages}